diff --git a/brainpy/__init__.py b/brainpy/__init__.py index e26a4cc7b..0d534106f 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -104,7 +104,7 @@ # Part: Others # # ---------------- # -from brainpy._src.visualization import (visualize as visualize) +import brainpy._src.visualization as visualize # Part: Deprecations # # -------------------- # diff --git a/brainpy/_src/_delay.py b/brainpy/_src/_delay.py index 0f12796f2..c37031454 100644 --- a/brainpy/_src/_delay.py +++ b/brainpy/_src/_delay.py @@ -11,292 +11,291 @@ from brainpy import check from brainpy import math as bm +from brainpy._src.context import share from brainpy._src.dynsys import DynamicalSystem from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE -from brainpy._src.context import share - __all__ = [ - 'Delay', + 'Delay', ] class Delay(DynamicalSystem): - """Delay variable which has a fixed delay length. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Parameters:: - - latest: Variable - The initial delay data. - length: int - The delay data length. - before_t0: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - method: str - The method used for updating delay. - - """ - - latest: bm.Variable - data: Optional[bm.Variable] - length: int - - def __init__( - self, - latest: bm.Variable, - length: int = 0, - before_t0: Optional[Union[float, int, bool, bm.Array, jax.Array, Callable]] = None, - entries: Optional[Dict] = None, - name: Optional[str] = None, - method: str = ROTATE_UPDATE, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - if method is None: - if self.mode.is_a(bm.NonBatchingMode): - method = ROTATE_UPDATE - elif self.mode.is_a(bm.TrainingMode): - method = CONCAT_UPDATE - else: - method = ROTATE_UPDATE - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.method = method - - # target - if not isinstance(latest, bm.Variable): - raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(latest)}') - self.latest = latest - - # delay length - assert isinstance(length, int) - self.length = length - - # delay data - if before_t0 is not None: - assert isinstance(before_t0, (int, float, bool, bm.BaseArray, jax.Array, Callable)) - self._before_t0 = before_t0 - if length > 0: - self._init_data(length) - else: - self.data = None - - # other info - self._access_to_step = dict() - for entry, value in entries.items(): - self.register_entry(entry, value) - - def register_entry( - self, - entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]] = None, - delay_step: Optional[Union[int, bm.Array, Callable]] = None, - ) -> 'Delay': - """Register an entry to access the data. - - Args: - entry (str): The entry to access the delay data. - delay_step: The delay step of the entry (must be an integer, denoting the delay step). - delay_time: The delay time of the entry (can be a float). - - Returns: - Return the self. - """ - if entry in self._access_to_step: - raise KeyError(f'Entry {entry} has been registered.') - - if delay_time is not None: - if delay_step is not None: - raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') - if callable(delay_time): - delay_time = bm.as_jax(delay_time(self.delay_target_shape)) - delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) - elif isinstance(delay_time, float): - delay_step = int(delay_time / bm.get_dt()) - else: - delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) - - # delay steps - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, int): - delay_type = 'homo' - elif isinstance(delay_step, (bm.BaseArray, jax.Array, np.ndarray)): - if delay_step.size == 1 and delay_step.ndim == 0: - delay_type = 'homo' - else: - delay_type = 'heter' - delay_step = delay_step - elif callable(delay_step): - delay_step = delay_step(self.delay_target_shape) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [jnp.int32, jnp.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if self.delay_target_shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') - if delay_type == 'heter': - max_delay_step = int(max(delay_step)) - elif delay_type == 'homo': - max_delay_step = delay_step - else: - max_delay_step = None - - # delay variable - if max_delay_step is not None: - if self.length < max_delay_step: - self._init_data(max_delay_step) - self.length = max_delay_step - self._access_to_step[entry] = delay_step - return self - - def at(self, entry: str, *indices) -> bm.Array: - """Get the data at the given entry. - - Args: - entry (str): The entry to access the data. - *indices: - - Returns: - The data. - """ - assert isinstance(entry, str) - if entry not in self._access_to_step: - raise KeyError(f'Does not find delay entry "{entry}".') - delay_step = self._access_to_step[entry] - if delay_step is None: - return self.latest.value - else: - if self.data is None: - return self.latest.value - else: - if isinstance(delay_step, slice): - return self.retrieve(delay_step, *indices) - elif np.ndim(delay_step) == 0: - return self.retrieve(delay_step, *indices) - else: - if len(indices) == 0 and len(delay_step) == self.latest.shape[0]: - indices = (jnp.arange(delay_step.size),) - return self.retrieve(delay_step, *indices) - - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.latest.shape + """Delay variable which has a fixed delay length. - def __repr__(self): - name = self.__class__.__name__ - return (f'{name}(num_delay_step={self.length}, ' - f'delay_target_shape={self.delay_target_shape}, ' - f'update_method={self.method})') + The data in this delay variable is arranged as:: - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.length}. ' - f'But we got {delay_len}') - - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] Parameters:: - delay_step: int, ArrayType - The delay length used to retrieve the data. - """ - assert delay_step is not None - if check.is_checking(): - check.jit_error(jnp.any(delay_step > self.length), self._check_delay, delay_step) - - if self.method == ROTATE_UPDATE: - i = share.load('i') - delay_idx = (i + delay_step) % (self.length + 1) - delay_idx = stop_gradient(delay_idx) - - elif self.method == CONCAT_UPDATE: - delay_idx = delay_step - - else: - raise ValueError(f'Unknown updating method "{self.method}"') - - # the delay index - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + tuple(indices) - - # the delay data - return self.data[indices] - - def update( - self, - latest_value: Optional[Union[bm.Array, jax.Array]] = None - ) -> None: - """Update delay variable with the new data. + latest: Variable + The initial delay data. + length: int + The delay data length. + before_t0: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + method: str + The method used for updating delay. + """ - if self.data is not None: - # get the latest target value - if latest_value is None: - latest_value = self.latest.value - - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - i = share.load('i') - idx = bm.as_jax((i - 1) % (self.length + 1)) - self.data[idx] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.length >= 2: - self.data.value = bm.vstack([latest_value, self.data[1:]]) + + latest: bm.Variable + data: Optional[bm.Variable] + length: int + + def __init__( + self, + latest: bm.Variable, + length: int = 0, + before_t0: Optional[Union[float, int, bool, bm.Array, jax.Array, Callable]] = None, + entries: Optional[Dict] = None, + name: Optional[str] = None, + method: str = ROTATE_UPDATE, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + if method is None: + if self.mode.is_a(bm.NonBatchingMode): + method = ROTATE_UPDATE + elif self.mode.is_a(bm.TrainingMode): + method = CONCAT_UPDATE + else: + method = ROTATE_UPDATE + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method + + # target + if not isinstance(latest, bm.Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(latest)}') + self.latest = latest + + # delay length + assert isinstance(length, int) + self.length = length + + # delay data + if before_t0 is not None: + assert isinstance(before_t0, (int, float, bool, bm.BaseArray, jax.Array, Callable)) + self._before_t0 = before_t0 + if length > 0: + self._init_data(length) + else: + self.data = None + + # other info + self._access_to_step = dict() + for entry, value in entries.items(): + self.register_entry(entry, value) + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[Union[int, bm.Array, Callable]] = None, + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry (str): The entry to access the delay data. + delay_step: The delay step of the entry (must be an integer, denoting the delay step). + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + if entry in self._access_to_step: + raise KeyError(f'Entry {entry} has been registered.') + + if delay_time is not None: + if delay_step is not None: + raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') + if callable(delay_time): + delay_time = bm.as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) + elif isinstance(delay_time, float): + delay_step = int(delay_time / bm.get_dt()) + else: + delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) + + # delay steps + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.BaseArray, jax.Array, np.ndarray)): + if delay_step.size == 1 and delay_step.ndim == 0: + delay_type = 'homo' + else: + delay_type = 'heter' + delay_step = delay_step + elif callable(delay_step): + delay_step = delay_step(self.delay_target_shape) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [jnp.int32, jnp.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if self.delay_target_shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') + if delay_type == 'heter': + max_delay_step = int(max(delay_step)) + elif delay_type == 'homo': + max_delay_step = delay_step else: - self.data[0] = latest_value + max_delay_step = None + + # delay variable + if max_delay_step is not None: + if self.length < max_delay_step: + self._init_data(max_delay_step) + self.length = max_delay_step + self._access_to_step[entry] = delay_step + return self + + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry (str): The entry to access the data. + *indices: + + Returns: + The data. + """ + assert isinstance(entry, str) + if entry not in self._access_to_step: + raise KeyError(f'Does not find delay entry "{entry}".') + delay_step = self._access_to_step[entry] + if delay_step is None: + return self.latest.value + else: + if self.data is None: + return self.latest.value + else: + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.latest.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.latest.shape + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.length}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.method})') + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.length}. ' + f'But we got {delay_len}') + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. + + Parameters:: + + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + assert delay_step is not None + if check.is_checking(): + check.jit_error(jnp.any(delay_step > self.length), self._check_delay, delay_step) + + if self.method == ROTATE_UPDATE: + i = share.load('i') + delay_idx = (i + delay_step) % (self.length + 1) + delay_idx = stop_gradient(delay_idx) + + elif self.method == CONCAT_UPDATE: + delay_idx = delay_step - def reset_state(self, batch_size: int = None): - """Reset the delay data. - """ - # initialize delay data - if self.data is not None: - self._init_data(self.length, batch_size) - - def _init_data(self, length, batch_size: int = None): - if batch_size is not None: - if self.latest.batch_size != batch_size: - raise ValueError(f'The batch sizes of delay variable and target variable differ ' - f'({self.latest.batch_size} != {batch_size}). ' - 'Please reset the target variable first, because delay data ' - 'depends on the target variable. ') - - if self.latest.batch_axis is None: - batch_axis = None - else: - batch_axis = self.latest.batch_axis + 1 - self.data = bm.Variable(jnp.zeros((length + 1,) + self.latest.shape, dtype=self.latest.dtype), - batch_axis=batch_axis) - # update delay data - self.data[0] = self.latest.value - if isinstance(self._before_t0, (bm.BaseArray, jax.Array, float, int, bool)): - self.data[1:] = self._before_t0 - elif callable(self._before_t0): - self.data[1:] = self._before_t0((length,) + self.latest.shape, dtype=self.latest.dtype) + else: + raise ValueError(f'Unknown updating method "{self.method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + + # the delay data + return self.data[indices] + + def update( + self, + latest_value: Optional[Union[bm.Array, jax.Array]] = None + ) -> None: + """Update delay variable with the new data. + """ + if self.data is not None: + # get the latest target value + if latest_value is None: + latest_value = self.latest.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + i = share.load('i') + idx = bm.as_jax((i - 1) % (self.length + 1)) + self.data[idx] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value + + def reset_state(self, batch_size: int = None): + """Reset the delay data. + """ + # initialize delay data + if self.data is not None: + self._init_data(self.length, batch_size) + + def _init_data(self, length, batch_size: int = None): + if batch_size is not None: + if self.latest.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.latest.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + + if self.latest.batch_axis is None: + batch_axis = None + else: + batch_axis = self.latest.batch_axis + 1 + self.data = bm.Variable(jnp.zeros((length + 1,) + self.latest.shape, dtype=self.latest.dtype), + batch_axis=batch_axis) + # update delay data + self.data[0] = self.latest.value + if isinstance(self._before_t0, (bm.BaseArray, jax.Array, float, int, bool)): + self.data[1:] = self._before_t0 + elif callable(self._before_t0): + self.data[1:] = self._before_t0((length,) + self.latest.shape, dtype=self.latest.dtype) diff --git a/brainpy/_src/analysis/__init__.py b/brainpy/_src/analysis/__init__.py index 9ea60642a..3aa431fcc 100644 --- a/brainpy/_src/analysis/__init__.py +++ b/brainpy/_src/analysis/__init__.py @@ -14,12 +14,9 @@ Details in the following. """ +from . import constants as C, stability, plotstyle, utils from .base import * - +from .constants import * from .highdim.slow_points import * - -from .lowdim.lowdim_phase_plane import * from .lowdim.lowdim_bifurcation import * - -from .constants import * -from . import constants as C, stability, plotstyle, utils +from .lowdim.lowdim_phase_plane import * diff --git a/brainpy/_src/analysis/base.py b/brainpy/_src/analysis/base.py index 188cfbcf2..2272d240c 100644 --- a/brainpy/_src/analysis/base.py +++ b/brainpy/_src/analysis/base.py @@ -2,11 +2,10 @@ __all__ = [ - 'DSAnalyzer' + 'DSAnalyzer' ] class DSAnalyzer(object): - """Base class of analyzers for dynamical systems in BrainPy""" - pass - + """Base class of analyzers for dynamical systems in BrainPy""" + pass diff --git a/brainpy/_src/analysis/constants.py b/brainpy/_src/analysis/constants.py index 16898bac4..de9346c07 100644 --- a/brainpy/_src/analysis/constants.py +++ b/brainpy/_src/analysis/constants.py @@ -2,46 +2,44 @@ __all__ = [ - 'F_vmap_fx', - 'F_vmap_fy', - 'F_vmap_brentq_fx', - 'F_vmap_brentq_fy', - 'F_vmap_fp_aux', - 'F_vmap_fp_opt', - 'F_vmap_dfxdx', - 'F_fx', - 'F_fy', - 'F_fz', - 'F_dfxdx', - 'F_dfxdy', - 'F_dfydx', - 'F_dfydy', - 'F_jacobian', - 'F_vmap_jacobian', - 'F_fixed_point_aux', - 'F_fixed_point_opt', - 'F_x_by_y', - 'F_y_by_x', - 'F_y_convert', - 'F_x_convert', - 'F_int_x', - 'F_int_y', - 'x_by_y', - 'y_by_x', - 'y_by_x_in_fy', - 'y_by_x_in_fx', - 'x_by_y_in_fx', - 'x_by_y_in_fy', - 'F_y_by_x_in_fy', - 'F_x_by_y_in_fy', - 'F_y_by_x_in_fx', - 'F_x_by_y_in_fx', - 'fx_nullcline_points', - 'fy_nullcline_points', + 'F_vmap_fx', + 'F_vmap_fy', + 'F_vmap_brentq_fx', + 'F_vmap_brentq_fy', + 'F_vmap_fp_aux', + 'F_vmap_fp_opt', + 'F_vmap_dfxdx', + 'F_fx', + 'F_fy', + 'F_fz', + 'F_dfxdx', + 'F_dfxdy', + 'F_dfydx', + 'F_dfydy', + 'F_jacobian', + 'F_vmap_jacobian', + 'F_fixed_point_aux', + 'F_fixed_point_opt', + 'F_x_by_y', + 'F_y_by_x', + 'F_y_convert', + 'F_x_convert', + 'F_int_x', + 'F_int_y', + 'x_by_y', + 'y_by_x', + 'y_by_x_in_fy', + 'y_by_x_in_fx', + 'x_by_y_in_fx', + 'x_by_y_in_fy', + 'F_y_by_x_in_fy', + 'F_x_by_y_in_fy', + 'F_y_by_x_in_fx', + 'F_x_by_y_in_fx', + 'fx_nullcline_points', + 'fy_nullcline_points', ] - - CONTINUOUS = 'continuous' DISCRETE = 'discrete' @@ -83,4 +81,3 @@ fx_nullcline_points = 'fx_nullcline_points' fy_nullcline_points = 'fy_nullcline_points' prefix = '\t' - diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py index 78ae6c5a8..cb1b001b2 100644 --- a/brainpy/_src/analysis/highdim/slow_points.py +++ b/brainpy/_src/analysis/highdim/slow_points.py @@ -6,848 +6,848 @@ import warnings from typing import Callable, Union, Dict, Sequence, Tuple +import jax import jax.numpy as jnp import numpy as np -import jax from jax.scipy.optimize import minimize from jax.tree_util import tree_flatten, tree_map import brainpy._src.math as bm from brainpy import optim, losses from brainpy._src.analysis import utils, base, constants -from brainpy._src.dynsys import DynamicalSystem from brainpy._src.context import share +from brainpy._src.deprecations import _input_deprecate_msg +from brainpy._src.dynsys import DynamicalSystem from brainpy._src.helpers import clear_input from brainpy._src.runners import check_and_format_inputs, _f_ops from brainpy.errors import AnalyzerError, UnsupportedError from brainpy.types import ArrayType -from brainpy._src.deprecations import _input_deprecate_msg - __all__ = [ - 'SlowPointFinder', + 'SlowPointFinder', ] F_OPT_SOLVER = 'function_for_opt_solver' F_GRADIENT_DESCENT = 'function_for_gradient_descent' SUPPORTED_OPT_SOLVERS = { - 'BFGS': lambda f, x0: minimize(f, x0, method='BFGS') + 'BFGS': lambda f, x0: minimize(f, x0, method='BFGS') } class SlowPointFinder(base.DSAnalyzer): - """Find fixed/slow points by numerical optimization. - - This class can help you: - - - optimize to find the closest fixed points / slow points - - exclude any fixed points whose fixed point loss is above threshold - - exclude any non-unique fixed points according to a tolerance - - exclude any far-away "outlier" fixed points - - Parameters:: - - f_cell : callable, function, DynamicalSystem - The target of computing the recurrent units. - - f_type : str - The system's type: continuous system or discrete system. - - - 'continuous': continuous derivative function, denotes this is a continuous system, or - - 'discrete': discrete update function, denotes this is a discrete system. - - verbose : bool - Whether output the optimization progress. - - f_loss: callable - The loss function. - - If ``f_type`` is `"discrete"`, the loss function must receive three - arguments, i.e., ``loss(outputs, targets, axis)``. - - If ``f_type`` is `"continuous"`, the loss function must receive two - arguments, i.e., ``loss(outputs, axis)``. - - .. versionadded:: 2.2.0 - t: float - Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. - The time to evaluate the fixed points. Default is 0. - - .. versionadded:: 2.2.0 - dt: float - Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. - The numerical integration step, which can be used when . - The default is given by `brainpy.math.get_dt()`. - - .. versionadded:: 2.2.0 - inputs: sequence, callable - Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. - Same as ``inputs`` in :py:class:`~.DSRunner`. - - .. versionadded:: 2.2.0 - excluded_vars: sequence, dict - Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. - The excluded variables (can be a sequence of `Variable` instances). - These variables will not be included for optimization of fixed points. - - .. versionadded:: 2.2.0 - target_vars: dict - Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. - The target variables (can be a dict of `Variable` instances). - These variables will be included for optimization of fixed points. - The candidate points later provided should have same keys as in ``target_vars``. - - .. versionadded:: 2.2.0 - f_loss_batch : callable, function - Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. - The function to compute the loss. - - .. deprecated:: 2.2.0 - Has been removed. Please use ``f_loss`` to set different loss function. - fun_inputs: callable - - .. deprecated:: 2.3.1 - Will be removed since version 2.4.0. - """ - - def __init__( - self, - f_cell: Union[Callable, DynamicalSystem], - f_type: str = None, - f_loss: Callable = None, - verbose: bool = True, - args: Tuple = (), - - # parameters for `f_cell` is DynamicalSystem instance - inputs: Sequence = None, - t: float = None, - dt: float = None, - target_vars: Dict[str, bm.Variable] = None, - excluded_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, - - # deprecated - f_loss_batch: Callable = None, - fun_inputs: Callable = None, - ): - super().__init__() - - # static arguments - if not isinstance(args, tuple): - raise ValueError(f'args must be an instance of tuple, but we got {type(args)}') - self.args = args - - # update function - if target_vars is None: - self.target_vars = bm.ArrayCollector() - else: - if not isinstance(target_vars, dict): - raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}') - self.target_vars = bm.ArrayCollector(target_vars) - excluded_vars = () if excluded_vars is None else excluded_vars - if isinstance(excluded_vars, dict): - excluded_vars = tuple(excluded_vars.values()) - if not isinstance(excluded_vars, (tuple, list)): - raise TypeError(f'"excluded_vars" must be a sequence but we got {type(excluded_vars)}') - for v in excluded_vars: - if not isinstance(v, bm.Variable): - raise TypeError(f'"excluded_vars" must be a sequence of Variable, ' - f'but we got {type(v)}') - self.excluded_vars = {f'_exclude_v{i}': v for i, v in enumerate(excluded_vars)} - if len(self.target_vars) > 0 and len(self.excluded_vars) > 0: - raise ValueError('"target_vars" and "excluded_vars" cannot be provided simultaneously.') - self.target = f_cell - - if isinstance(f_cell, DynamicalSystem): - # included variables - all_vars = f_cell.vars(method='relative', level=-1, include_self=True).unique() - - # exclude variables - if len(self.target_vars) > 0: - _all_ids = [id(v) for v in self.target_vars.values()] - for k, v in all_vars.items(): - if id(v) not in _all_ids: - self.excluded_vars[k] = v - else: - self.target_vars = all_vars - if len(excluded_vars): - excluded_vars = [id(v) for v in excluded_vars] - for key, val in tuple(self.target_vars.items()): - if id(val) in excluded_vars: - self.target_vars.pop(key) - - # input function - if callable(inputs): - self._inputs = inputs - else: - if inputs is None: - self._inputs = None - else: - self._inputs = check_and_format_inputs(host=self.target, inputs=inputs) - - # check included variables - for var in self.target_vars.values(): - if var.batch_axis is not None: - if var.shape[var.batch_axis] != 1: - raise ValueError(f'Batched variables should has only one batch. ' - f'But we got {var.shape[var.batch_axis]}. Maybe ' - f'you need to call ".reset_state(batch_size=1)" ' - f'for your system.') - - # update function - self.f_cell = self._generate_ds_cell_function(self.target, t, dt) - - # check function type - if f_type is not None: - if f_type != constants.DISCRETE: - raise ValueError(f'"f_type" must be "{constants.DISCRETE}" when "f_cell" ' - f'is instance of {DynamicalSystem.__name__}') - f_type = constants.DISCRETE - - # original data - self.target_data = {k: v.value for k, v in self.target_vars.items()} - self.excluded_data = {k: v.value for k, v in self.excluded_vars.items()} - - elif callable(f_cell): - if len(self.args) > 0: - self.f_cell = lambda x: f_cell(x, *self.args) - else: - self.f_cell = f_cell - if inputs is not None: - raise UnsupportedError('Do not support "inputs" when "f_cell" is not instance of ' - f'{DynamicalSystem.__name__}') - self._inputs = inputs - if t is not None: - raise UnsupportedError('Do not support "t" when "f_cell" is not instance of ' - f'{DynamicalSystem.__name__}') - if dt is not None: - raise UnsupportedError('Do not support "dt" when "f_cell" is not instance of ' - f'{DynamicalSystem.__name__}') - if target_vars is not None: - raise UnsupportedError('Do not support "target_vars" when "f_cell" is not instance of ' - f'{DynamicalSystem.__name__}') - if len(excluded_vars) > 0: - raise UnsupportedError('Do not support "excluded_vars" when "f_cell" is not instance of ' - f'{DynamicalSystem.__name__}') - else: - raise ValueError(f'Unknown type of "f_type": {type(f_cell)}') - if f_type not in [constants.DISCRETE, constants.CONTINUOUS]: - raise AnalyzerError(f'Only support "{constants.CONTINUOUS}" (continuous derivative function) or ' - f'"{constants.DISCRETE}" (discrete update function), not {f_type}.') - self.verbose = verbose - self.f_type = f_type - - # loss functon - if f_loss_batch is not None: - raise UnsupportedError('"f_loss_batch" is no longer supported, please ' - 'use "f_loss" instead.') - if fun_inputs is not None: - raise UnsupportedError('"fun_inputs" is no longer supported.') - if f_loss is None: - f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square - self.f_loss = f_loss - - # essential variables - self._losses = None - self._fixed_points = None - self._selected_ids = None - self._opt_losses = None - - # functions - self._opt_functions = dict() - - @property - def opt_losses(self) -> np.ndarray: - """The optimization losses.""" - return np.asarray(self._opt_losses) - - @opt_losses.setter - def opt_losses(self, val): - raise UnsupportedError('Do not support set "opt_losses" by users.') - - @property - def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: - """The final fixed points found.""" - return tree_map(lambda a: np.asarray(a), self._fixed_points) - - @fixed_points.setter - def fixed_points(self, val): - raise UnsupportedError('Do not support set "fixed_points" by users.') - - @property - def num_fps(self) -> int: - if isinstance(self._fixed_points, dict): - return tuple(self._fixed_points.values())[0].shape[0] - else: - return self._fixed_points.shape[0] - - @property - def losses(self) -> np.ndarray: - """Losses of fixed points.""" - return np.asarray(self._losses) - - @losses.setter - def losses(self, val): - raise UnsupportedError('Do not support set "losses" by users.') - - @property - def selected_ids(self) -> np.ndarray: - """The selected ids of candidate points.""" - return np.asarray(self._selected_ids) - - @selected_ids.setter - def selected_ids(self, val): - raise UnsupportedError('Do not support set "selected_ids" by users.') - - def find_fps_with_gd_method( - self, - candidates: Union[ArrayType, Dict[str, ArrayType]], - tolerance: Union[float, Dict[str, float]] = 1e-5, - num_batch: int = 100, - num_opt: int = 10000, - optimizer: optim.Optimizer = None, - ): - """Optimize fixed points with gradient descent methods. + """Find fixed/slow points by numerical optimization. + + This class can help you: + + - optimize to find the closest fixed points / slow points + - exclude any fixed points whose fixed point loss is above threshold + - exclude any non-unique fixed points according to a tolerance + - exclude any far-away "outlier" fixed points Parameters:: - candidates : ArrayType, dict - The array with the shape of (batch size, state dim) of hidden states - of RNN to start training for fixed points. + f_cell : callable, function, DynamicalSystem + The target of computing the recurrent units. + + f_type : str + The system's type: continuous system or discrete system. + + - 'continuous': continuous derivative function, denotes this is a continuous system, or + - 'discrete': discrete update function, denotes this is a discrete system. + + verbose : bool + Whether output the optimization progress. + + f_loss: callable + The loss function. + - If ``f_type`` is `"discrete"`, the loss function must receive three + arguments, i.e., ``loss(outputs, targets, axis)``. + - If ``f_type`` is `"continuous"`, the loss function must receive two + arguments, i.e., ``loss(outputs, axis)``. + + .. versionadded:: 2.2.0 + t: float + Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. + The time to evaluate the fixed points. Default is 0. + + .. versionadded:: 2.2.0 + dt: float + Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. + The numerical integration step, which can be used when . + The default is given by `brainpy.math.get_dt()`. + + .. versionadded:: 2.2.0 + inputs: sequence, callable + Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. + Same as ``inputs`` in :py:class:`~.DSRunner`. + + .. versionadded:: 2.2.0 + excluded_vars: sequence, dict + Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. + The excluded variables (can be a sequence of `Variable` instances). + These variables will not be included for optimization of fixed points. + + .. versionadded:: 2.2.0 + target_vars: dict + Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. + The target variables (can be a dict of `Variable` instances). + These variables will be included for optimization of fixed points. + The candidate points later provided should have same keys as in ``target_vars``. + + .. versionadded:: 2.2.0 + f_loss_batch : callable, function + Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. + The function to compute the loss. + + .. deprecated:: 2.2.0 + Has been removed. Please use ``f_loss`` to set different loss function. + fun_inputs: callable + + .. deprecated:: 2.3.1 + Will be removed since version 2.4.0. + """ - tolerance: float - The loss threshold during optimization + def __init__( + self, + f_cell: Union[Callable, DynamicalSystem], + f_type: str = None, + f_loss: Callable = None, + verbose: bool = True, + args: Tuple = (), + + # parameters for `f_cell` is DynamicalSystem instance + inputs: Sequence = None, + t: float = None, + dt: float = None, + target_vars: Dict[str, bm.Variable] = None, + excluded_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + + # deprecated + f_loss_batch: Callable = None, + fun_inputs: Callable = None, + ): + super().__init__() + + # static arguments + if not isinstance(args, tuple): + raise ValueError(f'args must be an instance of tuple, but we got {type(args)}') + self.args = args + + # update function + if target_vars is None: + self.target_vars = bm.ArrayCollector() + else: + if not isinstance(target_vars, dict): + raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}') + self.target_vars = bm.ArrayCollector(target_vars) + excluded_vars = () if excluded_vars is None else excluded_vars + if isinstance(excluded_vars, dict): + excluded_vars = tuple(excluded_vars.values()) + if not isinstance(excluded_vars, (tuple, list)): + raise TypeError(f'"excluded_vars" must be a sequence but we got {type(excluded_vars)}') + for v in excluded_vars: + if not isinstance(v, bm.Variable): + raise TypeError(f'"excluded_vars" must be a sequence of Variable, ' + f'but we got {type(v)}') + self.excluded_vars = {f'_exclude_v{i}': v for i, v in enumerate(excluded_vars)} + if len(self.target_vars) > 0 and len(self.excluded_vars) > 0: + raise ValueError('"target_vars" and "excluded_vars" cannot be provided simultaneously.') + self.target = f_cell + + if isinstance(f_cell, DynamicalSystem): + # included variables + all_vars = f_cell.vars(method='relative', level=-1, include_self=True).unique() + + # exclude variables + if len(self.target_vars) > 0: + _all_ids = [id(v) for v in self.target_vars.values()] + for k, v in all_vars.items(): + if id(v) not in _all_ids: + self.excluded_vars[k] = v + else: + self.target_vars = all_vars + if len(excluded_vars): + excluded_vars = [id(v) for v in excluded_vars] + for key, val in tuple(self.target_vars.items()): + if id(val) in excluded_vars: + self.target_vars.pop(key) + + # input function + if callable(inputs): + self._inputs = inputs + else: + if inputs is None: + self._inputs = None + else: + self._inputs = check_and_format_inputs(host=self.target, inputs=inputs) + + # check included variables + for var in self.target_vars.values(): + if var.batch_axis is not None: + if var.shape[var.batch_axis] != 1: + raise ValueError(f'Batched variables should has only one batch. ' + f'But we got {var.shape[var.batch_axis]}. Maybe ' + f'you need to call ".reset_state(batch_size=1)" ' + f'for your system.') + + # update function + self.f_cell = self._generate_ds_cell_function(self.target, t, dt) + + # check function type + if f_type is not None: + if f_type != constants.DISCRETE: + raise ValueError(f'"f_type" must be "{constants.DISCRETE}" when "f_cell" ' + f'is instance of {DynamicalSystem.__name__}') + f_type = constants.DISCRETE + + # original data + self.target_data = {k: v.value for k, v in self.target_vars.items()} + self.excluded_data = {k: v.value for k, v in self.excluded_vars.items()} + + elif callable(f_cell): + if len(self.args) > 0: + self.f_cell = lambda x: f_cell(x, *self.args) + else: + self.f_cell = f_cell + if inputs is not None: + raise UnsupportedError('Do not support "inputs" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + self._inputs = inputs + if t is not None: + raise UnsupportedError('Do not support "t" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + if dt is not None: + raise UnsupportedError('Do not support "dt" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + if target_vars is not None: + raise UnsupportedError('Do not support "target_vars" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + if len(excluded_vars) > 0: + raise UnsupportedError('Do not support "excluded_vars" when "f_cell" is not instance of ' + f'{DynamicalSystem.__name__}') + else: + raise ValueError(f'Unknown type of "f_type": {type(f_cell)}') + if f_type not in [constants.DISCRETE, constants.CONTINUOUS]: + raise AnalyzerError(f'Only support "{constants.CONTINUOUS}" (continuous derivative function) or ' + f'"{constants.DISCRETE}" (discrete update function), not {f_type}.') + self.verbose = verbose + self.f_type = f_type + + # loss functon + if f_loss_batch is not None: + raise UnsupportedError('"f_loss_batch" is no longer supported, please ' + 'use "f_loss" instead.') + if fun_inputs is not None: + raise UnsupportedError('"fun_inputs" is no longer supported.') + if f_loss is None: + f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square + self.f_loss = f_loss + + # essential variables + self._losses = None + self._fixed_points = None + self._selected_ids = None + self._opt_losses = None + + # functions + self._opt_functions = dict() + + @property + def opt_losses(self) -> np.ndarray: + """The optimization losses.""" + return np.asarray(self._opt_losses) + + @opt_losses.setter + def opt_losses(self, val): + raise UnsupportedError('Do not support set "opt_losses" by users.') + + @property + def fixed_points(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """The final fixed points found.""" + return tree_map(lambda a: np.asarray(a), self._fixed_points) + + @fixed_points.setter + def fixed_points(self, val): + raise UnsupportedError('Do not support set "fixed_points" by users.') + + @property + def num_fps(self) -> int: + if isinstance(self._fixed_points, dict): + return tuple(self._fixed_points.values())[0].shape[0] + else: + return self._fixed_points.shape[0] + + @property + def losses(self) -> np.ndarray: + """Losses of fixed points.""" + return np.asarray(self._losses) + + @losses.setter + def losses(self, val): + raise UnsupportedError('Do not support set "losses" by users.') + + @property + def selected_ids(self) -> np.ndarray: + """The selected ids of candidate points.""" + return np.asarray(self._selected_ids) + + @selected_ids.setter + def selected_ids(self, val): + raise UnsupportedError('Do not support set "selected_ids" by users.') + + def find_fps_with_gd_method( + self, + candidates: Union[ArrayType, Dict[str, ArrayType]], + tolerance: Union[float, Dict[str, float]] = 1e-5, + num_batch: int = 100, + num_opt: int = 10000, + optimizer: optim.Optimizer = None, + ): + """Optimize fixed points with gradient descent methods. + + Parameters:: + + candidates : ArrayType, dict + The array with the shape of (batch size, state dim) of hidden states + of RNN to start training for fixed points. + + tolerance: float + The loss threshold during optimization + + num_opt : int + The maximum number of optimization. + + num_batch : int + Print training information during optimization every so often. + + optimizer: optim.Optimizer + The optimizer instance. + + .. versionadded:: 2.1.2 + """ + # optimization settings + if optimizer is None: + optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999), + beta1=0.9, beta2=0.999, eps=1e-8) + else: + if not isinstance(optimizer, optim.Optimizer): + raise ValueError(f'Must be an instance of {optim.Optimizer.__name__}, ' + f'while we got {type(optimizer)}') + + # set up optimization + num_candidate = self._check_candidates(candidates) + if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)): + raise ValueError('Candidates must be instance of ArrayType or dict of ArrayType.') + fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray)) + f_eval_loss = self._get_f_eval_loss() + + def f_loss(): + return f_eval_loss(tree_map(lambda a: bm.as_jax(a), + fixed_points, + is_leaf=lambda x: isinstance(x, bm.BaseArray))).mean() - num_opt : int - The maximum number of optimization. + grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True) + optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) - num_batch : int - Print training information during optimization every so often. + def train(idx): + gradients, loss = grad_f() + optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients}) + optimizer.lr.step_epoch() + return loss - optimizer: optim.Optimizer - The optimizer instance. + def batch_train(start_i, n_batch): + return bm.for_loop(train, bm.arange(start_i, start_i + n_batch)) - .. versionadded:: 2.1.2 - """ - # optimization settings - if optimizer is None: - optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999), - beta1=0.9, beta2=0.999, eps=1e-8) - else: - if not isinstance(optimizer, optim.Optimizer): - raise ValueError(f'Must be an instance of {optim.Optimizer.__name__}, ' - f'while we got {type(optimizer)}') - - # set up optimization - num_candidate = self._check_candidates(candidates) - if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)): - raise ValueError('Candidates must be instance of ArrayType or dict of ArrayType.') - fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray)) - f_eval_loss = self._get_f_eval_loss() - - def f_loss(): - return f_eval_loss(tree_map(lambda a: bm.as_jax(a), - fixed_points, - is_leaf=lambda x: isinstance(x, bm.BaseArray))).mean() - - grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True) - optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points}) - - def train(idx): - gradients, loss = grad_f() - optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients}) - optimizer.lr.step_epoch() - return loss - - def batch_train(start_i, n_batch): - return bm.for_loop(train, bm.arange(start_i, start_i + n_batch)) - - # Run the optimization - if self.verbose: - print(f"Optimizing with {optimizer} to find fixed points:") - opt_losses = [] - do_stop = False - num_opt_loops = int(num_opt / num_batch) - for oidx in range(num_opt_loops): - if do_stop: - break - batch_idx_start = oidx * num_batch - start_time = time.time() - train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch) - batch_time = time.time() - start_time - opt_losses.append(train_losses) - - if self.verbose: - print(f" " - f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} " - f"in {batch_time:0.2f} sec, Training loss {train_losses[-1]:0.10f}") - - if train_losses[-1] < tolerance: - do_stop = True + # Run the optimization if self.verbose: - print(f' ' - f'Stop optimization as mean training loss {train_losses[-1]:0.10f} ' - f'is below tolerance {tolerance:0.10f}.') + print(f"Optimizing with {optimizer} to find fixed points:") + opt_losses = [] + do_stop = False + num_opt_loops = int(num_opt / num_batch) + for oidx in range(num_opt_loops): + if do_stop: + break + batch_idx_start = oidx * num_batch + start_time = time.time() + train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch) + batch_time = time.time() - start_time + opt_losses.append(train_losses) + + if self.verbose: + print(f" " + f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} " + f"in {batch_time:0.2f} sec, Training loss {train_losses[-1]:0.10f}") + + if train_losses[-1] < tolerance: + do_stop = True + if self.verbose: + print(f' ' + f'Stop optimization as mean training loss {train_losses[-1]:0.10f} ' + f'is below tolerance {tolerance:0.10f}.') + + self._opt_losses = jnp.concatenate(opt_losses) + self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a), + fixed_points, + is_leaf=lambda x: isinstance(x, bm.BaseArray))) + self._fixed_points = tree_map(lambda a: bm.as_jax(a), + fixed_points, + is_leaf=lambda x: isinstance(x, bm.BaseArray)) + self._selected_ids = jnp.arange(num_candidate) + + if isinstance(self.target, DynamicalSystem): + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.target_vars.items(): + v.value = self.target_data[k] + + def find_fps_with_opt_solver( + self, + candidates: Union[ArrayType, Dict[str, ArrayType]], + opt_solver: str = 'BFGS' + ): + """Optimize fixed points with nonlinear optimization solvers. + + Parameters:: + + candidates: ArrayType, dict + The candidate (initial) fixed points. + opt_solver: str + The solver of the optimization. + """ + # optimization function + num_candidate = self._check_candidates(candidates) + for var in self.target_vars.values(): + if jnp.ndim(var) != 1: + raise ValueError('Cannot use opt solver.') + if self._opt_functions.get(F_OPT_SOLVER, None) is None: + self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver(candidates, + SUPPORTED_OPT_SOLVERS[opt_solver]) + f_opt = self._opt_functions[F_OPT_SOLVER] - self._opt_losses = jnp.concatenate(opt_losses) - self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a), - fixed_points, - is_leaf=lambda x: isinstance(x, bm.BaseArray))) - self._fixed_points = tree_map(lambda a: bm.as_jax(a), - fixed_points, - is_leaf=lambda x: isinstance(x, bm.BaseArray)) - self._selected_ids = jnp.arange(num_candidate) - - if isinstance(self.target, DynamicalSystem): - for k, v in self.excluded_vars.items(): - v.value = self.excluded_data[k] - for k, v in self.target_vars.items(): - v.value = self.target_data[k] - - def find_fps_with_opt_solver( - self, - candidates: Union[ArrayType, Dict[str, ArrayType]], - opt_solver: str = 'BFGS' - ): - """Optimize fixed points with nonlinear optimization solvers. + if self.verbose: + print(f"Optimizing with {opt_solver} to find fixed points:") + + # optimizing + res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))) + + # results + valid_ids = jnp.where(res.success)[0] + fixed_points = res.x[valid_ids] + if isinstance(candidates, dict): + indices = [0] + for v in candidates.values(): + indices.append(v.shape[1]) + indices = np.cumsum(indices) + keys = tuple(candidates.keys()) + self._fixed_points = {key: fixed_points[:, indices[i]: indices[i + 1]] + for i, key in enumerate(keys)} + else: + self._fixed_points = fixed_points + self._losses = res.fun[valid_ids] + self._selected_ids = jnp.asarray(valid_ids) + if self.verbose: + print(f' ' + f'Found {len(valid_ids)} fixed points from {num_candidate} initial points.') - Parameters:: + def filter_loss(self, tolerance: float = 1e-5): + """Filter fixed points whose speed larger than a given tolerance. - candidates: ArrayType, dict - The candidate (initial) fixed points. - opt_solver: str - The solver of the optimization. - """ - # optimization function - num_candidate = self._check_candidates(candidates) - for var in self.target_vars.values(): - if jnp.ndim(var) != 1: - raise ValueError('Cannot use opt solver.') - if self._opt_functions.get(F_OPT_SOLVER, None) is None: - self._opt_functions[F_OPT_SOLVER] = self._get_f_for_opt_solver(candidates, SUPPORTED_OPT_SOLVERS[opt_solver]) - f_opt = self._opt_functions[F_OPT_SOLVER] - - if self.verbose: - print(f"Optimizing with {opt_solver} to find fixed points:") - - # optimizing - res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))) - - # results - valid_ids = jnp.where(res.success)[0] - fixed_points = res.x[valid_ids] - if isinstance(candidates, dict): - indices = [0] - for v in candidates.values(): - indices.append(v.shape[1]) - indices = np.cumsum(indices) - keys = tuple(candidates.keys()) - self._fixed_points = {key: fixed_points[:, indices[i]: indices[i + 1]] - for i, key in enumerate(keys)} - else: - self._fixed_points = fixed_points - self._losses = res.fun[valid_ids] - self._selected_ids = jnp.asarray(valid_ids) - if self.verbose: - print(f' ' - f'Found {len(valid_ids)} fixed points from {num_candidate} initial points.') - - def filter_loss(self, tolerance: float = 1e-5): - """Filter fixed points whose speed larger than a given tolerance. + Parameters:: - Parameters:: + tolerance: float + Discard fixed points with squared speed larger than this value. + """ + if self.verbose: + print(f"Excluding fixed points with squared speed above " + f"tolerance {tolerance}:") + if isinstance(self._fixed_points, dict): + num_fps = tuple(self._fixed_points.values())[0].shape[0] + else: + num_fps = self._fixed_points.shape[0] + ids = self._losses < tolerance + keep_ids = bm.as_jax(bm.where(ids)[0]) + self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) + self._losses = self._losses[keep_ids] + self._selected_ids = self._selected_ids[keep_ids] + if self.verbose: + print(f" " + f"Kept {len(keep_ids)}/{num_fps} " + f"fixed points with tolerance under {tolerance}.") - tolerance: float - Discard fixed points with squared speed larger than this value. - """ - if self.verbose: - print(f"Excluding fixed points with squared speed above " - f"tolerance {tolerance}:") - if isinstance(self._fixed_points, dict): - num_fps = tuple(self._fixed_points.values())[0].shape[0] - else: - num_fps = self._fixed_points.shape[0] - ids = self._losses < tolerance - keep_ids = bm.as_jax(bm.where(ids)[0]) - self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) - self._losses = self._losses[keep_ids] - self._selected_ids = self._selected_ids[keep_ids] - if self.verbose: - print(f" " - f"Kept {len(keep_ids)}/{num_fps} " - f"fixed points with tolerance under {tolerance}.") - - def keep_unique(self, tolerance: float = 2.5e-2): - """Filter unique fixed points by choosing a representative within tolerance. + def keep_unique(self, tolerance: float = 2.5e-2): + """Filter unique fixed points by choosing a representative within tolerance. - Parameters:: + Parameters:: - tolerance: float - Tolerance for determination of identical fixed points. - """ - if self.verbose: - print("Excluding non-unique fixed points:") - if isinstance(self._fixed_points, dict): - num_fps = tuple(self._fixed_points.values())[0].shape[0] - else: - num_fps = self._fixed_points.shape[0] - fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance) - self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps) - self._losses = self._losses[keep_ids] - self._selected_ids = self._selected_ids[keep_ids] - if self.verbose: - print(f" Kept {keep_ids.shape[0]}/{num_fps} unique fixed points " - f"with uniqueness tolerance {tolerance}.") - - def exclude_outliers(self, tolerance: float = 1e0): - """Exclude points whose closest neighbor is further than threshold. + tolerance: float + Tolerance for determination of identical fixed points. + """ + if self.verbose: + print("Excluding non-unique fixed points:") + if isinstance(self._fixed_points, dict): + num_fps = tuple(self._fixed_points.values())[0].shape[0] + else: + num_fps = self._fixed_points.shape[0] + fps, keep_ids = utils.keep_unique(self.fixed_points, tolerance=tolerance) + self._fixed_points = tree_map(lambda a: jnp.asarray(a), fps) + self._losses = self._losses[keep_ids] + self._selected_ids = self._selected_ids[keep_ids] + if self.verbose: + print(f" Kept {keep_ids.shape[0]}/{num_fps} unique fixed points " + f"with uniqueness tolerance {tolerance}.") - Parameters:: + def exclude_outliers(self, tolerance: float = 1e0): + """Exclude points whose closest neighbor is further than threshold. - tolerance: float - Any point whose closest fixed point is greater than tol is an outlier. - """ - if self.verbose: - print("Excluding outliers:") - if np.isinf(tolerance): - return - if isinstance(self._fixed_points, dict): - num_fps = tuple(self._fixed_points.values())[0].shape[0] - else: - num_fps = self._fixed_points.shape[0] - if num_fps <= 1: - return - - # Compute pairwise distances between all fixed points. - distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps)) - - # Find the second smallest element in each column of the pairwise distance matrix. - # This corresponds to the closest neighbor for each fixed point. - closest_neighbor = np.partition(distances, kth=1, axis=0)[1] - - # Return data with outliers removed and indices of kept datapoints. - keep_ids = np.where(closest_neighbor < tolerance)[0] - self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) - self._selected_ids = self._selected_ids[keep_ids] - self._losses = self._losses[keep_ids] - - if self.verbose: - print(f" " - f"Kept {keep_ids.shape[0]}/{num_fps} fixed points " - f"with within outlier tolerance {tolerance}.") - - def compute_jacobians( - self, - points: Union[ArrayType, Dict[str, ArrayType]], - stack_dict_var: bool = True, - plot: bool = False, - num_col: int = 4, - len_col: int = 3, - len_row: int = 2, - ): - """Compute the Jacobian matrices at the points. + Parameters:: - Parameters:: + tolerance: float + Any point whose closest fixed point is greater than tol is an outlier. + """ + if self.verbose: + print("Excluding outliers:") + if np.isinf(tolerance): + return + if isinstance(self._fixed_points, dict): + num_fps = tuple(self._fixed_points.values())[0].shape[0] + else: + num_fps = self._fixed_points.shape[0] + if num_fps <= 1: + return - points: np.ndarray, bm.ArrayType, jax.ndarray - The fixed points with the shape of (num_point, num_dim). - stack_dict_var: bool - Stack dictionary variables to calculate Jacobian matrix? - plot: bool - Plot the decomposition results of the Jacobian matrix. - num_col: int - The number of the figure column. - len_col: int - The length of each column. - len_row: int - The length of each row. - """ - # check data - info = np.asarray([(l.ndim, l.shape[0]) - for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]]) - ndim = np.unique(info[:, 0]) - if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}') - if ndim[0] == 1: - points = tree_map(lambda a: bm.asarray([a]), points) - num_point = 1 - elif ndim[0] == 2: - nsize = np.unique(info[:, 1]) - if len(nsize) != 1: raise ValueError(f'Number of the evaluated points are mis-matched. {nsize}') - num_point = nsize[0] - else: - raise ValueError('Only support points of 1D: (num_feature,) or 2D: (num_point, num_feature)') - if isinstance(points, dict) and stack_dict_var: - points = jnp.hstack(tuple(points.values())) - - # get Jacobian matrix - jacobian = self._get_f_jocabian(stack_dict_var)(points) - - # visualization - if plot: - import matplotlib.pyplot as plt - from brainpy import visualize - jacobian = bm.as_numpy(jacobian) - - num_col = min(num_col, num_point) - num_row = int(math.ceil(num_point / num_col)) - fig, gs = visualize.get_figure(num_row, num_col, len_row, len_col) - for i in range(num_point): - eigval, eigvec = np.linalg.eig(np.asarray(jacobian[i])) - ax = fig.add_subplot(gs[i // num_col, i % num_col]) - ax.scatter(np.real(eigval), np.imag(eigval)) - ax.plot([1, 1] if self.f_type == constants.DISCRETE else [0, 0], [-1, 1], '--') - ax.set_xlabel('Real') - ax.set_ylabel('Imaginary') - ax.set_title(f'Point {i}') - plt.show() - - return jacobian - - @staticmethod - def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False): - """Compute the eigenvalues of the matrices. + # Compute pairwise distances between all fixed points. + distances = np.asarray(utils.euclidean_distance_jax(self.fixed_points, num_fps)) - Parameters:: + # Find the second smallest element in each column of the pairwise distance matrix. + # This corresponds to the closest neighbor for each fixed point. + closest_neighbor = np.partition(distances, kth=1, axis=0)[1] - matrices: np.ndarray, bm.ArrayType, jax.ndarray - A 3D array with the shape of (num_matrices, dim, dim). - sort_by: str - The method of sorting. - do_compute_lefts: bool - Compute the left eigenvectors? Requires a pseudo-inverse call. + # Return data with outliers removed and indices of kept datapoints. + keep_ids = np.where(closest_neighbor < tolerance)[0] + self._fixed_points = tree_map(lambda a: a[keep_ids], self._fixed_points) + self._selected_ids = self._selected_ids[keep_ids] + self._losses = self._losses[keep_ids] - Returns:: + if self.verbose: + print(f" " + f"Kept {keep_ids.shape[0]}/{num_fps} fixed points " + f"with within outlier tolerance {tolerance}.") + + def compute_jacobians( + self, + points: Union[ArrayType, Dict[str, ArrayType]], + stack_dict_var: bool = True, + plot: bool = False, + num_col: int = 4, + len_col: int = 3, + len_row: int = 2, + ): + """Compute the Jacobian matrices at the points. + + Parameters:: + + points: np.ndarray, bm.ArrayType, jax.ndarray + The fixed points with the shape of (num_point, num_dim). + stack_dict_var: bool + Stack dictionary variables to calculate Jacobian matrix? + plot: bool + Plot the decomposition results of the Jacobian matrix. + num_col: int + The number of the figure column. + len_col: int + The length of each column. + len_row: int + The length of each row. + """ + # check data + info = np.asarray([(l.ndim, l.shape[0]) + for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]]) + ndim = np.unique(info[:, 0]) + if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}') + if ndim[0] == 1: + points = tree_map(lambda a: bm.asarray([a]), points) + num_point = 1 + elif ndim[0] == 2: + nsize = np.unique(info[:, 1]) + if len(nsize) != 1: raise ValueError(f'Number of the evaluated points are mis-matched. {nsize}') + num_point = nsize[0] + else: + raise ValueError('Only support points of 1D: (num_feature,) or 2D: (num_point, num_feature)') + if isinstance(points, dict) and stack_dict_var: + points = jnp.hstack(tuple(points.values())) + + # get Jacobian matrix + jacobian = self._get_f_jocabian(stack_dict_var)(points) + + # visualization + if plot: + import matplotlib.pyplot as plt + from brainpy import visualize + jacobian = bm.as_numpy(jacobian) + + num_col = min(num_col, num_point) + num_row = int(math.ceil(num_point / num_col)) + fig, gs = visualize.get_figure(num_row, num_col, len_row, len_col) + for i in range(num_point): + eigval, eigvec = np.linalg.eig(np.asarray(jacobian[i])) + ax = fig.add_subplot(gs[i // num_col, i % num_col]) + ax.scatter(np.real(eigval), np.imag(eigval)) + ax.plot([1, 1] if self.f_type == constants.DISCRETE else [0, 0], [-1, 1], '--') + ax.set_xlabel('Real') + ax.set_ylabel('Imaginary') + ax.set_title(f'Point {i}') + plt.show() + + return jacobian + + @staticmethod + def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False): + """Compute the eigenvalues of the matrices. + + Parameters:: + + matrices: np.ndarray, bm.ArrayType, jax.ndarray + A 3D array with the shape of (num_matrices, dim, dim). + sort_by: str + The method of sorting. + do_compute_lefts: bool + Compute the left eigenvectors? Requires a pseudo-inverse call. + + Returns:: + + decompositions : list + A list of dictionaries with sorted eigenvalues components: + (eigenvalues, right eigenvectors, and left eigenvectors). + """ + if sort_by == 'magnitude': + sort_fun = np.abs + elif sort_by == 'real': + sort_fun = np.real + else: + raise ValueError("Not implemented yet.") + matrices = np.asarray(matrices) + + decompositions = [] + for mat in matrices: + eig_values, eig_vectors = np.linalg.eig(mat) + indices = np.flipud(np.argsort(sort_fun(eig_values))) + L = None + if do_compute_lefts: + L = np.linalg.pinv(eig_vectors).T # as columns + L = L[:, indices] + decompositions.append({'eig_values': eig_values[indices], + 'R': eig_vectors[:, indices], + 'L': L}) + return decompositions + + def _step_func_input(self): + if self._inputs is None: + return + elif callable(self._inputs): + try: + ba = inspect.signature(self._inputs).bind(dict()) + self._inputs(share.get_shargs()) + warnings.warn(_input_deprecate_msg, UserWarning) + except TypeError: + self._inputs() + else: + for ops, values in self._inputs['fixed'].items(): + for var, data in values: + _f_ops(ops, var, data) + for ops, values in self._inputs['array'].items(): + if len(values) > 0: + raise UnsupportedError + for ops, values in self._inputs['functional'].items(): + for var, data in values: + _f_ops(ops, var, data(share.get_shargs())) + for ops, values in self._inputs['iterated'].items(): + if len(values) > 0: + raise UnsupportedError + + def _get_f_eval_loss(self, ): + name = 'f_eval_loss' + if name not in self._opt_functions: + self._opt_functions[name] = self._generate_f_eval_loss() + return self._opt_functions[name] + + def _generate_f_eval_loss(self): + # evaluate losses of a batch of inputs + if self.f_type == constants.DISCRETE: + f_eval_loss = lambda h: self.f_loss(h, jax.vmap(self.f_cell)(h), axis=1) + else: + f_eval_loss = lambda h: self.f_loss(jax.vmap(self.f_cell)(h), axis=1) + + if isinstance(self.target, DynamicalSystem): + @jax.jit + def loss_func(h): + r = f_eval_loss(h) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.target_vars.items(): + v.value = self.target_data[k] + return r + + return loss_func + else: + return jax.jit(f_eval_loss) + + def _get_f_for_opt_solver(self, candidates, opt_method): + # loss function + if self.f_type == constants.DISCRETE: + # overall loss function for fixed points optimization + if isinstance(candidates, dict): + keys = tuple(self.target_vars.keys()) + indices = [0] + for v in self.target_vars.values(): + indices.append(v.shape[0]) + indices = np.cumsum(indices) + + def f_loss(h): + h = {key: h[indices[i]: indices[i + 1]] for i, key in enumerate(keys)} + return bm.as_jax(self.f_loss(h, self.f_cell(h))) + else: + def f_loss(h): + return bm.as_jax(self.f_loss(h, self.f_cell(h))) + else: + # overall loss function for fixed points optimization + def f_loss(h): + return self.f_loss(self.f_cell(h)) + + @jax.jit + @jax.vmap + def f_opt(x0): + for k, v in self.target_vars.items(): + v.value = x0[k] if (v.batch_axis is None) else jnp.expand_dims(x0[k], axis=v.batch_axis) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + if isinstance(x0, dict): + x0 = jnp.concatenate(tuple(x0.values())) + return opt_method(f_loss, x0) + + def call_opt(x): + r = f_opt(x) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.target_vars.items(): + v.value = self.target_data[k] + return r + + return call_opt if isinstance(self.target, DynamicalSystem) else f_opt + + def _generate_ds_cell_function( + self, target, + t: float = None, + dt: float = None, + ): + if dt is None: dt = bm.get_dt() + if t is None: t = 0. + + def f_cell(h: Dict): + share.save(t=t, i=0, dt=dt) + + # update target variables + for k, v in self.target_vars.items(): + v.value = (bm.asarray(h[k], dtype=v.dtype) + if v.batch_axis is None else + bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype)) + + # update excluded variables + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + + # add inputs + clear_input(target) + self._step_func_input() + + # call update functions + target(*self.args) + + # get new states + new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis)) + for k, v in self.target_vars.items()} + return new_h + + return f_cell + + def _get_f_jocabian(self, stack=True): + name = f'f_eval_jacobian_stack={stack}' + if name not in self._opt_functions: + self._opt_functions[name] = self._generate_ds_jocabian(stack) + return self._opt_functions[name] + + def _generate_ds_jocabian(self, stack=True): + if stack and isinstance(self.target, DynamicalSystem): + indices = [0] + for var in self.target_vars.values(): + shape = list(var.shape) + if var.batch_axis is not None: + shape.pop(var.batch_axis) + indices.append(np.prod(shape)) + indices = np.cumsum(indices) + + def jacob(x0): + x0 = {k: x0[indices[i]:indices[i + 1]] for i, k in enumerate(self.target_vars.keys())} + r = self.f_cell(x0) + return jnp.concatenate(list(r.values())) + else: + jacob = self.f_cell - decompositions : list - A list of dictionaries with sorted eigenvalues components: - (eigenvalues, right eigenvectors, and left eigenvectors). - """ - if sort_by == 'magnitude': - sort_fun = np.abs - elif sort_by == 'real': - sort_fun = np.real - else: - raise ValueError("Not implemented yet.") - matrices = np.asarray(matrices) - - decompositions = [] - for mat in matrices: - eig_values, eig_vectors = np.linalg.eig(mat) - indices = np.flipud(np.argsort(sort_fun(eig_values))) - L = None - if do_compute_lefts: - L = np.linalg.pinv(eig_vectors).T # as columns - L = L[:, indices] - decompositions.append({'eig_values': eig_values[indices], - 'R': eig_vectors[:, indices], - 'L': L}) - return decompositions - - def _step_func_input(self): - if self._inputs is None: - return - elif callable(self._inputs): - try: - ba = inspect.signature(self._inputs).bind(dict()) - self._inputs(share.get_shargs()) - warnings.warn(_input_deprecate_msg, UserWarning) - except TypeError: - self._inputs() - else: - for ops, values in self._inputs['fixed'].items(): - for var, data in values: - _f_ops(ops, var, data) - for ops, values in self._inputs['array'].items(): - if len(values) > 0: - raise UnsupportedError - for ops, values in self._inputs['functional'].items(): - for var, data in values: - _f_ops(ops, var, data(share.get_shargs())) - for ops, values in self._inputs['iterated'].items(): - if len(values) > 0: - raise UnsupportedError - - def _get_f_eval_loss(self, ): - name = 'f_eval_loss' - if name not in self._opt_functions: - self._opt_functions[name] = self._generate_f_eval_loss() - return self._opt_functions[name] - - def _generate_f_eval_loss(self): - # evaluate losses of a batch of inputs - if self.f_type == constants.DISCRETE: - f_eval_loss = lambda h: self.f_loss(h, jax.vmap(self.f_cell)(h), axis=1) - else: - f_eval_loss = lambda h: self.f_loss(jax.vmap(self.f_cell)(h), axis=1) - - if isinstance(self.target, DynamicalSystem): - @jax.jit - def loss_func(h): - r = f_eval_loss(h) - for k, v in self.excluded_vars.items(): - v.value = self.excluded_data[k] - for k, v in self.target_vars.items(): - v.value = self.target_data[k] - return r - - return loss_func - else: - return jax.jit(f_eval_loss) - - def _get_f_for_opt_solver(self, candidates, opt_method): - # loss function - if self.f_type == constants.DISCRETE: - # overall loss function for fixed points optimization - if isinstance(candidates, dict): - keys = tuple(self.target_vars.keys()) - indices = [0] - for v in self.target_vars.values(): - indices.append(v.shape[0]) - indices = np.cumsum(indices) - - def f_loss(h): - h = {key: h[indices[i]: indices[i + 1]] for i, key in enumerate(keys)} - return bm.as_jax(self.f_loss(h, self.f_cell(h))) - else: - def f_loss(h): - return bm.as_jax(self.f_loss(h, self.f_cell(h))) - else: - # overall loss function for fixed points optimization - def f_loss(h): - return self.f_loss(self.f_cell(h)) - - @jax.jit - @jax.vmap - def f_opt(x0): - for k, v in self.target_vars.items(): - v.value = x0[k] if (v.batch_axis is None) else jnp.expand_dims(x0[k], axis=v.batch_axis) - for k, v in self.excluded_vars.items(): - v.value = self.excluded_data[k] - if isinstance(x0, dict): - x0 = jnp.concatenate(tuple(x0.values())) - return opt_method(f_loss, x0) - - def call_opt(x): - r = f_opt(x) - for k, v in self.excluded_vars.items(): - v.value = self.excluded_data[k] - for k, v in self.target_vars.items(): - v.value = self.target_data[k] - return r - - return call_opt if isinstance(self.target, DynamicalSystem) else f_opt - - def _generate_ds_cell_function( - self, target, - t: float = None, - dt: float = None, - ): - if dt is None: dt = bm.get_dt() - if t is None: t = 0. - - def f_cell(h: Dict): - share.save(t=t, i=0, dt=dt) - - # update target variables - for k, v in self.target_vars.items(): - v.value = (bm.asarray(h[k], dtype=v.dtype) - if v.batch_axis is None else - bm.asarray(bm.expand_dims(h[k], axis=v.batch_axis), dtype=v.dtype)) - - # update excluded variables - for k, v in self.excluded_vars.items(): - v.value = self.excluded_data[k] - - # add inputs - clear_input(target) - self._step_func_input() - - # call update functions - target(*self.args) - - # get new states - new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis)) - for k, v in self.target_vars.items()} - return new_h - - return f_cell - - def _get_f_jocabian(self, stack=True): - name = f'f_eval_jacobian_stack={stack}' - if name not in self._opt_functions: - self._opt_functions[name] = self._generate_ds_jocabian(stack) - return self._opt_functions[name] - - def _generate_ds_jocabian(self, stack=True): - if stack and isinstance(self.target, DynamicalSystem): - indices = [0] - for var in self.target_vars.values(): - shape = list(var.shape) - if var.batch_axis is not None: - shape.pop(var.batch_axis) - indices.append(np.prod(shape)) - indices = np.cumsum(indices) - - def jacob(x0): - x0 = {k: x0[indices[i]:indices[i + 1]] for i, k in enumerate(self.target_vars.keys())} - r = self.f_cell(x0) - return jnp.concatenate(list(r.values())) - else: - jacob = self.f_cell - - f_jac = jax.jit(jax.vmap(bm.jacobian(jacob))) - - if isinstance(self.target, DynamicalSystem): - def jacobian_func(x): - r = f_jac(x) - for k, v in self.excluded_vars.items(): - v.value = self.excluded_data[k] - for k, v in self.target_vars.items(): - v.value = self.target_data[k] - return r - - return jacobian_func - else: - return f_jac - - def _check_candidates(self, candidates): - if isinstance(self.target, DynamicalSystem): - if not isinstance(candidates, dict): - raise ValueError(f'When "f_cell" is instance of {DynamicalSystem.__name__}, ' - f'we should provide "candidates" as a dict, in which the key is ' - f'the variable name with relative path, and the value ' - f'is the candidate fixed point values. ') - for key in candidates: - if key not in self.target_vars: - raise KeyError(f'"{key}" is not defined in required variables ' - f'for fixed point optimization of {self.target}. ' - f'Please do not provide its initial values.') - - for key in self.target_vars.keys(): - if key not in candidates: - raise KeyError(f'"{key}" is defined in required variables ' - f'for fixed point optimization of {self.target}. ' - f'Please provide its initial values.') - for key, value in candidates.items(): - if self.target_vars[key].batch_axis is None: - if value.ndim != self.target_vars[key].ndim + 1: - raise ValueError(f'"{key}" is defined in the required variables for fixed ' - f'point optimization of {self.target}. \n' - f'We expect the provided candidate has a batch size, ' - f'but we got {value.shape} for variable with shape of ' - f'{self.target_vars[key].shape}') + f_jac = jax.jit(jax.vmap(bm.jacobian(jacob))) + + if isinstance(self.target, DynamicalSystem): + def jacobian_func(x): + r = f_jac(x) + for k, v in self.excluded_vars.items(): + v.value = self.excluded_data[k] + for k, v in self.target_vars.items(): + v.value = self.target_data[k] + return r + + return jacobian_func + else: + return f_jac + + def _check_candidates(self, candidates): + if isinstance(self.target, DynamicalSystem): + if not isinstance(candidates, dict): + raise ValueError(f'When "f_cell" is instance of {DynamicalSystem.__name__}, ' + f'we should provide "candidates" as a dict, in which the key is ' + f'the variable name with relative path, and the value ' + f'is the candidate fixed point values. ') + for key in candidates: + if key not in self.target_vars: + raise KeyError(f'"{key}" is not defined in required variables ' + f'for fixed point optimization of {self.target}. ' + f'Please do not provide its initial values.') + + for key in self.target_vars.keys(): + if key not in candidates: + raise KeyError(f'"{key}" is defined in required variables ' + f'for fixed point optimization of {self.target}. ' + f'Please provide its initial values.') + for key, value in candidates.items(): + if self.target_vars[key].batch_axis is None: + if value.ndim != self.target_vars[key].ndim + 1: + raise ValueError(f'"{key}" is defined in the required variables for fixed ' + f'point optimization of {self.target}. \n' + f'We expect the provided candidate has a batch size, ' + f'but we got {value.shape} for variable with shape of ' + f'{self.target_vars[key].shape}') + else: + if value.ndim != self.target_vars[key].ndim: + raise ValueError(f'"{key}" is defined in the required variables for fixed ' + f'point optimization of {self.target}. \n' + f'We expect the provided candidate has a batch size, ' + f'but we got {value.shape} for variable with shape of ' + f'{self.target_vars[key].shape}') + + if isinstance(candidates, dict): + num_candidate = np.unique([leaf.shape[0] for leaf in candidates.values()]) + if len(num_candidate) != 1: + raise ValueError('The numbers of candidates for each variable should be the same. ' + f'But we got {num_candidate}') + num_candidate = num_candidate[0] else: - if value.ndim != self.target_vars[key].ndim: - raise ValueError(f'"{key}" is defined in the required variables for fixed ' - f'point optimization of {self.target}. \n' - f'We expect the provided candidate has a batch size, ' - f'but we got {value.shape} for variable with shape of ' - f'{self.target_vars[key].shape}') - - if isinstance(candidates, dict): - num_candidate = np.unique([leaf.shape[0] for leaf in candidates.values()]) - if len(num_candidate) != 1: - raise ValueError('The numbers of candidates for each variable should be the same. ' - f'But we got {num_candidate}') - num_candidate = num_candidate[0] - else: - num_candidate = candidates.shape[0] - return num_candidate + num_candidate = candidates.shape[0] + return num_candidate diff --git a/brainpy/_src/analysis/highdim/tests/test_slow_points.py b/brainpy/_src/analysis/highdim/tests/test_slow_points.py index d63c3f383..c30ce6f60 100644 --- a/brainpy/_src/analysis/highdim/tests/test_slow_points.py +++ b/brainpy/_src/analysis/highdim/tests/test_slow_points.py @@ -7,165 +7,165 @@ class HH(bp.dyn.NeuDyn): - def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, - V_th=20., C=1.0, name=None): - super(HH, self).__init__(size=size, name=name) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - - # variables - self.V = bm.Variable(bm.ones(self.num) * -65.) - self.m = bm.Variable(0.5 * bm.ones(self.num)) - self.h = bm.Variable(0.6 * bm.ones(self.num)) - self.n = bm.Variable(0.32 * bm.ones(self.num)) - self.spike = bm.Variable(bm.zeros(size, dtype=bool)) - self.input = bm.Variable(bm.zeros(size)) - - # integral functions - self.int_h = bp.ode.ExponentialEuler(self.dh) - self.int_n = bp.ode.ExponentialEuler(self.dn) - self.int_m = bp.ode.ExponentialEuler(self.dm) - self.int_V = bp.ode.ExponentialEuler(self.dV) - - def dh(self, h, t, V): - alpha = 0.07 * bm.exp(-(V + 65) / 20.) - beta = 1 / (1 + bm.exp(-(V + 35) / 10)) - dhdt = alpha * (1 - h) - beta * h - return dhdt - - def dn(self, n, t, V): - alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) - beta = 0.125 * bm.exp(-(V + 65) / 80) - dndt = alpha * (1 - n) - beta * n - return dndt - - def dm(self, m, t, V): - alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) - beta = 4.0 * bm.exp(-(V + 65) / 18) - dmdt = alpha * (1 - m) - beta * m - return dmdt - - def dV(self, V, t, m, h, n, Iext): - I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) - I_K = (self.gK * n ** 4.0) * (V - self.EK) - I_leak = self.gL * (V - self.EL) - dVdt = (- I_Na - I_K - I_leak + Iext) / self.C - return dVdt - - def update(self, tdi): - t, dt = tdi.t, tdi.dt - m = self.int_m(self.m.value, t, self.V.value, dt=dt) - h = self.int_h(self.h.value, t, self.V.value, dt=dt) - n = self.int_n(self.n.value, t, self.V.value, dt=dt) - V = self.int_V(self.V.value, t, self.m.value, self.h.value, self.n.value, self.input.value, dt=dt) - self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - self.V.value = V - self.h.value = h - self.n.value = n - self.m.value = m - self.input[:] = 0. + def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, + V_th=20., C=1.0, name=None): + super(HH, self).__init__(size=size, name=name) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + + # variables + self.V = bm.Variable(bm.ones(self.num) * -65.) + self.m = bm.Variable(0.5 * bm.ones(self.num)) + self.h = bm.Variable(0.6 * bm.ones(self.num)) + self.n = bm.Variable(0.32 * bm.ones(self.num)) + self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + self.input = bm.Variable(bm.zeros(size)) + + # integral functions + self.int_h = bp.ode.ExponentialEuler(self.dh) + self.int_n = bp.ode.ExponentialEuler(self.dn) + self.int_m = bp.ode.ExponentialEuler(self.dm) + self.int_V = bp.ode.ExponentialEuler(self.dV) + + def dh(self, h, t, V): + alpha = 0.07 * bm.exp(-(V + 65) / 20.) + beta = 1 / (1 + bm.exp(-(V + 35) / 10)) + dhdt = alpha * (1 - h) - beta * h + return dhdt + + def dn(self, n, t, V): + alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + beta = 0.125 * bm.exp(-(V + 65) / 80) + dndt = alpha * (1 - n) - beta * n + return dndt + + def dm(self, m, t, V): + alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + beta = 4.0 * bm.exp(-(V + 65) / 18) + dmdt = alpha * (1 - m) - beta * m + return dmdt + + def dV(self, V, t, m, h, n, Iext): + I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa) + I_K = (self.gK * n ** 4.0) * (V - self.EK) + I_leak = self.gL * (V - self.EL) + dVdt = (- I_Na - I_K - I_leak + Iext) / self.C + return dVdt + + def update(self, tdi): + t, dt = tdi.t, tdi.dt + m = self.int_m(self.m.value, t, self.V.value, dt=dt) + h = self.int_h(self.h.value, t, self.V.value, dt=dt) + n = self.int_n(self.n.value, t, self.V.value, dt=dt) + V = self.int_V(self.V.value, t, self.m.value, self.h.value, self.n.value, self.input.value, dt=dt) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.h.value = h + self.n.value = n + self.m.value = m + self.input[:] = 0. class TestFixedPointsFinding(unittest.TestCase): - def test_opt_solver_for_func1(self): - gamma = 0.641 # Saturation factor for gating variable - tau = 0.06 # Synaptic time constant [sec] - a = 270. - b = 108. - d = 0.154 - - JE = 0.3725 # self-coupling strength [nA] - JI = -0.1137 # cross-coupling strength [nA] - JAext = 0.00117 # Stimulus input strength [nA] - - mu = 20. # Stimulus firing rate [spikes/sec] - coh = 0.5 # Stimulus coherence [%] - Ib1 = 0.3297 - Ib2 = 0.3297 - - def ds1(s1, t, s2, coh=0.5, mu=20.): - I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) - r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) - return - s1 / tau + (1. - s1) * gamma * r1 - - def ds2(s2, t, s1, coh=0.5, mu=20.): - I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) - r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) - return - s2 / tau + (1. - s2) * gamma * r2 - - def step(s): - return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])]) - - rng = bm.random.RandomState(123) - finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) - finder.find_fps_with_opt_solver(rng.random((100, 2))) - - def test_opt_solver_for_ds1(self): - hh = HH(1) - finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike]) - rng = bm.random.RandomState(123) - - with self.assertRaises(ValueError): - finder.find_fps_with_opt_solver(rng.random((100, 4))) - - finder.find_fps_with_opt_solver({'V': rng.random((100, 1)), - 'm': rng.random((100, 1)), - 'h': rng.random((100, 1)), - 'n': rng.random((100, 1))}) - - def test_gd_method_for_func1(self): - gamma = 0.641 # Saturation factor for gating variable - tau = 0.06 # Synaptic time constant [sec] - a = 270. - b = 108. - d = 0.154 - - JE = 0.3725 # self-coupling strength [nA] - JI = -0.1137 # cross-coupling strength [nA] - JAext = 0.00117 # Stimulus input strength [nA] - - mu = 20. # Stimulus firing rate [spikes/sec] - coh = 0.5 # Stimulus coherence [%] - Ib1 = 0.3297 - Ib2 = 0.3297 - - def ds1(s1, t, s2, coh=0.5, mu=20.): - I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) - r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) - return - s1 / tau + (1. - s1) * gamma * r1 - - def ds2(s2, t, s1, coh=0.5, mu=20.): - I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) - r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) - return - s2 / tau + (1. - s2) * gamma * r2 - - def step(s): - return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])]) - - rng = bm.random.RandomState(123) - finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) - finder.find_fps_with_gd_method(rng.random((100, 2)), num_opt=100) - - def test_gd_method_for_func2(self): - hh = HH(1) - finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike]) - rng = bm.random.RandomState(123) - - with self.assertRaises(ValueError): - finder.find_fps_with_opt_solver(rng.random((100, 4))) - - finder.find_fps_with_gd_method( - {'V': rng.random((100, 1)), - 'm': rng.random((100, 1)), - 'h': rng.random((100, 1)), - 'n': rng.random((100, 1))}, - num_opt=100 - ) + def test_opt_solver_for_func1(self): + gamma = 0.641 # Saturation factor for gating variable + tau = 0.06 # Synaptic time constant [sec] + a = 270. + b = 108. + d = 0.154 + + JE = 0.3725 # self-coupling strength [nA] + JI = -0.1137 # cross-coupling strength [nA] + JAext = 0.00117 # Stimulus input strength [nA] + + mu = 20. # Stimulus firing rate [spikes/sec] + coh = 0.5 # Stimulus coherence [%] + Ib1 = 0.3297 + Ib2 = 0.3297 + + def ds1(s1, t, s2, coh=0.5, mu=20.): + I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) + r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) + return - s1 / tau + (1. - s1) * gamma * r1 + + def ds2(s2, t, s1, coh=0.5, mu=20.): + I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) + r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) + return - s2 / tau + (1. - s2) * gamma * r2 + + def step(s): + return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])]) + + rng = bm.random.RandomState(123) + finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) + finder.find_fps_with_opt_solver(rng.random((100, 2))) + + def test_opt_solver_for_ds1(self): + hh = HH(1) + finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike]) + rng = bm.random.RandomState(123) + + with self.assertRaises(ValueError): + finder.find_fps_with_opt_solver(rng.random((100, 4))) + + finder.find_fps_with_opt_solver({'V': rng.random((100, 1)), + 'm': rng.random((100, 1)), + 'h': rng.random((100, 1)), + 'n': rng.random((100, 1))}) + + def test_gd_method_for_func1(self): + gamma = 0.641 # Saturation factor for gating variable + tau = 0.06 # Synaptic time constant [sec] + a = 270. + b = 108. + d = 0.154 + + JE = 0.3725 # self-coupling strength [nA] + JI = -0.1137 # cross-coupling strength [nA] + JAext = 0.00117 # Stimulus input strength [nA] + + mu = 20. # Stimulus firing rate [spikes/sec] + coh = 0.5 # Stimulus coherence [%] + Ib1 = 0.3297 + Ib2 = 0.3297 + + def ds1(s1, t, s2, coh=0.5, mu=20.): + I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu * (1. + coh) + r1 = (a * I1 - b) / (1. - bm.exp(-d * (a * I1 - b))) + return - s1 / tau + (1. - s1) * gamma * r1 + + def ds2(s2, t, s1, coh=0.5, mu=20.): + I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu * (1. - coh) + r2 = (a * I2 - b) / (1. - bm.exp(-d * (a * I2 - b))) + return - s2 / tau + (1. - s2) * gamma * r2 + + def step(s): + return bm.asarray([ds1(s[0], 0., s[1]), ds2(s[1], 0., s[0])]) + + rng = bm.random.RandomState(123) + finder = bp.analysis.SlowPointFinder(f_cell=step, f_type=bp.analysis.CONTINUOUS) + finder.find_fps_with_gd_method(rng.random((100, 2)), num_opt=100) + + def test_gd_method_for_func2(self): + hh = HH(1) + finder = bp.analysis.SlowPointFinder(f_cell=hh, excluded_vars=[hh.input, hh.spike]) + rng = bm.random.RandomState(123) + + with self.assertRaises(ValueError): + finder.find_fps_with_opt_solver(rng.random((100, 4))) + + finder.find_fps_with_gd_method( + {'V': rng.random((100, 1)), + 'm': rng.random((100, 1)), + 'h': rng.random((100, 1)), + 'n': rng.random((100, 1))}, + num_opt=100 + ) diff --git a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py index 25eaa4325..56c2945eb 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/_src/analysis/lowdim/lowdim_analyzer.py @@ -3,1045 +3,1045 @@ import warnings from functools import partial -import numpy as np import jax +import numpy as np from jax import numpy as jnp from jax import vmap from jax.scipy.optimize import minimize -from brainpy import errors, tools import brainpy._src.math as bm -from brainpy._src.math.object_transform.base import Collector +from brainpy import errors, tools from brainpy._src.analysis import constants as C, utils from brainpy._src.analysis.base import DSAnalyzer +from brainpy._src.math.object_transform.base import Collector pyplot = None __all__ = [ - 'LowDimAnalyzer', - 'Num1DAnalyzer', - 'Num2DAnalyzer', + 'LowDimAnalyzer', + 'Num1DAnalyzer', + 'Num2DAnalyzer', ] class LowDimAnalyzer(DSAnalyzer): - r"""Automatic Analyzer for Low-dimensional Dynamical Systems. - - A dynamical model is characterized by a series of dynamical - variables and parameters: - - .. math:: - - {dF \over dt} = F(v_1, v_2, ..., p_1, p_2, ...) - - where :math:`v_1, v_2` are variables, :math:`p_1, p_2` are parameters. - - .. note:: - ``LowDimAnalyzer`` cannot analyze dynamical system depends on time :math:`t`. - - Parameters:: - - model : Any, ODEIntegrator, sequence of ODEIntegrator, DynamicalSystem - A model of the population, the integrator function, - or a list/tuple of integrator functions. - target_vars : dict - The target/dynamical variables. - fixed_vars : dict - The fixed variables. - target_pars : dict, optional - The parameters which can be dynamical varied. - pars_update : dict, optional - The parameters to update. - resolutions : float, dict - The resolution for numerical iterative solvers. Default is 20 equal parts - (:math:`\frac{\mathrm{max} - \mathrm{min}}{20}`). It can - set the numerical resolution of dynamical variables or dynamical parameters. - For example, - - - set ``resolutions=0.1`` will generalize it to all variables and parameters; - - set ``resolutions={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05}`` will specify - the particular resolutions to variables and parameters. - - Moreover, you can also set ``resolutions={var1: Array([...]), var2: 0.1}`` - to specify the search points need to explore for variable `var1`. - This will be useful to set sense search points at some inflection points. - lim_scale: float - The axis limit scale factor. Default is 1.05. The setting means - the axes will be clipped to ``[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]``. - options : optional, dict - The optional setting. Maybe needed in the individual analyzer. - """ - - def __init__( - self, - model, - target_vars, - fixed_vars=None, - target_pars=None, - pars_update=None, - resolutions=None, - jit_device=None, - lim_scale=1.05, - options=None, - ): - # model - # ----- - self.model = utils.model_transform(model) - - # target variables - # ---------------- - if not isinstance(target_vars, dict): - raise errors.AnalyzerError('"target_vars" must be a dict, with the format of ' - '{"var1": (var1_min, var1_max)}.') - self.target_vars = Collector(target_vars) - self.target_var_names = list(self.target_vars.keys()) # list of target vars - for key in self.target_vars.keys(): - if key not in self.model.variables: - raise errors.AnalyzerError(f'{key} is not a dynamical variable in {self.model}.') - value = self.target_vars[key] - if value[0] > value[1]: - raise errors.AnalyzerError( - f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.') - - # fixed variables - # ---------------- - if fixed_vars is None: - fixed_vars = dict() - if not isinstance(fixed_vars, dict): - raise errors.AnalyzerError('"fixed_vars" must be a dict with the format ' - 'of {"var1": val1, "var2": val2}.') - for key in fixed_vars.keys(): - if key not in self.model.variables: - raise ValueError(f'{key} is not a dynamical variable in {self.model}.') - self.fixed_vars = Collector(fixed_vars) - - # check duplicate - for key in self.fixed_vars.keys(): - if key in self.target_vars: - raise errors.AnalyzerError(f'"{key}" is defined as a target variable in "target_vars", ' - f'but also defined as a fixed variable in "fixed_vars".') - - # parameters to update - # --------------------- - if pars_update is None: - pars_update = dict() - if not isinstance(pars_update, dict): - raise errors.AnalyzerError('"pars_update" must be a dict with the format ' - 'of {"par1": val1, "par2": val2}.') - pars_update = Collector(pars_update) - for key in pars_update.keys(): - if key not in self.model.parameters: - raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.') - self.pars_update = pars_update - - # dynamical parameters - # --------------------- - if target_pars is None: - target_pars = dict() - if not isinstance(target_pars, dict): - raise errors.AnalyzerError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.') - for key, value in target_pars.items(): - if key not in self.model.parameters: - raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.') - if value[0] > value[1]: - raise errors.AnalyzerError( - f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.') - - self.target_pars = Collector(target_pars) - self.target_par_names = list(self.target_pars.keys()) # list of target_pars - - # check duplicate - for key in self.pars_update.keys(): - if key in self.target_pars: - raise errors.AnalyzerError(f'"{key}" is defined as a target parameter in "target_pars", ' - f'but also defined as a fixed parameter in "pars_update".') - - # resolutions for numerical methods - # --------------------------------- - self.resolutions = dict() - _target_vp = self.target_vars + self.target_pars - if resolutions is None: - for key, lim in self.target_vars.items(): - self.resolutions[key] = jnp.linspace(*lim, 20) - for key, lim in self.target_pars.items(): - self.resolutions[key] = jnp.linspace(*lim, 20) - elif isinstance(resolutions, float): - if len(self.target_pars) >= 1: - warnings.warn('The `resolutions` is specified to all parameters and variables. ' - 'Analysis computation may occupy too much memory if `resolutions` is small. ' - 'Please specify `resolutions` for each parameter and variable by dict, ' - 'such as resolutions={"V": 0.1}.', - category=UserWarning) - for key, lim in self.target_vars.items(): - self.resolutions[key] = jnp.arange(*lim, resolutions) - for key, lim in self.target_pars.items(): - self.resolutions[key] = jnp.arange(*lim, resolutions) - elif isinstance(resolutions, dict): - for key in resolutions.keys(): - if key in self.target_var_names: - continue - if key in self.target_par_names: - continue - raise errors.AnalyzerError(f'The resolution setting target "{key}" is not found in ' - f'the target variables {self.target_var_names} or ' - f'the target parameters {self.target_par_names}.') - for key in self.target_var_names + self.target_par_names: - if key not in resolutions: - self.resolutions[key] = jnp.linspace(*_target_vp[key], 20) - else: - resolution = resolutions[key] - if isinstance(resolution, float): - self.resolutions[key] = jnp.arange(*_target_vp[key], resolution) - elif isinstance(resolution, (bm.ndarray, np.ndarray, jnp.ndarray)): - if not np.ndim(resolution) == 1: - raise errors.AnalyzerError(f'resolution must be a 1D array, but get its ' - f'shape with {resolution.shape}.') - self.resolutions[key] = bm.as_jax(resolution) - else: - raise errors.AnalyzerError(f'Unknown resolution setting: {key}: {resolution}') - else: - raise errors.AnalyzerError(f'Unknown resolution type: {type(resolutions)}') - - # other settings - # -------------- - if options is None: - options = dict() - self.options = options - self.jit_device = jit_device - self.lim_scale = lim_scale - - # A dict to store the analyzed results - # ------------------------------------- - # 'dxdt' : The differential function ``f`` of the first variable ``x``. - # It can be used as ``dxdt(x, y, ...)``. - # 'dydt' : The differential function ``g`` of the second variable ``y``. - # It can be used as ``dydt(x, y, ...)``. - # 'dfdx' : The derivative of ``f`` by ``x``. It can be used as ``dfdx(x, y, ...)``. - # 'dfdy' : The derivative of ``f`` by ``y``. It can be used as ``dfdy(x, y, ...)``. - # 'dgdx' : The derivative of ``g`` by ``x``. It can be used as ``dgdx(x, y, ...)``. - # 'dgdy' : The derivative of ``g`` by ``y``. It can be used as ``dgdy(x, y, ...)``. - # 'jacobian' : The jacobian matrix. It can be used as ``jacobian(x, y, ...)``. - # 'fixed_point' : The fixed point. - # 'y_by_x_in_fy' : - # 'x_by_y_in_fy' : - # 'y_by_x_in_fx' : - # 'x_by_y_in_fx' : - self.analyzed_results = tools.DotDict() - - def show_figure(self): - global pyplot - if pyplot is None: - from matplotlib import pyplot - pyplot.show() - - -class Num1DAnalyzer(LowDimAnalyzer): - r"""Analyzer for one-dimensional dynamical system. - - It supports the analysis for 1D dynamical system. - - .. math:: - - {dx \over dt} = f(x, t) - - Actually, the analysis for 1D system is purely numerically. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.x_var = self.target_var_names[0] - if len(self.target_vars) < 1: - raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system ' - f'with >= 1 variables. But we got {len(self.target_vars)} ' - f'variables in {self.model}.') - - @property - def F_fx(self): - """Make the standard function call of :math:`f_x (*\mathrm{vars}, *\mathrm{pars})`. - - This function has been transformed into the standard call. - For instance, if the user has the ``target_vars=("v1", "v2")`` and - the ``target_pars=("p1", "p2")``, while the first function is defined as: + r"""Automatic Analyzer for Low-dimensional Dynamical Systems. - >>> def f1(v1, t, p1): - >>> return something + A dynamical model is characterized by a series of dynamical + variables and parameters: - However, after the stransformation, this function should be called as: + .. math:: - >>> self.F_fx(v1, v2, p1, p2) - """ - if C.F_fx not in self.analyzed_results: - _, arguments = utils.get_args(self.model.f_derivatives[self.x_var]) - wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) - f = wrapper(self.model.f_derivatives[self.x_var]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - f = utils.f_without_jaxarray_return(f) - f = utils.remove_return_shape(f) - self.analyzed_results[C.F_fx] = jax.jit(f, device=self.jit_device) - return self.analyzed_results[C.F_fx] - - @property - def F_vmap_fx(self): - if C.F_vmap_fx not in self.analyzed_results: - self.analyzed_results[C.F_vmap_fx] = jax.jit(vmap(self.F_fx), device=self.jit_device) - return self.analyzed_results[C.F_vmap_fx] - - @property - def F_dfxdx(self): - r"""The function to evaluate :math:`\frac{df_x(*\mathrm{vars}, *\mathrm{pars})}{dx}`.""" - if C.F_dfxdx not in self.analyzed_results: - dfx = bm.vector_grad(self.F_fx, argnums=0) - self.analyzed_results[C.F_dfxdx] = jax.jit(dfx, device=self.jit_device) - return self.analyzed_results[C.F_dfxdx] - - @property - def F_fixed_point_aux(self): - if C.F_fixed_point_aux not in self.analyzed_results: - def aux_fun(x, *args): - return jnp.abs(self.F_fx(x, *args)).sum() - - self.analyzed_results[C.F_fixed_point_aux] = aux_fun - return self.analyzed_results[C.F_fixed_point_aux] - - @property - def F_vmap_fp_aux(self): - if C.F_vmap_fp_aux not in self.analyzed_results: - # The arguments of this function are: - # --- - # "X": a two-dimensional matrix: (num_batch, num_var) - # "args": a list of one-dimensional vectors, each has the shape of (num_batch,) - self.analyzed_results[C.F_vmap_fp_aux] = jax.jit(vmap(self.F_fixed_point_aux)) - return self.analyzed_results[C.F_vmap_fp_aux] - - @property - def F_fixed_point_opt(self): - if C.F_fixed_point_opt not in self.analyzed_results: - def f(start_and_end, *args): - return utils.jax_brentq(self.F_fx)(start_and_end[0], start_and_end[1], args) - - self.analyzed_results[C.F_fixed_point_opt] = f - return self.analyzed_results[C.F_fixed_point_opt] - - @property - def F_vmap_fp_opt(self): - if C.F_vmap_fp_opt not in self.analyzed_results: - # The arguments of this function are: - # --- - # "X": a two-dimensional matrix: (num_batch, num_var) - # "args": a list of one-dimensional vectors, each has the shape of (num_batch,) - self.analyzed_results[C.F_vmap_fp_opt] = jax.jit(vmap(self.F_fixed_point_opt)) - return self.analyzed_results[C.F_vmap_fp_opt] - - def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None): - """ + {dF \over dt} = F(v_1, v_2, ..., p_1, p_2, ...) - "candidates" and "args" can be obtained through: + where :math:`v_1, v_2` are variables, :math:`p_1, p_2` are parameters. - >>> all_candidates = [] - >>> all_par1 = [] - >>> all_par2 = [] - >>> for p1 in par1_list: - >>> for p2 in par2_list: - >>> xs = self.resolutions[self.x_var] - >>> all_candidates.append(xs) - >>> all_par1.append(jnp.ones_like(xs) * p1) - >>> all_par2.append(jnp.ones_like(xs) * p2) + .. note:: + ``LowDimAnalyzer`` cannot analyze dynamical system depends on time :math:`t`. Parameters:: - candidates - args - tol_aux - loss_screen - - Returns:: - + model : Any, ODEIntegrator, sequence of ODEIntegrator, DynamicalSystem + A model of the population, the integrator function, + or a list/tuple of integrator functions. + target_vars : dict + The target/dynamical variables. + fixed_vars : dict + The fixed variables. + target_pars : dict, optional + The parameters which can be dynamical varied. + pars_update : dict, optional + The parameters to update. + resolutions : float, dict + The resolution for numerical iterative solvers. Default is 20 equal parts + (:math:`\frac{\mathrm{max} - \mathrm{min}}{20}`). It can + set the numerical resolution of dynamical variables or dynamical parameters. + For example, + + - set ``resolutions=0.1`` will generalize it to all variables and parameters; + - set ``resolutions={var1: 0.1, var2: 0.2, par1: 0.1, par2: 0.05}`` will specify + the particular resolutions to variables and parameters. + - Moreover, you can also set ``resolutions={var1: Array([...]), var2: 0.1}`` + to specify the search points need to explore for variable `var1`. + This will be useful to set sense search points at some inflection points. + lim_scale: float + The axis limit scale factor. Default is 1.05. The setting means + the axes will be clipped to ``[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]``. + options : optional, dict + The optional setting. Maybe needed in the individual analyzer. """ - # candidates: xs, a vector with the length of self.resolutions[self.x_var] - # args: parameters, a list/tuple of vectors - candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates - selected_ids = np.arange(len(candidates)) - args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) - for a in args: assert len(a) == len(candidates) - if num_seg is None: - num_seg = len(self.resolutions[self.x_var]) - assert isinstance(num_seg, int) - - # get the signs - signs = jnp.sign(self.F_vmap_fx(candidates, *args)) - signs = signs.reshape((num_seg, -1)) - par_len = signs.shape[1] - signs1 = signs.at[-1].set(1) - signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) - ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] - if len(ids) <= 0: - return [], [], [] - - # selected the proper candidates to optimize fixed points - selected_ids = selected_ids[np.asarray(ids)] - starts = candidates[ids] - ends = candidates[ids + par_len] - X = jnp.stack((starts, ends)).T - args = tuple(a[ids] for a in args) - - # optimize the fixed points - res = self.F_vmap_fp_opt(X, *args) - losses = self.F_vmap_fp_aux(res['root'], *args) - valid_or_not = jnp.logical_and(res['status'] == utils.ECONVERGED, losses <= tol_aux) - ids = np.asarray(jnp.where(valid_or_not)[0]) - fps = np.asarray(res['root'])[ids] - args = tuple(a[ids] for a in args) - selected_ids = selected_ids[np.asarray(ids)] - return fps, selected_ids, args - -class Num2DAnalyzer(Num1DAnalyzer): - r"""Analyzer for two-dimensional dynamical system. - - It supports the analysis for 2D dynamical system. + def __init__( + self, + model, + target_vars, + fixed_vars=None, + target_pars=None, + pars_update=None, + resolutions=None, + jit_device=None, + lim_scale=1.05, + options=None, + ): + # model + # ----- + self.model = utils.model_transform(model) + + # target variables + # ---------------- + if not isinstance(target_vars, dict): + raise errors.AnalyzerError('"target_vars" must be a dict, with the format of ' + '{"var1": (var1_min, var1_max)}.') + self.target_vars = Collector(target_vars) + self.target_var_names = list(self.target_vars.keys()) # list of target vars + for key in self.target_vars.keys(): + if key not in self.model.variables: + raise errors.AnalyzerError(f'{key} is not a dynamical variable in {self.model}.') + value = self.target_vars[key] + if value[0] > value[1]: + raise errors.AnalyzerError( + f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.') + + # fixed variables + # ---------------- + if fixed_vars is None: + fixed_vars = dict() + if not isinstance(fixed_vars, dict): + raise errors.AnalyzerError('"fixed_vars" must be a dict with the format ' + 'of {"var1": val1, "var2": val2}.') + for key in fixed_vars.keys(): + if key not in self.model.variables: + raise ValueError(f'{key} is not a dynamical variable in {self.model}.') + self.fixed_vars = Collector(fixed_vars) + + # check duplicate + for key in self.fixed_vars.keys(): + if key in self.target_vars: + raise errors.AnalyzerError(f'"{key}" is defined as a target variable in "target_vars", ' + f'but also defined as a fixed variable in "fixed_vars".') + + # parameters to update + # --------------------- + if pars_update is None: + pars_update = dict() + if not isinstance(pars_update, dict): + raise errors.AnalyzerError('"pars_update" must be a dict with the format ' + 'of {"par1": val1, "par2": val2}.') + pars_update = Collector(pars_update) + for key in pars_update.keys(): + if key not in self.model.parameters: + raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.') + self.pars_update = pars_update + + # dynamical parameters + # --------------------- + if target_pars is None: + target_pars = dict() + if not isinstance(target_pars, dict): + raise errors.AnalyzerError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.') + for key, value in target_pars.items(): + if key not in self.model.parameters: + raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.') + if value[0] > value[1]: + raise errors.AnalyzerError( + f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.') + + self.target_pars = Collector(target_pars) + self.target_par_names = list(self.target_pars.keys()) # list of target_pars + + # check duplicate + for key in self.pars_update.keys(): + if key in self.target_pars: + raise errors.AnalyzerError(f'"{key}" is defined as a target parameter in "target_pars", ' + f'but also defined as a fixed parameter in "pars_update".') + + # resolutions for numerical methods + # --------------------------------- + self.resolutions = dict() + _target_vp = self.target_vars + self.target_pars + if resolutions is None: + for key, lim in self.target_vars.items(): + self.resolutions[key] = jnp.linspace(*lim, 20) + for key, lim in self.target_pars.items(): + self.resolutions[key] = jnp.linspace(*lim, 20) + elif isinstance(resolutions, float): + if len(self.target_pars) >= 1: + warnings.warn('The `resolutions` is specified to all parameters and variables. ' + 'Analysis computation may occupy too much memory if `resolutions` is small. ' + 'Please specify `resolutions` for each parameter and variable by dict, ' + 'such as resolutions={"V": 0.1}.', + category=UserWarning) + for key, lim in self.target_vars.items(): + self.resolutions[key] = jnp.arange(*lim, resolutions) + for key, lim in self.target_pars.items(): + self.resolutions[key] = jnp.arange(*lim, resolutions) + elif isinstance(resolutions, dict): + for key in resolutions.keys(): + if key in self.target_var_names: + continue + if key in self.target_par_names: + continue + raise errors.AnalyzerError(f'The resolution setting target "{key}" is not found in ' + f'the target variables {self.target_var_names} or ' + f'the target parameters {self.target_par_names}.') + for key in self.target_var_names + self.target_par_names: + if key not in resolutions: + self.resolutions[key] = jnp.linspace(*_target_vp[key], 20) + else: + resolution = resolutions[key] + if isinstance(resolution, float): + self.resolutions[key] = jnp.arange(*_target_vp[key], resolution) + elif isinstance(resolution, (bm.ndarray, np.ndarray, jnp.ndarray)): + if not np.ndim(resolution) == 1: + raise errors.AnalyzerError(f'resolution must be a 1D array, but get its ' + f'shape with {resolution.shape}.') + self.resolutions[key] = bm.as_jax(resolution) + else: + raise errors.AnalyzerError(f'Unknown resolution setting: {key}: {resolution}') + else: + raise errors.AnalyzerError(f'Unknown resolution type: {type(resolutions)}') + + # other settings + # -------------- + if options is None: + options = dict() + self.options = options + self.jit_device = jit_device + self.lim_scale = lim_scale + + # A dict to store the analyzed results + # ------------------------------------- + # 'dxdt' : The differential function ``f`` of the first variable ``x``. + # It can be used as ``dxdt(x, y, ...)``. + # 'dydt' : The differential function ``g`` of the second variable ``y``. + # It can be used as ``dydt(x, y, ...)``. + # 'dfdx' : The derivative of ``f`` by ``x``. It can be used as ``dfdx(x, y, ...)``. + # 'dfdy' : The derivative of ``f`` by ``y``. It can be used as ``dfdy(x, y, ...)``. + # 'dgdx' : The derivative of ``g`` by ``x``. It can be used as ``dgdx(x, y, ...)``. + # 'dgdy' : The derivative of ``g`` by ``y``. It can be used as ``dgdy(x, y, ...)``. + # 'jacobian' : The jacobian matrix. It can be used as ``jacobian(x, y, ...)``. + # 'fixed_point' : The fixed point. + # 'y_by_x_in_fy' : + # 'x_by_y_in_fy' : + # 'y_by_x_in_fx' : + # 'x_by_y_in_fx' : + self.analyzed_results = tools.DotDict() + + def show_figure(self): + global pyplot + if pyplot is None: + from matplotlib import pyplot + pyplot.show() - .. math:: - {dx \over dt} = fx(x, t, y) +class Num1DAnalyzer(LowDimAnalyzer): + r"""Analyzer for one-dimensional dynamical system. - {dy \over dt} = fy(y, t, x) - """ + It supports the analysis for 1D dynamical system. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if len(self.target_vars) < 2: - raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system ' - f'with >= 2 variables. But we got {len(self.target_vars)} ' - f'variables in {self.model}.') - self.y_var = self.target_var_names[1] + .. math:: - @property - def F_fy(self): - r"""The function to evaluate :math:`f_y(*\mathrm{vars}, *\mathrm{pars})`. + {dx \over dt} = f(x, t) - This function has been transformed into the standard call. - For instance, if the user has the ``target_vars=("v1", "v2")`` and - the ``target_pars=("p1", "p2")``, while the first function is defined as: + Actually, the analysis for 1D system is purely numerically. + """ - >>> def f1(v1, t, p1): - >>> return something + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.x_var = self.target_var_names[0] + if len(self.target_vars) < 1: + raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system ' + f'with >= 1 variables. But we got {len(self.target_vars)} ' + f'variables in {self.model}.') + + @property + def F_fx(self): + """Make the standard function call of :math:`f_x (*\mathrm{vars}, *\mathrm{pars})`. + + This function has been transformed into the standard call. + For instance, if the user has the ``target_vars=("v1", "v2")`` and + the ``target_pars=("p1", "p2")``, while the first function is defined as: + + >>> def f1(v1, t, p1): + >>> return something + + However, after the stransformation, this function should be called as: + + >>> self.F_fx(v1, v2, p1, p2) + """ + if C.F_fx not in self.analyzed_results: + _, arguments = utils.get_args(self.model.f_derivatives[self.x_var]) + wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) + f = wrapper(self.model.f_derivatives[self.x_var]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + f = utils.f_without_jaxarray_return(f) + f = utils.remove_return_shape(f) + self.analyzed_results[C.F_fx] = jax.jit(f, device=self.jit_device) + return self.analyzed_results[C.F_fx] + + @property + def F_vmap_fx(self): + if C.F_vmap_fx not in self.analyzed_results: + self.analyzed_results[C.F_vmap_fx] = jax.jit(vmap(self.F_fx), device=self.jit_device) + return self.analyzed_results[C.F_vmap_fx] + + @property + def F_dfxdx(self): + r"""The function to evaluate :math:`\frac{df_x(*\mathrm{vars}, *\mathrm{pars})}{dx}`.""" + if C.F_dfxdx not in self.analyzed_results: + dfx = bm.vector_grad(self.F_fx, argnums=0) + self.analyzed_results[C.F_dfxdx] = jax.jit(dfx, device=self.jit_device) + return self.analyzed_results[C.F_dfxdx] + + @property + def F_fixed_point_aux(self): + if C.F_fixed_point_aux not in self.analyzed_results: + def aux_fun(x, *args): + return jnp.abs(self.F_fx(x, *args)).sum() + + self.analyzed_results[C.F_fixed_point_aux] = aux_fun + return self.analyzed_results[C.F_fixed_point_aux] + + @property + def F_vmap_fp_aux(self): + if C.F_vmap_fp_aux not in self.analyzed_results: + # The arguments of this function are: + # --- + # "X": a two-dimensional matrix: (num_batch, num_var) + # "args": a list of one-dimensional vectors, each has the shape of (num_batch,) + self.analyzed_results[C.F_vmap_fp_aux] = jax.jit(vmap(self.F_fixed_point_aux)) + return self.analyzed_results[C.F_vmap_fp_aux] + + @property + def F_fixed_point_opt(self): + if C.F_fixed_point_opt not in self.analyzed_results: + def f(start_and_end, *args): + return utils.jax_brentq(self.F_fx)(start_and_end[0], start_and_end[1], args) + + self.analyzed_results[C.F_fixed_point_opt] = f + return self.analyzed_results[C.F_fixed_point_opt] + + @property + def F_vmap_fp_opt(self): + if C.F_vmap_fp_opt not in self.analyzed_results: + # The arguments of this function are: + # --- + # "X": a two-dimensional matrix: (num_batch, num_var) + # "args": a list of one-dimensional vectors, each has the shape of (num_batch,) + self.analyzed_results[C.F_vmap_fp_opt] = jax.jit(vmap(self.F_fixed_point_opt)) + return self.analyzed_results[C.F_vmap_fp_opt] + + def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None): + """ + + "candidates" and "args" can be obtained through: + + >>> all_candidates = [] + >>> all_par1 = [] + >>> all_par2 = [] + >>> for p1 in par1_list: + >>> for p2 in par2_list: + >>> xs = self.resolutions[self.x_var] + >>> all_candidates.append(xs) + >>> all_par1.append(jnp.ones_like(xs) * p1) + >>> all_par2.append(jnp.ones_like(xs) * p2) + + Parameters:: + + candidates + args + tol_aux + loss_screen + + Returns:: + + """ + # candidates: xs, a vector with the length of self.resolutions[self.x_var] + # args: parameters, a list/tuple of vectors + candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates + selected_ids = np.arange(len(candidates)) + args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) + for a in args: assert len(a) == len(candidates) + if num_seg is None: + num_seg = len(self.resolutions[self.x_var]) + assert isinstance(num_seg, int) + + # get the signs + signs = jnp.sign(self.F_vmap_fx(candidates, *args)) + signs = signs.reshape((num_seg, -1)) + par_len = signs.shape[1] + signs1 = signs.at[-1].set(1) + signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) + ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] + if len(ids) <= 0: + return [], [], [] + + # selected the proper candidates to optimize fixed points + selected_ids = selected_ids[np.asarray(ids)] + starts = candidates[ids] + ends = candidates[ids + par_len] + X = jnp.stack((starts, ends)).T + args = tuple(a[ids] for a in args) + + # optimize the fixed points + res = self.F_vmap_fp_opt(X, *args) + losses = self.F_vmap_fp_aux(res['root'], *args) + valid_or_not = jnp.logical_and(res['status'] == utils.ECONVERGED, losses <= tol_aux) + ids = np.asarray(jnp.where(valid_or_not)[0]) + fps = np.asarray(res['root'])[ids] + args = tuple(a[ids] for a in args) + selected_ids = selected_ids[np.asarray(ids)] + return fps, selected_ids, args - However, after the stransformation, this function should be called as: - >>> self.F_fy(v1, v2, p1, p2) - """ - if C.F_fy not in self.analyzed_results: - variables, arguments = utils.get_args(self.model.f_derivatives[self.y_var]) - wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) - f = wrapper(self.model.f_derivatives[self.y_var]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - f = utils.f_without_jaxarray_return(f) - f = utils.remove_return_shape(f) - self.analyzed_results[C.F_fy] = jax.jit(f, device=self.jit_device) - return self.analyzed_results[C.F_fy] - - @property - def F_int_x(self): - if C.F_int_x not in self.analyzed_results: - wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.x_var])[1], - self.target_var_names, self.target_par_names) - init_x = partial(wrap_x(self.model.f_integrals[0]), **(self.pars_update + self.fixed_vars)) - self.analyzed_results[C.F_int_x] = init_x - return self.analyzed_results[C.F_int_x] - - @property - def F_int_y(self): - if C.F_int_y not in self.analyzed_results: - wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.y_var])[1], - self.target_var_names, self.target_par_names) - init_x = partial(wrap_x(self.model.f_integrals[1]), **(self.pars_update + self.fixed_vars)) - self.analyzed_results[C.F_int_y] = init_x - return self.analyzed_results[C.F_int_y] - - @property - def F_x_by_y_in_fx(self): - if C.F_x_by_y_in_fx not in self.analyzed_results: - if C.x_by_y_in_fx in self.options: - wrapper = utils.std_func(utils.get_args(self.options[C.x_by_y_in_fx], gather_var=False), - self.target_var_names[1:], - self.target_par_names) - f = wrapper(self.options[C.x_by_y_in_fx]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - f = utils.f_without_jaxarray_return(f) - self.analyzed_results[C.F_x_by_y_in_fx] = f - else: - self.analyzed_results[C.F_x_by_y_in_fx] = None - return self.analyzed_results[C.F_x_by_y_in_fx] - - @property - def F_y_by_x_in_fx(self): - if C.F_y_by_x_in_fx not in self.analyzed_results: - if C.y_by_x_in_fx in self.options: - wrapper = utils.std_func(utils.get_args(self.options[C.y_by_x_in_fx], gather_var=False), - self.target_var_names[:1] + self.target_var_names[2:], - self.target_par_names) - f = wrapper(self.options[C.y_by_x_in_fx]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - f = utils.f_without_jaxarray_return(f) - self.analyzed_results[C.F_y_by_x_in_fx] = f - else: - self.analyzed_results[C.F_y_by_x_in_fx] = None - return self.analyzed_results[C.F_y_by_x_in_fx] - - @property - def F_x_by_y_in_fy(self): - if C.F_x_by_y_in_fy not in self.analyzed_results: - if C.x_by_y_in_fy in self.options: - wrapper = utils.std_func(utils.get_args(self.options[C.x_by_y_in_fy], gather_var=False), - self.target_var_names[1:], - self.target_par_names) - f = wrapper(self.options[C.x_by_y_in_fy]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - f = utils.f_without_jaxarray_return(f) - self.analyzed_results[C.F_x_by_y_in_fy] = f - else: - self.analyzed_results[C.F_x_by_y_in_fy] = None - return self.analyzed_results[C.F_x_by_y_in_fy] - - @property - def F_y_by_x_in_fy(self): - if C.F_y_by_x_in_fy not in self.analyzed_results: - if C.y_by_x_in_fy in self.options: - wrapper = utils.std_func(utils.get_args(self.options[C.y_by_x_in_fy], gather_var=False), - self.target_var_names[:1] + self.target_var_names[2:], - self.target_par_names) - f = wrapper(self.options[C.y_by_x_in_fy]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - f = utils.f_without_jaxarray_return(f) - self.analyzed_results[C.F_y_by_x_in_fy] = f - else: - self.analyzed_results[C.F_y_by_x_in_fy] = None - return self.analyzed_results[C.F_y_by_x_in_fy] - - @property - def F_vmap_fy(self): - if C.F_vmap_fy not in self.analyzed_results: - self.analyzed_results[C.F_vmap_fy] = jax.jit(vmap(self.F_fy), device=self.jit_device) - return self.analyzed_results[C.F_vmap_fy] - - @property - def F_dfxdy(self): - r"""The function to evaluate :math:`\frac{df_x (*\mathrm{vars}, *\mathrm{pars})}{dy}`.""" - if C.F_dfxdy not in self.analyzed_results: - dfxdy = bm.vector_grad(self.F_fx, argnums=1) - self.analyzed_results[C.F_dfxdy] = jax.jit(dfxdy, device=self.jit_device) - return self.analyzed_results[C.F_dfxdy] - - @property - def F_dfydx(self): - r"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dx}`.""" - if C.F_dfydx not in self.analyzed_results: - dfydx = bm.vector_grad(self.F_fy, argnums=0) - self.analyzed_results[C.F_dfydx] = jax.jit(dfydx, device=self.jit_device) - return self.analyzed_results[C.F_dfydx] - - @property - def F_dfydy(self): - r"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dy}`.""" - if C.F_dfydy not in self.analyzed_results: - dfydy = bm.vector_grad(self.F_fy, argnums=1) - self.analyzed_results[C.F_dfydy] = jax.jit(dfydy, device=self.jit_device) - return self.analyzed_results[C.F_dfydy] - - @property - def F_jacobian(self): - r"""The function to evaluate :math:`J(*\mathrm{vars}, *\mathrm{pars})`.""" - if C.F_jacobian not in self.analyzed_results: - @partial(bm.jacobian, argnums=(0, 1)) - def f_jacobian(*var_and_pars): - return self.F_fx(*var_and_pars), self.F_fy(*var_and_pars) - - def call(*var_and_pars): - var_and_pars = tuple((vp.value if isinstance(vp, bm.BaseArray) else vp) for vp in var_and_pars) - return jnp.array(jax.jit(f_jacobian, device=self.jit_device)(*var_and_pars)) - - self.analyzed_results[C.F_jacobian] = call - return self.analyzed_results[C.F_jacobian] - - @property - def F_fixed_point_opt(self): - if C.F_fixed_point_opt not in self.analyzed_results: - if self._can_convert_to_one_eq(): - if self.convert_type() == C.x_by_y: - def f(start_and_end, *args): - return utils.jax_brentq(self.F_y_convert[1])(start_and_end[0], start_and_end[1], args) - else: - def f(start_and_end, *args): - return utils.jax_brentq(self.F_x_convert[1])(start_and_end[0], start_and_end[1], args) - self.analyzed_results[C.F_fixed_point_opt] = f - - else: - # If cannot convert to one variable equation - def opt_fun(xy_init, *args): - # "xy_init" is a vector with length 2, - # "args: is a tuple of scalar - return minimize(self.F_fixed_point_aux, xy_init, args=args, method='BFGS') - - self.analyzed_results[C.F_fixed_point_opt] = opt_fun - return self.analyzed_results[C.F_fixed_point_opt] - - @property - def F_fixed_point_aux(self): - if C.F_fixed_point_aux not in self.analyzed_results: - if self._can_convert_to_one_eq(): - if self.convert_type() == C.x_by_y: - f = lambda y, *args: jnp.abs(self.F_y_convert[1](y, *args)).sum() - else: - f = lambda x, *args: jnp.abs(self.F_x_convert[1](x, *args)).sum() - self.analyzed_results[C.F_fixed_point_aux] = f - - else: - def aux_fun(xy, *args): - # "xy" is a vector with length 2, - # "args": is a tuple of scalar - dx = self.F_fx(xy[0], xy[1], *args) - dy = self.F_fy(xy[0], xy[1], *args) - # return (jnp.abs(dx) + jnp.abs(dy)).sum() - return (dx ** 2 + dy ** 2).sum() - - self.analyzed_results[C.F_fixed_point_aux] = aux_fun - return self.analyzed_results[C.F_fixed_point_aux] - - def _can_convert_to_one_eq(self): - if self.F_x_by_y_in_fx is not None: - return True - if self.F_x_by_y_in_fy is not None: - return True - if self.F_y_by_x_in_fx is not None: - return True - if self.F_y_by_x_in_fy is not None: - return True - return False - - def convert_type(self): - if self.F_x_by_y_in_fx is not None: - return C.x_by_y - if self.F_x_by_y_in_fy is not None: - return C.x_by_y - if self.F_y_by_x_in_fx is not None: - return C.y_by_x - if self.F_y_by_x_in_fy is not None: - return C.y_by_x - raise errors.AnalyzerError - - @property - def F_y_convert(self): - if C.F_y_convert not in self.analyzed_results: - if self.F_x_by_y_in_fy is not None: - f = lambda y, *pars: self.F_fx(self.F_x_by_y_in_fy(y, *pars), y, *pars) - res = (self.F_x_by_y_in_fy, f) - elif self.F_x_by_y_in_fx is not None: - f = lambda y, *pars: self.F_fy(self.F_x_by_y_in_fx(y, *pars), y, *pars) - res = (self.F_x_by_y_in_fx, f) - else: - res = None - self.analyzed_results[C.F_y_convert] = res - return self.analyzed_results[C.F_y_convert] - - @property - def F_x_convert(self): - if C.F_x_convert not in self.analyzed_results: - if self.F_y_by_x_in_fy is not None: - f = lambda x, *pars: self.F_fx(x, self.F_y_by_x_in_fy(x, *pars), *pars) - res = (self.F_y_by_x_in_fy, f) - elif self.F_y_by_x_in_fx is not None: - f = lambda x, *pars: self.F_fy(x, self.F_y_by_x_in_fx(x, *pars), *pars) - res = (self.F_y_by_x_in_fx, f) - else: - res = None - self.analyzed_results[C.F_x_convert] = res - return self.analyzed_results[C.F_x_convert] - - def _fp_filter(self, x_values, y_values, par_values, aux_filter=0.): - if aux_filter > 0.: - losses = self.F_vmap_fp_aux(jnp.stack([x_values, y_values]).T, *par_values) - ids = jnp.where(losses < aux_filter)[0] - x_values = x_values[ids] - y_values = y_values[ids] - par_values = tuple(p[ids] for p in par_values) - return x_values, y_values, par_values - - def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux_filter=0.): - coords = (self.x_var + '-' + self.y_var) if coords is None else coords - key = C.fx_nullcline_points + ',' + coords - if key not in self.analyzed_results: - all_losses = [] - all_x_values_in_fx = [] - all_y_values_in_fx = [] - all_p_values_in_fx = tuple([] for _ in range(len(self.target_par_names))) - - # points of variables and parameters - xs = self.resolutions[self.x_var] - ys = self.resolutions[self.y_var] - par_seg = utils.Segment(targets=tuple(self.resolutions[p] for p in self.target_par_names), - num_segments=num_segments) - - if self.F_x_by_y_in_fx is not None: - utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...") - vmap_f = jax.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device) - for j, pars in enumerate(par_seg): - if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") - mesh_values = jnp.meshgrid(*((ys,) + pars)) - x_values_in_fx = vmap_f(*mesh_values) - y_values_in_fx = mesh_values[0] - p_values_in_fx = mesh_values[1:] - losses = self.F_vmap_fx(x_values_in_fx, y_values_in_fx, *p_values_in_fx) - all_losses.append(losses) - all_x_values_in_fx.append(x_values_in_fx) - all_y_values_in_fx.append(y_values_in_fx) - for i, arg in enumerate(p_values_in_fx): - all_p_values_in_fx[i].append(arg) - - elif self.F_y_by_x_in_fx is not None: - utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...") - vmap_f = jax.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device) - for j, pars in enumerate(par_seg): - if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") - mesh_values = jnp.meshgrid(*((xs,) + pars)) - y_values_in_fx = vmap_f(*mesh_values) - x_values_in_fx = mesh_values[0] - p_values_in_fx = mesh_values[1:] - losses = self.F_vmap_fx(x_values_in_fx, y_values_in_fx, *p_values_in_fx) - all_losses.append(losses) - all_x_values_in_fx.append(x_values_in_fx) - all_y_values_in_fx.append(y_values_in_fx) - for i, arg in enumerate(p_values_in_fx): - all_p_values_in_fx[i].append(arg) - - else: - utils.output("I am evaluating fx-nullcline by optimization ...") - # auxiliary functions - f2 = lambda y, x, *pars: self.F_fx(x, y, *pars) - vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) - vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) - vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) - - # num segments - for _j, Ps in enumerate(par_seg): - if len(par_seg.arg_id_segments[0]) > 1: - utils.output(f"{C.prefix}segment {_j} ...") - if coords == self.x_var + '-' + self.y_var: - x0s, x1s, vps = utils.brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps)) - x_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f1, x0s, x1s, *vps) - y_values_in_fx = out_args[0] - p_values_in_fx = out_args[1:] - x_values_in_fx, y_values_in_fx, p_values_in_fx = \ - self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter) - elif coords == self.y_var + '-' + self.x_var: - x0s, x1s, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) - y_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f2, x0s, x1s, *vps) - x_values_in_fx = out_args[0] - p_values_in_fx = out_args[1:] - x_values_in_fx, y_values_in_fx, p_values_in_fx = \ - self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter) - else: - raise ValueError - losses = self.F_vmap_fx(x_values_in_fx, y_values_in_fx, *p_values_in_fx) - all_losses.append(losses) - all_x_values_in_fx.append(x_values_in_fx) - all_y_values_in_fx.append(y_values_in_fx) - for i, arg in enumerate(p_values_in_fx): - all_p_values_in_fx[i].append(arg) - - all_losses = jnp.concatenate(all_losses) - all_x_values_in_fx = jnp.concatenate(all_x_values_in_fx) - all_y_values_in_fx = jnp.concatenate(all_y_values_in_fx) - all_p_values_in_fx = tuple(jnp.concatenate(p) for p in all_p_values_in_fx) - ids = jnp.where(all_losses < tol)[0] - all_x_values_in_fx = all_x_values_in_fx[ids] - all_y_values_in_fx = all_y_values_in_fx[ids] - all_p_values_in_fx = tuple(a[ids] for a in all_p_values_in_fx) - all_xy_values = jnp.stack([all_x_values_in_fx, all_y_values_in_fx]).T - self.analyzed_results[key] = (all_xy_values,) + all_p_values_in_fx - return self.analyzed_results[key] - - def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux_filter=0.): - coords = (self.x_var + '-' + self.y_var) if coords is None else coords - key = C.fy_nullcline_points + ',' + coords - if key not in self.analyzed_results: - all_losses = [] - all_x_values_in_fy = [] - all_y_values_in_fy = [] - all_p_values_in_fy = tuple([] for _ in range(len(self.target_par_names))) - - xs = self.resolutions[self.x_var] - ys = self.resolutions[self.y_var] - par_seg = utils.Segment(tuple(self.resolutions[p] for p in self.target_par_names), - num_segments=num_segments) - - if self.F_x_by_y_in_fy is not None: - utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...") - vmap_f = jax.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device) - for j, pars in enumerate(par_seg): - if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") - mesh_values = jnp.meshgrid(*((ys,) + pars)) - x_values_in_fy = vmap_f(*mesh_values) - y_values_in_fy = mesh_values[0] - p_values_in_fy = mesh_values[1:] - losses = self.F_vmap_fy(x_values_in_fy, y_values_in_fy, *p_values_in_fy) - all_losses.append(losses) - all_x_values_in_fy.append(x_values_in_fy) - all_y_values_in_fy.append(y_values_in_fy) - for i, arg in enumerate(p_values_in_fy): - all_p_values_in_fy[i].append(arg) - - elif self.F_y_by_x_in_fy is not None: - utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...") - vmap_f = jax.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device) - for j, pars in enumerate(par_seg): - if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") - mesh_values = jnp.meshgrid(*((xs,) + pars)) - y_values_in_fy = vmap_f(*mesh_values) - x_values_in_fy = mesh_values[0] - p_values_in_fy = mesh_values[1:] - losses = self.F_vmap_fy(x_values_in_fy, y_values_in_fy, *p_values_in_fy) - all_losses.append(losses) - all_x_values_in_fy.append(x_values_in_fy) - all_y_values_in_fy.append(y_values_in_fy) - for i, arg in enumerate(p_values_in_fy): - all_p_values_in_fy[i].append(arg) - - else: - utils.output("I am evaluating fy-nullcline by optimization ...") - - # auxiliary functions - f2 = lambda y, x, *pars: self.F_fy(x, y, *pars) - vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) - vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) - vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) - - for j, Ps in enumerate(par_seg): - if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") - if coords == self.x_var + '-' + self.y_var: - starts, ends, vps = utils.brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps)) - x_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f1, starts, ends, *vps) - y_values_in_fy = out_args[0] - p_values_in_fy = out_args[1:] - x_values_in_fy, y_values_in_fy, p_values_in_fy = \ - self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter) - elif coords == self.y_var + '-' + self.x_var: - starts, ends, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) - y_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f2, starts, ends, *vps) - x_values_in_fy = out_args[0] - p_values_in_fy = out_args[1:] - x_values_in_fy, y_values_in_fy, p_values_in_fy = \ - self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter) - else: - raise ValueError - losses = self.F_vmap_fy(x_values_in_fy, y_values_in_fy, *p_values_in_fy) - all_losses.append(losses) - all_x_values_in_fy.append(x_values_in_fy) - all_y_values_in_fy.append(y_values_in_fy) - for i, arg in enumerate(p_values_in_fy): - all_p_values_in_fy[i].append(arg) - all_losses = jnp.concatenate(all_losses) - all_x_values_in_fy = jnp.concatenate(all_x_values_in_fy) - all_y_values_in_fy = jnp.concatenate(all_y_values_in_fy) - all_p_values_in_fy = tuple(jnp.concatenate(p) for p in all_p_values_in_fy) - ids = jnp.where(all_losses < tol)[0] - all_x_values_in_fy = all_x_values_in_fy[ids] - all_y_values_in_fy = all_y_values_in_fy[ids] - all_p_values_in_fy = tuple(a[ids] for a in all_p_values_in_fy) - all_xy_values = jnp.stack([all_x_values_in_fy, all_y_values_in_fy]).T - self.analyzed_results[key] = (all_xy_values,) + all_p_values_in_fy - return self.analyzed_results[key] - - def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100): - utils.output(f"I am filtering out fixed point candidates with auxiliary function ...") - all_xs = [] - all_ys = [] - all_ps = tuple([] for _ in range(len(self.target_par_names))) - - # points of variables and parameters - xs = self.resolutions[self.x_var] - ys = self.resolutions[self.y_var] - P = tuple(self.resolutions[p] for p in self.target_par_names) - f_select = jax.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1))) - - # num seguments - if isinstance(num_segments, int): - num_segments = tuple([num_segments] * len(self.target_par_names)) - assert isinstance(num_segments, (tuple, list)) and len(num_segments) == len(self.target_par_names) - arg_lens = tuple(len(p) for p in P) - arg_pre_len = tuple(int(np.ceil(l / num_segments[i])) for i, l in enumerate(arg_lens)) - arg_id_segments = tuple(np.arange(0, l, arg_pre_len[i]) for i, l in enumerate(arg_lens)) - arg_id_segments = tuple(ids.flatten() for ids in np.meshgrid(*arg_id_segments)) - if len(arg_id_segments) == 0: - arg_id_segments = ((0,),) - for _j, ids in enumerate(zip(*arg_id_segments)): - if len(arg_id_segments[0]) > 1: - utils.output(f"{C.prefix}segment {_j} ...") - - ps = tuple(p[ids[i]: ids[i] + arg_pre_len[i]] for i, p in enumerate(P)) - # change the position of meshgrid values - vps = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in ((xs, ys) + ps)) - mesh_values = jnp.meshgrid(*vps) - mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) - mesh_values = tuple(m.flatten() for m in mesh_values) - # function outputs - losses = self.F_vmap_fp_aux(jnp.stack([mesh_values[0], mesh_values[1]]).T, *mesh_values[2:]) - shape = (len(xs) * len(ys), -1) - losses = losses.reshape(shape) - argsorts = jnp.argsort(losses, axis=0)[:num_rank] - all_xs.append(f_select(mesh_values[0].reshape(shape), argsorts).flatten()) - all_ys.append(f_select(mesh_values[1].reshape(shape), argsorts).flatten()) - for i, p in enumerate(ps): - all_ps[i].append(f_select(mesh_values[i + 2].reshape(shape), argsorts).flatten()) - all_xys = jnp.vstack([jnp.concatenate(all_xs), jnp.concatenate(all_ys)]).T - all_ps = tuple(jnp.concatenate(p) for p in all_ps) - return (all_xys, all_ps) - - def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, - tol_unique=1e-2, tol_opt_candidate=None, - num_segment=1): - """Get the fixed points according to the initial ``candidates`` and the parameter setting ``args``. - - "candidates" and "args" can be obtained through: - - >>> all_candidates = [] - >>> all_par1 = [] - >>> all_par2 = [] - >>> for p1 in par1_list: - >>> for p2 in par2_list: - >>> nullcline_points = _get_nullcline_points(p1, p2) - >>> all_candidates.append(nullcline_points) - >>> all_par1.append(jnp.ones_like(nullcline_points) * p1) - >>> all_par2.append(jnp.ones_like(nullcline_points) * p2) +class Num2DAnalyzer(Num1DAnalyzer): + r"""Analyzer for two-dimensional dynamical system. - Parameters:: + It supports the analysis for 2D dynamical system. - candidates: np.ndarray, jnp.ndarray - The candidate points (batched) to optimize, like the nullcline points. - args : tuple - The parameters (batched). - tol_aux : float - tol_unique : float - tol_opt_candidate : float, optional + .. math:: - Returns:: + {dx \over dt} = fx(x, t, y) - res : tuple - The fixed point results. + {dy \over dt} = fy(y, t, x) """ - if self._can_convert_to_one_eq(): - utils.output("I am trying to find fixed points by brentq optimization ...") - - # candidates: xs, a vector with the length of self.resolutions[self.x_var] - # args: parameters, a list/tuple of vectors - candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates - selected_ids = np.arange(len(candidates)) - args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) - for a in args: assert len(a) == len(candidates) - - if self.convert_type() == C.x_by_y: - num_seg = len(self.resolutions[self.y_var]) - f_vmap = jax.jit(vmap(self.F_y_convert[1])) - else: - num_seg = len(self.resolutions[self.x_var]) - f_vmap = jax.jit(vmap(self.F_x_convert[1])) - # get the signs - signs = jnp.sign(f_vmap(candidates, *args)) - signs = signs.reshape((num_seg, -1)) - par_len = signs.shape[1] - signs1 = signs.at[-1].set(1) - signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) - ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] - if len(ids) <= 0: - return [], [], [] - - # selected the proper candidates to optimize fixed points - selected_ids = selected_ids[np.asarray(ids)] - starts = candidates[ids] - ends = candidates[ids + par_len] - X = jnp.stack((starts, ends)).T - args = tuple(a[ids] for a in args) - - # optimize the fixed points - res = self.F_vmap_fp_opt(X, *args) - losses = self.F_vmap_fp_aux(res['root'], *args) - valid_or_not = jnp.logical_and(res['status'] == utils.ECONVERGED, losses <= tol_aux) - ids = np.asarray(jnp.where(valid_or_not)[0]) - fps = np.asarray(res['root'])[ids] - args = tuple(a[ids] for a in args) - selected_ids = selected_ids[np.asarray(ids)] - - # get another value - if self.convert_type() == C.x_by_y: - y_values = fps - x_values = jax.jit(vmap(self.F_y_convert[0]))(y_values, *args) - else: - x_values = fps - y_values = jax.jit(vmap(self.F_x_convert[0]))(x_values, *args) - fps = jnp.stack([x_values, y_values]).T - return fps, selected_ids, args - - else: - utils.output("I am trying to find fixed points by optimization ...") - utils.output(f"{C.prefix}There are {len(candidates)} candidates") - - candidates = jnp.asarray(candidates) - args = tuple(jnp.asarray(a) for a in args) - - all_ids = [] - all_fps = [] - all_args = tuple([] for _ in range(len(args))) - seg_len = int(np.ceil(len(candidates) / num_segment)) - segment_ids = np.arange(0, len(candidates), seg_len) - selected_ids = jnp.arange(len(candidates)) - - for _j, i in enumerate(segment_ids): - if len(segment_ids) > 1: - utils.output(f"{C.prefix}segment {_j} ...") - seg_fps = candidates[i: i + seg_len] - seg_args = tuple(a[i: i + seg_len] for a in args) - seg_ids = selected_ids[i: i + seg_len] - - if tol_opt_candidate is not None: - # screen by the function loss - losses = self.F_vmap_fp_aux(seg_fps, *seg_args) - ids = jnp.where(losses < tol_opt_candidate)[0] - seg_fps = seg_fps[ids] - seg_args = tuple(a[ids] for a in seg_args) - seg_ids = seg_ids[ids] - if len(seg_fps): - # optimization - seg_fps = self.F_vmap_fp_opt(seg_fps, *seg_args) - # loss - losses = self.F_vmap_fp_aux(seg_fps.x, *seg_args) - # valid indices - ids = jnp.where(losses <= tol_aux)[0] - seg_ids = seg_ids[ids] - all_fps.append(seg_fps.x[ids]) - all_ids.append(seg_ids) - for i in range(len(all_args)): - all_args[i].append(seg_args[i][ids]) - all_fps = jnp.concatenate(all_fps) - all_ids = jnp.concatenate(all_ids) - all_args = tuple(jnp.concatenate(args) for args in all_args) - return all_fps, all_ids, all_args + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if len(self.target_vars) < 2: + raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system ' + f'with >= 2 variables. But we got {len(self.target_vars)} ' + f'variables in {self.model}.') + self.y_var = self.target_var_names[1] + + @property + def F_fy(self): + r"""The function to evaluate :math:`f_y(*\mathrm{vars}, *\mathrm{pars})`. + + This function has been transformed into the standard call. + For instance, if the user has the ``target_vars=("v1", "v2")`` and + the ``target_pars=("p1", "p2")``, while the first function is defined as: + + >>> def f1(v1, t, p1): + >>> return something + + However, after the stransformation, this function should be called as: + + >>> self.F_fy(v1, v2, p1, p2) + """ + if C.F_fy not in self.analyzed_results: + variables, arguments = utils.get_args(self.model.f_derivatives[self.y_var]) + wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) + f = wrapper(self.model.f_derivatives[self.y_var]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + f = utils.f_without_jaxarray_return(f) + f = utils.remove_return_shape(f) + self.analyzed_results[C.F_fy] = jax.jit(f, device=self.jit_device) + return self.analyzed_results[C.F_fy] + + @property + def F_int_x(self): + if C.F_int_x not in self.analyzed_results: + wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.x_var])[1], + self.target_var_names, self.target_par_names) + init_x = partial(wrap_x(self.model.f_integrals[0]), **(self.pars_update + self.fixed_vars)) + self.analyzed_results[C.F_int_x] = init_x + return self.analyzed_results[C.F_int_x] + + @property + def F_int_y(self): + if C.F_int_y not in self.analyzed_results: + wrap_x = utils.std_derivative(utils.get_args(self.model.f_derivatives[self.y_var])[1], + self.target_var_names, self.target_par_names) + init_x = partial(wrap_x(self.model.f_integrals[1]), **(self.pars_update + self.fixed_vars)) + self.analyzed_results[C.F_int_y] = init_x + return self.analyzed_results[C.F_int_y] + + @property + def F_x_by_y_in_fx(self): + if C.F_x_by_y_in_fx not in self.analyzed_results: + if C.x_by_y_in_fx in self.options: + wrapper = utils.std_func(utils.get_args(self.options[C.x_by_y_in_fx], gather_var=False), + self.target_var_names[1:], + self.target_par_names) + f = wrapper(self.options[C.x_by_y_in_fx]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + f = utils.f_without_jaxarray_return(f) + self.analyzed_results[C.F_x_by_y_in_fx] = f + else: + self.analyzed_results[C.F_x_by_y_in_fx] = None + return self.analyzed_results[C.F_x_by_y_in_fx] + + @property + def F_y_by_x_in_fx(self): + if C.F_y_by_x_in_fx not in self.analyzed_results: + if C.y_by_x_in_fx in self.options: + wrapper = utils.std_func(utils.get_args(self.options[C.y_by_x_in_fx], gather_var=False), + self.target_var_names[:1] + self.target_var_names[2:], + self.target_par_names) + f = wrapper(self.options[C.y_by_x_in_fx]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + f = utils.f_without_jaxarray_return(f) + self.analyzed_results[C.F_y_by_x_in_fx] = f + else: + self.analyzed_results[C.F_y_by_x_in_fx] = None + return self.analyzed_results[C.F_y_by_x_in_fx] + + @property + def F_x_by_y_in_fy(self): + if C.F_x_by_y_in_fy not in self.analyzed_results: + if C.x_by_y_in_fy in self.options: + wrapper = utils.std_func(utils.get_args(self.options[C.x_by_y_in_fy], gather_var=False), + self.target_var_names[1:], + self.target_par_names) + f = wrapper(self.options[C.x_by_y_in_fy]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + f = utils.f_without_jaxarray_return(f) + self.analyzed_results[C.F_x_by_y_in_fy] = f + else: + self.analyzed_results[C.F_x_by_y_in_fy] = None + return self.analyzed_results[C.F_x_by_y_in_fy] + + @property + def F_y_by_x_in_fy(self): + if C.F_y_by_x_in_fy not in self.analyzed_results: + if C.y_by_x_in_fy in self.options: + wrapper = utils.std_func(utils.get_args(self.options[C.y_by_x_in_fy], gather_var=False), + self.target_var_names[:1] + self.target_var_names[2:], + self.target_par_names) + f = wrapper(self.options[C.y_by_x_in_fy]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + f = utils.f_without_jaxarray_return(f) + self.analyzed_results[C.F_y_by_x_in_fy] = f + else: + self.analyzed_results[C.F_y_by_x_in_fy] = None + return self.analyzed_results[C.F_y_by_x_in_fy] + + @property + def F_vmap_fy(self): + if C.F_vmap_fy not in self.analyzed_results: + self.analyzed_results[C.F_vmap_fy] = jax.jit(vmap(self.F_fy), device=self.jit_device) + return self.analyzed_results[C.F_vmap_fy] + + @property + def F_dfxdy(self): + r"""The function to evaluate :math:`\frac{df_x (*\mathrm{vars}, *\mathrm{pars})}{dy}`.""" + if C.F_dfxdy not in self.analyzed_results: + dfxdy = bm.vector_grad(self.F_fx, argnums=1) + self.analyzed_results[C.F_dfxdy] = jax.jit(dfxdy, device=self.jit_device) + return self.analyzed_results[C.F_dfxdy] + + @property + def F_dfydx(self): + r"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dx}`.""" + if C.F_dfydx not in self.analyzed_results: + dfydx = bm.vector_grad(self.F_fy, argnums=0) + self.analyzed_results[C.F_dfydx] = jax.jit(dfydx, device=self.jit_device) + return self.analyzed_results[C.F_dfydx] + + @property + def F_dfydy(self): + r"""The function to evaluate :math:`\frac{df_y (*\mathrm{vars}, *\mathrm{pars})}{dy}`.""" + if C.F_dfydy not in self.analyzed_results: + dfydy = bm.vector_grad(self.F_fy, argnums=1) + self.analyzed_results[C.F_dfydy] = jax.jit(dfydy, device=self.jit_device) + return self.analyzed_results[C.F_dfydy] + + @property + def F_jacobian(self): + r"""The function to evaluate :math:`J(*\mathrm{vars}, *\mathrm{pars})`.""" + if C.F_jacobian not in self.analyzed_results: + @partial(bm.jacobian, argnums=(0, 1)) + def f_jacobian(*var_and_pars): + return self.F_fx(*var_and_pars), self.F_fy(*var_and_pars) + + def call(*var_and_pars): + var_and_pars = tuple((vp.value if isinstance(vp, bm.BaseArray) else vp) for vp in var_and_pars) + return jnp.array(jax.jit(f_jacobian, device=self.jit_device)(*var_and_pars)) + + self.analyzed_results[C.F_jacobian] = call + return self.analyzed_results[C.F_jacobian] + + @property + def F_fixed_point_opt(self): + if C.F_fixed_point_opt not in self.analyzed_results: + if self._can_convert_to_one_eq(): + if self.convert_type() == C.x_by_y: + def f(start_and_end, *args): + return utils.jax_brentq(self.F_y_convert[1])(start_and_end[0], start_and_end[1], args) + else: + def f(start_and_end, *args): + return utils.jax_brentq(self.F_x_convert[1])(start_and_end[0], start_and_end[1], args) + self.analyzed_results[C.F_fixed_point_opt] = f + + else: + # If cannot convert to one variable equation + def opt_fun(xy_init, *args): + # "xy_init" is a vector with length 2, + # "args: is a tuple of scalar + return minimize(self.F_fixed_point_aux, xy_init, args=args, method='BFGS') + + self.analyzed_results[C.F_fixed_point_opt] = opt_fun + return self.analyzed_results[C.F_fixed_point_opt] + + @property + def F_fixed_point_aux(self): + if C.F_fixed_point_aux not in self.analyzed_results: + if self._can_convert_to_one_eq(): + if self.convert_type() == C.x_by_y: + f = lambda y, *args: jnp.abs(self.F_y_convert[1](y, *args)).sum() + else: + f = lambda x, *args: jnp.abs(self.F_x_convert[1](x, *args)).sum() + self.analyzed_results[C.F_fixed_point_aux] = f + + else: + def aux_fun(xy, *args): + # "xy" is a vector with length 2, + # "args": is a tuple of scalar + dx = self.F_fx(xy[0], xy[1], *args) + dy = self.F_fy(xy[0], xy[1], *args) + # return (jnp.abs(dx) + jnp.abs(dy)).sum() + return (dx ** 2 + dy ** 2).sum() + + self.analyzed_results[C.F_fixed_point_aux] = aux_fun + return self.analyzed_results[C.F_fixed_point_aux] + + def _can_convert_to_one_eq(self): + if self.F_x_by_y_in_fx is not None: + return True + if self.F_x_by_y_in_fy is not None: + return True + if self.F_y_by_x_in_fx is not None: + return True + if self.F_y_by_x_in_fy is not None: + return True + return False + + def convert_type(self): + if self.F_x_by_y_in_fx is not None: + return C.x_by_y + if self.F_x_by_y_in_fy is not None: + return C.x_by_y + if self.F_y_by_x_in_fx is not None: + return C.y_by_x + if self.F_y_by_x_in_fy is not None: + return C.y_by_x + raise errors.AnalyzerError + + @property + def F_y_convert(self): + if C.F_y_convert not in self.analyzed_results: + if self.F_x_by_y_in_fy is not None: + f = lambda y, *pars: self.F_fx(self.F_x_by_y_in_fy(y, *pars), y, *pars) + res = (self.F_x_by_y_in_fy, f) + elif self.F_x_by_y_in_fx is not None: + f = lambda y, *pars: self.F_fy(self.F_x_by_y_in_fx(y, *pars), y, *pars) + res = (self.F_x_by_y_in_fx, f) + else: + res = None + self.analyzed_results[C.F_y_convert] = res + return self.analyzed_results[C.F_y_convert] + + @property + def F_x_convert(self): + if C.F_x_convert not in self.analyzed_results: + if self.F_y_by_x_in_fy is not None: + f = lambda x, *pars: self.F_fx(x, self.F_y_by_x_in_fy(x, *pars), *pars) + res = (self.F_y_by_x_in_fy, f) + elif self.F_y_by_x_in_fx is not None: + f = lambda x, *pars: self.F_fy(x, self.F_y_by_x_in_fx(x, *pars), *pars) + res = (self.F_y_by_x_in_fx, f) + else: + res = None + self.analyzed_results[C.F_x_convert] = res + return self.analyzed_results[C.F_x_convert] + + def _fp_filter(self, x_values, y_values, par_values, aux_filter=0.): + if aux_filter > 0.: + losses = self.F_vmap_fp_aux(jnp.stack([x_values, y_values]).T, *par_values) + ids = jnp.where(losses < aux_filter)[0] + x_values = x_values[ids] + y_values = y_values[ids] + par_values = tuple(p[ids] for p in par_values) + return x_values, y_values, par_values + + def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux_filter=0.): + coords = (self.x_var + '-' + self.y_var) if coords is None else coords + key = C.fx_nullcline_points + ',' + coords + if key not in self.analyzed_results: + all_losses = [] + all_x_values_in_fx = [] + all_y_values_in_fx = [] + all_p_values_in_fx = tuple([] for _ in range(len(self.target_par_names))) + + # points of variables and parameters + xs = self.resolutions[self.x_var] + ys = self.resolutions[self.y_var] + par_seg = utils.Segment(targets=tuple(self.resolutions[p] for p in self.target_par_names), + num_segments=num_segments) + + if self.F_x_by_y_in_fx is not None: + utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...") + vmap_f = jax.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device) + for j, pars in enumerate(par_seg): + if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") + mesh_values = jnp.meshgrid(*((ys,) + pars)) + x_values_in_fx = vmap_f(*mesh_values) + y_values_in_fx = mesh_values[0] + p_values_in_fx = mesh_values[1:] + losses = self.F_vmap_fx(x_values_in_fx, y_values_in_fx, *p_values_in_fx) + all_losses.append(losses) + all_x_values_in_fx.append(x_values_in_fx) + all_y_values_in_fx.append(y_values_in_fx) + for i, arg in enumerate(p_values_in_fx): + all_p_values_in_fx[i].append(arg) + + elif self.F_y_by_x_in_fx is not None: + utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...") + vmap_f = jax.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device) + for j, pars in enumerate(par_seg): + if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") + mesh_values = jnp.meshgrid(*((xs,) + pars)) + y_values_in_fx = vmap_f(*mesh_values) + x_values_in_fx = mesh_values[0] + p_values_in_fx = mesh_values[1:] + losses = self.F_vmap_fx(x_values_in_fx, y_values_in_fx, *p_values_in_fx) + all_losses.append(losses) + all_x_values_in_fx.append(x_values_in_fx) + all_y_values_in_fx.append(y_values_in_fx) + for i, arg in enumerate(p_values_in_fx): + all_p_values_in_fx[i].append(arg) + + else: + utils.output("I am evaluating fx-nullcline by optimization ...") + # auxiliary functions + f2 = lambda y, x, *pars: self.F_fx(x, y, *pars) + vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) + vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) + vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) + + # num segments + for _j, Ps in enumerate(par_seg): + if len(par_seg.arg_id_segments[0]) > 1: + utils.output(f"{C.prefix}segment {_j} ...") + if coords == self.x_var + '-' + self.y_var: + x0s, x1s, vps = utils.brentq_candidates(self.F_vmap_fx, *((xs, ys) + Ps)) + x_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f1, x0s, x1s, *vps) + y_values_in_fx = out_args[0] + p_values_in_fx = out_args[1:] + x_values_in_fx, y_values_in_fx, p_values_in_fx = \ + self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter) + elif coords == self.y_var + '-' + self.x_var: + x0s, x1s, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) + y_values_in_fx, out_args = utils.brentq_roots2(vmap_brentq_f2, x0s, x1s, *vps) + x_values_in_fx = out_args[0] + p_values_in_fx = out_args[1:] + x_values_in_fx, y_values_in_fx, p_values_in_fx = \ + self._fp_filter(x_values_in_fx, y_values_in_fx, p_values_in_fx, fp_aux_filter) + else: + raise ValueError + losses = self.F_vmap_fx(x_values_in_fx, y_values_in_fx, *p_values_in_fx) + all_losses.append(losses) + all_x_values_in_fx.append(x_values_in_fx) + all_y_values_in_fx.append(y_values_in_fx) + for i, arg in enumerate(p_values_in_fx): + all_p_values_in_fx[i].append(arg) + + all_losses = jnp.concatenate(all_losses) + all_x_values_in_fx = jnp.concatenate(all_x_values_in_fx) + all_y_values_in_fx = jnp.concatenate(all_y_values_in_fx) + all_p_values_in_fx = tuple(jnp.concatenate(p) for p in all_p_values_in_fx) + ids = jnp.where(all_losses < tol)[0] + all_x_values_in_fx = all_x_values_in_fx[ids] + all_y_values_in_fx = all_y_values_in_fx[ids] + all_p_values_in_fx = tuple(a[ids] for a in all_p_values_in_fx) + all_xy_values = jnp.stack([all_x_values_in_fx, all_y_values_in_fx]).T + self.analyzed_results[key] = (all_xy_values,) + all_p_values_in_fx + return self.analyzed_results[key] + + def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux_filter=0.): + coords = (self.x_var + '-' + self.y_var) if coords is None else coords + key = C.fy_nullcline_points + ',' + coords + if key not in self.analyzed_results: + all_losses = [] + all_x_values_in_fy = [] + all_y_values_in_fy = [] + all_p_values_in_fy = tuple([] for _ in range(len(self.target_par_names))) + + xs = self.resolutions[self.x_var] + ys = self.resolutions[self.y_var] + par_seg = utils.Segment(tuple(self.resolutions[p] for p in self.target_par_names), + num_segments=num_segments) + + if self.F_x_by_y_in_fy is not None: + utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...") + vmap_f = jax.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device) + for j, pars in enumerate(par_seg): + if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") + mesh_values = jnp.meshgrid(*((ys,) + pars)) + x_values_in_fy = vmap_f(*mesh_values) + y_values_in_fy = mesh_values[0] + p_values_in_fy = mesh_values[1:] + losses = self.F_vmap_fy(x_values_in_fy, y_values_in_fy, *p_values_in_fy) + all_losses.append(losses) + all_x_values_in_fy.append(x_values_in_fy) + all_y_values_in_fy.append(y_values_in_fy) + for i, arg in enumerate(p_values_in_fy): + all_p_values_in_fy[i].append(arg) + + elif self.F_y_by_x_in_fy is not None: + utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...") + vmap_f = jax.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device) + for j, pars in enumerate(par_seg): + if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") + mesh_values = jnp.meshgrid(*((xs,) + pars)) + y_values_in_fy = vmap_f(*mesh_values) + x_values_in_fy = mesh_values[0] + p_values_in_fy = mesh_values[1:] + losses = self.F_vmap_fy(x_values_in_fy, y_values_in_fy, *p_values_in_fy) + all_losses.append(losses) + all_x_values_in_fy.append(x_values_in_fy) + all_y_values_in_fy.append(y_values_in_fy) + for i, arg in enumerate(p_values_in_fy): + all_p_values_in_fy[i].append(arg) + + else: + utils.output("I am evaluating fy-nullcline by optimization ...") + + # auxiliary functions + f2 = lambda y, x, *pars: self.F_fy(x, y, *pars) + vmap_f2 = jax.jit(vmap(f2), device=self.jit_device) + vmap_brentq_f2 = jax.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) + vmap_brentq_f1 = jax.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) + + for j, Ps in enumerate(par_seg): + if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") + if coords == self.x_var + '-' + self.y_var: + starts, ends, vps = utils.brentq_candidates(self.F_vmap_fy, *((xs, ys) + Ps)) + x_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f1, starts, ends, *vps) + y_values_in_fy = out_args[0] + p_values_in_fy = out_args[1:] + x_values_in_fy, y_values_in_fy, p_values_in_fy = \ + self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter) + elif coords == self.y_var + '-' + self.x_var: + starts, ends, vps = utils.brentq_candidates(vmap_f2, *((ys, xs) + Ps)) + y_values_in_fy, out_args = utils.brentq_roots2(vmap_brentq_f2, starts, ends, *vps) + x_values_in_fy = out_args[0] + p_values_in_fy = out_args[1:] + x_values_in_fy, y_values_in_fy, p_values_in_fy = \ + self._fp_filter(x_values_in_fy, y_values_in_fy, p_values_in_fy, fp_aux_filter) + else: + raise ValueError + losses = self.F_vmap_fy(x_values_in_fy, y_values_in_fy, *p_values_in_fy) + all_losses.append(losses) + all_x_values_in_fy.append(x_values_in_fy) + all_y_values_in_fy.append(y_values_in_fy) + for i, arg in enumerate(p_values_in_fy): + all_p_values_in_fy[i].append(arg) + all_losses = jnp.concatenate(all_losses) + all_x_values_in_fy = jnp.concatenate(all_x_values_in_fy) + all_y_values_in_fy = jnp.concatenate(all_y_values_in_fy) + all_p_values_in_fy = tuple(jnp.concatenate(p) for p in all_p_values_in_fy) + ids = jnp.where(all_losses < tol)[0] + all_x_values_in_fy = all_x_values_in_fy[ids] + all_y_values_in_fy = all_y_values_in_fy[ids] + all_p_values_in_fy = tuple(a[ids] for a in all_p_values_in_fy) + all_xy_values = jnp.stack([all_x_values_in_fy, all_y_values_in_fy]).T + self.analyzed_results[key] = (all_xy_values,) + all_p_values_in_fy + return self.analyzed_results[key] + + def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100): + utils.output(f"I am filtering out fixed point candidates with auxiliary function ...") + all_xs = [] + all_ys = [] + all_ps = tuple([] for _ in range(len(self.target_par_names))) + + # points of variables and parameters + xs = self.resolutions[self.x_var] + ys = self.resolutions[self.y_var] + P = tuple(self.resolutions[p] for p in self.target_par_names) + f_select = jax.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1))) + + # num seguments + if isinstance(num_segments, int): + num_segments = tuple([num_segments] * len(self.target_par_names)) + assert isinstance(num_segments, (tuple, list)) and len(num_segments) == len(self.target_par_names) + arg_lens = tuple(len(p) for p in P) + arg_pre_len = tuple(int(np.ceil(l / num_segments[i])) for i, l in enumerate(arg_lens)) + arg_id_segments = tuple(np.arange(0, l, arg_pre_len[i]) for i, l in enumerate(arg_lens)) + arg_id_segments = tuple(ids.flatten() for ids in np.meshgrid(*arg_id_segments)) + if len(arg_id_segments) == 0: + arg_id_segments = ((0,),) + for _j, ids in enumerate(zip(*arg_id_segments)): + if len(arg_id_segments[0]) > 1: + utils.output(f"{C.prefix}segment {_j} ...") + + ps = tuple(p[ids[i]: ids[i] + arg_pre_len[i]] for i, p in enumerate(P)) + # change the position of meshgrid values + vps = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in ((xs, ys) + ps)) + mesh_values = jnp.meshgrid(*vps) + mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) + mesh_values = tuple(m.flatten() for m in mesh_values) + # function outputs + losses = self.F_vmap_fp_aux(jnp.stack([mesh_values[0], mesh_values[1]]).T, *mesh_values[2:]) + shape = (len(xs) * len(ys), -1) + losses = losses.reshape(shape) + argsorts = jnp.argsort(losses, axis=0)[:num_rank] + all_xs.append(f_select(mesh_values[0].reshape(shape), argsorts).flatten()) + all_ys.append(f_select(mesh_values[1].reshape(shape), argsorts).flatten()) + for i, p in enumerate(ps): + all_ps[i].append(f_select(mesh_values[i + 2].reshape(shape), argsorts).flatten()) + all_xys = jnp.vstack([jnp.concatenate(all_xs), jnp.concatenate(all_ys)]).T + all_ps = tuple(jnp.concatenate(p) for p in all_ps) + return (all_xys, all_ps) + + def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, + tol_unique=1e-2, tol_opt_candidate=None, + num_segment=1): + """Get the fixed points according to the initial ``candidates`` and the parameter setting ``args``. + + "candidates" and "args" can be obtained through: + + >>> all_candidates = [] + >>> all_par1 = [] + >>> all_par2 = [] + >>> for p1 in par1_list: + >>> for p2 in par2_list: + >>> nullcline_points = _get_nullcline_points(p1, p2) + >>> all_candidates.append(nullcline_points) + >>> all_par1.append(jnp.ones_like(nullcline_points) * p1) + >>> all_par2.append(jnp.ones_like(nullcline_points) * p2) + + Parameters:: + + candidates: np.ndarray, jnp.ndarray + The candidate points (batched) to optimize, like the nullcline points. + args : tuple + The parameters (batched). + tol_aux : float + tol_unique : float + tol_opt_candidate : float, optional + + Returns:: + + res : tuple + The fixed point results. + """ + + if self._can_convert_to_one_eq(): + utils.output("I am trying to find fixed points by brentq optimization ...") + + # candidates: xs, a vector with the length of self.resolutions[self.x_var] + # args: parameters, a list/tuple of vectors + candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates + selected_ids = np.arange(len(candidates)) + args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) + for a in args: assert len(a) == len(candidates) + + if self.convert_type() == C.x_by_y: + num_seg = len(self.resolutions[self.y_var]) + f_vmap = jax.jit(vmap(self.F_y_convert[1])) + else: + num_seg = len(self.resolutions[self.x_var]) + f_vmap = jax.jit(vmap(self.F_x_convert[1])) + # get the signs + signs = jnp.sign(f_vmap(candidates, *args)) + signs = signs.reshape((num_seg, -1)) + par_len = signs.shape[1] + signs1 = signs.at[-1].set(1) + signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) + ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] + if len(ids) <= 0: + return [], [], [] + + # selected the proper candidates to optimize fixed points + selected_ids = selected_ids[np.asarray(ids)] + starts = candidates[ids] + ends = candidates[ids + par_len] + X = jnp.stack((starts, ends)).T + args = tuple(a[ids] for a in args) + + # optimize the fixed points + res = self.F_vmap_fp_opt(X, *args) + losses = self.F_vmap_fp_aux(res['root'], *args) + valid_or_not = jnp.logical_and(res['status'] == utils.ECONVERGED, losses <= tol_aux) + ids = np.asarray(jnp.where(valid_or_not)[0]) + fps = np.asarray(res['root'])[ids] + args = tuple(a[ids] for a in args) + selected_ids = selected_ids[np.asarray(ids)] + + # get another value + if self.convert_type() == C.x_by_y: + y_values = fps + x_values = jax.jit(vmap(self.F_y_convert[0]))(y_values, *args) + else: + x_values = fps + y_values = jax.jit(vmap(self.F_x_convert[0]))(x_values, *args) + fps = jnp.stack([x_values, y_values]).T + return fps, selected_ids, args + + else: + utils.output("I am trying to find fixed points by optimization ...") + utils.output(f"{C.prefix}There are {len(candidates)} candidates") + + candidates = jnp.asarray(candidates) + args = tuple(jnp.asarray(a) for a in args) + + all_ids = [] + all_fps = [] + all_args = tuple([] for _ in range(len(args))) + seg_len = int(np.ceil(len(candidates) / num_segment)) + segment_ids = np.arange(0, len(candidates), seg_len) + selected_ids = jnp.arange(len(candidates)) + + for _j, i in enumerate(segment_ids): + if len(segment_ids) > 1: + utils.output(f"{C.prefix}segment {_j} ...") + seg_fps = candidates[i: i + seg_len] + seg_args = tuple(a[i: i + seg_len] for a in args) + seg_ids = selected_ids[i: i + seg_len] + + if tol_opt_candidate is not None: + # screen by the function loss + losses = self.F_vmap_fp_aux(seg_fps, *seg_args) + ids = jnp.where(losses < tol_opt_candidate)[0] + seg_fps = seg_fps[ids] + seg_args = tuple(a[ids] for a in seg_args) + seg_ids = seg_ids[ids] + if len(seg_fps): + # optimization + seg_fps = self.F_vmap_fp_opt(seg_fps, *seg_args) + # loss + losses = self.F_vmap_fp_aux(seg_fps.x, *seg_args) + # valid indices + ids = jnp.where(losses <= tol_aux)[0] + seg_ids = seg_ids[ids] + all_fps.append(seg_fps.x[ids]) + all_ids.append(seg_ids) + for i in range(len(all_args)): + all_args[i].append(seg_args[i][ids]) + all_fps = jnp.concatenate(all_fps) + all_ids = jnp.concatenate(all_ids) + all_args = tuple(jnp.concatenate(args) for args in all_args) + return all_fps, all_ids, all_args class Num3DAnalyzer(Num2DAnalyzer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if len(self.target_vars) < 3: - raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system ' - f'with >= 3 variables. But we got {len(self.target_vars)} ' - f'variables in {self.model}.') - self.z_var = self.target_var_names[2] - - @property - def F_fz(self): - r"""The function to evaluate :math:`f_y(*\mathrm{vars}, *\mathrm{pars})`.""" - if C.F_fz not in self.analyzed_results: - variables, arguments = utils.get_args(self.model.f_derivatives[self.z_var]) - wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) - f = wrapper(self.model.f_derivatives[self.z_var]) - f = partial(f, **(self.pars_update + self.fixed_vars)) - self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device) - return self.analyzed_results[C.F_fz] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if len(self.target_vars) < 3: + raise errors.AnalyzerError(f'{Num1DAnalyzer.__name__} only supports dynamical system ' + f'with >= 3 variables. But we got {len(self.target_vars)} ' + f'variables in {self.model}.') + self.z_var = self.target_var_names[2] + + @property + def F_fz(self): + r"""The function to evaluate :math:`f_y(*\mathrm{vars}, *\mathrm{pars})`.""" + if C.F_fz not in self.analyzed_results: + variables, arguments = utils.get_args(self.model.f_derivatives[self.z_var]) + wrapper = utils.std_derivative(arguments, self.target_var_names, self.target_par_names) + f = wrapper(self.model.f_derivatives[self.z_var]) + f = partial(f, **(self.pars_update + self.fixed_vars)) + self.analyzed_results[C.F_fz] = jax.jit(f, device=self.jit_device) + return self.analyzed_results[C.F_fz] diff --git a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py index b20cb7c18..f1befd2dd 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py @@ -1,623 +1,623 @@ # -*- coding: utf-8 -*- +from copy import deepcopy from functools import partial import jax import jax.numpy as jnp -from jax import vmap import numpy as np -from copy import deepcopy +from jax import vmap -from brainpy import errors import brainpy._src.math as bm +from brainpy import errors from brainpy._src.analysis import stability, plotstyle, utils, constants as C from brainpy._src.analysis.lowdim.lowdim_analyzer import * pyplot = None __all__ = [ - 'Bifurcation1D', - 'Bifurcation2D', - 'FastSlow1D', - 'FastSlow2D', + 'Bifurcation1D', + 'Bifurcation2D', + 'FastSlow1D', + 'FastSlow2D', ] class Bifurcation1D(Num1DAnalyzer): - """Bifurcation analysis of 1D system. - - Using this class, we can make co-dimension1 or co-dimension2 bifurcation analysis. - """ - - def __init__(self, model, target_pars, target_vars, fixed_vars=None, - pars_update=None, resolutions=None, options=None): - super().__init__(model=model, - target_pars=target_pars, - target_vars=target_vars, - fixed_vars=fixed_vars, - pars_update=pars_update, - resolutions=resolutions, - options=options) - - if len(self.target_pars) == 0: - raise ValueError - - @property - def F_vmap_dfxdx(self): - if C.F_vmap_dfxdx not in self.analyzed_results: - f = jax.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) - self.analyzed_results[C.F_vmap_dfxdx] = f - return self.analyzed_results[C.F_vmap_dfxdx] - - def plot_bifurcation(self, with_plot=True, show=False, with_return=False, - tol_aux=1e-8, loss_screen=None): - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am making bifurcation analysis ...') - - xs = self.resolutions[self.x_var] - vps = jnp.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names)) - vps = tuple(jnp.moveaxis(bm.as_jax(vp), 0, 1).flatten() for vp in vps) - candidates = vps[0] - pars = vps[1:] - fixed_points, _, pars = self._get_fixed_points(candidates, *pars, - tol_aux=tol_aux, - loss_screen=loss_screen, - num_seg=len(xs)) - dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars)) - pars = tuple(np.asarray(p) for p in pars) - - if with_plot: - if len(self.target_pars) == 1: - container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()} - - # fixed point - for p, x, dx in zip(pars[0], fixed_points, dfxdx): - fp_type = stability.stability_analysis(dx) - container[fp_type]['p'].append(p) - container[fp_type]['x'].append(x) - - # visualization - pyplot.figure(self.x_var) - for fp_type, points in container.items(): - if len(points['x']): - plot_style = deepcopy(plotstyle.plot_schema[fp_type]) - pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type) - pyplot.xlabel(self.target_par_names[0]) - pyplot.ylabel(self.x_var) - - scale = (self.lim_scale - 1) / 2 - pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) - pyplot.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) - - pyplot.legend() - if show: - pyplot.show() - - elif len(self.target_pars) == 2: - container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()} - - # fixed point - for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx): - fp_type = stability.stability_analysis(dx) - container[fp_type]['p0'].append(p0) - container[fp_type]['p1'].append(p1) - container[fp_type]['x'].append(x) + """Bifurcation analysis of 1D system. - # visualization - fig = pyplot.figure(self.x_var) - ax = fig.add_subplot(projection='3d') - for fp_type, points in container.items(): - if len(points['x']): - plot_style = deepcopy(plotstyle.plot_schema[fp_type]) - xs = points['p0'] - ys = points['p1'] - zs = points['x'] - plot_style.pop('linestyle') - plot_style['s'] = plot_style.pop('markersize', None) - ax.scatter(xs, ys, zs, **plot_style, label=fp_type) - - ax.set_xlabel(self.target_par_names[0]) - ax.set_ylabel(self.target_par_names[1]) - ax.set_zlabel(self.x_var) - - scale = (self.lim_scale - 1) / 2 - ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) - ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) - ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) - - ax.grid(True) - ax.legend() - if show: - pyplot.show() - - else: - raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} ' - f'bifurcation.') - if with_return: - return fixed_points, pars, dfxdx - - -class Bifurcation2D(Num2DAnalyzer): - """Bifurcation analysis of 2D system. - - Using this class, we can make co-dimension1 or co-dimension2 bifurcation analysis. - """ - - def __init__(self, model, target_pars, target_vars, fixed_vars=None, - pars_update=None, resolutions=None, options=None): - super().__init__(model=model, - target_pars=target_pars, - target_vars=target_vars, - fixed_vars=fixed_vars, - pars_update=pars_update, - resolutions=resolutions, - options=options) - - if len(self.target_pars) == 0: - raise ValueError - - self._fixed_points = None - - @property - def F_vmap_jacobian(self): - if C.F_vmap_jacobian not in self.analyzed_results: - f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args), - self.F_fy(xy[0], xy[1], *args)]) - f2 = jax.jit(vmap(bm.jacobian(f1)), device=self.jit_device) - self.analyzed_results[C.F_vmap_jacobian] = f2 - return self.analyzed_results[C.F_vmap_jacobian] - - def plot_bifurcation(self, with_plot=True, show=False, with_return=False, - tol_aux=1e-8, tol_unique=1e-2, tol_opt_candidate=None, - num_par_segments=1, num_fp_segment=1, nullcline_aux_filter=1., - select_candidates='aux_rank', num_rank=100): - r"""Make the bifurcation analysis. - - Parameters:: - - with_plot: bool - Whether plot the bifurcation figure. - show: bool - Whether show the figure. - with_return: bool - Whether return the computed bifurcation results. - tol_aux: float - The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed - point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`, - :math:`x_1` will be a fixed point. - tol_unique: float - The tolerance of distance between candidate fixed points to confirm they are - the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`, - then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise, - :math:`x_1` and :math:`x_2` will be treated as a same fixed point. - tol_opt_candidate: float, optional - The tolerance of auxiliary function :math:`f_{aux}` to select candidate - initial points for fixed point optimization. - num_par_segments: int, sequence of int - How to segment parameters. - num_fp_segment: int - How to segment fixed points. - nullcline_aux_filter: float - The - select_candidates: str - The method to select candidate fixed points. It can be: - - - ``fx-nullcline``: use the points of fx-nullcline. - - ``fy-nullcline``: use the points of fy-nullcline. - - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline. - - ``aux_rank``: use the minimal value of points for the auxiliary function. - num_rank: int - The number of candidates to be used to optimize the fixed points. - rank to use. - - Returns:: - - results : tuple - Return a tuple of analyzed results: - - - fixed points: a 2D matrix with the shape of (num_point, num_var) - - parameters: a 2D matrix with the shape of (num_point, num_par) - - jacobians: a 3D tensors with the shape of (num_point, 2, 2) + Using this class, we can make co-dimension1 or co-dimension2 bifurcation analysis. """ - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am making bifurcation analysis ...') - - if self._can_convert_to_one_eq(): - if self.convert_type() == C.x_by_y: - X = bm.as_jax(self.resolutions[self.y_var]) - else: - X = bm.as_jax(self.resolutions[self.x_var]) - pars = tuple(bm.as_jax(self.resolutions[p]) for p in self.target_par_names) - mesh_values = jnp.meshgrid(*((X,) + pars)) - mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values) - candidates = mesh_values[0] - parameters = mesh_values[1:] - - else: - if select_candidates == 'fx-nullcline': - fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments, - fp_aux_filter=nullcline_aux_filter) - candidates = fx_nullclines[0] - parameters = fx_nullclines[1:] - elif select_candidates == 'fy-nullcline': - fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments, - fp_aux_filter=nullcline_aux_filter) - candidates = fy_nullclines[0] - parameters = fy_nullclines[1:] - elif select_candidates == 'nullclines': - fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments, - fp_aux_filter=nullcline_aux_filter) - fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments, - fp_aux_filter=nullcline_aux_filter) - candidates = jnp.vstack([fx_nullclines[0], fy_nullclines[0]]) - parameters = [jnp.concatenate([fx_nullclines[i], fy_nullclines[i]]) - for i in range(1, len(fy_nullclines))] - elif select_candidates == 'aux_rank': - assert nullcline_aux_filter > 0. - candidates, parameters = self._get_fp_candidates_by_aux_rank(num_segments=num_par_segments, - num_rank=num_rank) - else: - raise ValueError - candidates, _, parameters = self._get_fixed_points(candidates, - *parameters, + + def __init__(self, model, target_pars, target_vars, fixed_vars=None, + pars_update=None, resolutions=None, options=None): + super().__init__(model=model, + target_pars=target_pars, + target_vars=target_vars, + fixed_vars=fixed_vars, + pars_update=pars_update, + resolutions=resolutions, + options=options) + + if len(self.target_pars) == 0: + raise ValueError + + @property + def F_vmap_dfxdx(self): + if C.F_vmap_dfxdx not in self.analyzed_results: + f = jax.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) + self.analyzed_results[C.F_vmap_dfxdx] = f + return self.analyzed_results[C.F_vmap_dfxdx] + + def plot_bifurcation(self, with_plot=True, show=False, with_return=False, + tol_aux=1e-8, loss_screen=None): + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am making bifurcation analysis ...') + + xs = self.resolutions[self.x_var] + vps = jnp.meshgrid(xs, *tuple(self.resolutions[p] for p in self.target_par_names)) + vps = tuple(jnp.moveaxis(bm.as_jax(vp), 0, 1).flatten() for vp in vps) + candidates = vps[0] + pars = vps[1:] + fixed_points, _, pars = self._get_fixed_points(candidates, *pars, tol_aux=tol_aux, - tol_unique=tol_unique, - tol_opt_candidate=tol_opt_candidate, - num_segment=num_fp_segment) - candidates = np.asarray(candidates) - parameters = np.stack(tuple(np.asarray(p) for p in parameters)).T - utils.output('I am trying to filter out duplicate fixed points ...') - final_fps = [] - final_pars = [] - for par in np.unique(parameters, axis=0): - ids = np.where(np.all(parameters == par, axis=1))[0] - fps, ids2 = utils.keep_unique(candidates[ids], tolerance=tol_unique) - final_fps.append(fps) - final_pars.append(parameters[ids[ids2]]) - final_fps = np.vstack(final_fps) # with the shape of (num_point, num_var) - final_pars = np.vstack(final_pars) # with the shape of (num_point, num_par) - jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T)) - utils.output(f'{C.prefix}Found {len(final_fps)} fixed points.') - - # remember the fixed points for later limit cycle plotting - self._fixed_points = (final_fps, final_pars) - - if with_plot: - # bifurcation analysis of co-dimension 1 - if len(self.target_pars) == 1: - container = {c: {'p': [], self.x_var: [], self.y_var: []} - for c in stability.get_2d_stability_types()} - - # fixed point - for p, xy, J in zip(final_pars, final_fps, jacobians): - fp_type = stability.stability_analysis(J) - container[fp_type]['p'].append(p[0]) - container[fp_type][self.x_var].append(xy[0]) - container[fp_type][self.y_var].append(xy[1]) + loss_screen=loss_screen, + num_seg=len(xs)) + dfxdx = np.asarray(self.F_vmap_dfxdx(jnp.asarray(fixed_points), *pars)) + pars = tuple(np.asarray(p) for p in pars) + + if with_plot: + if len(self.target_pars) == 1: + container = {c: {'p': [], 'x': []} for c in stability.get_1d_stability_types()} + + # fixed point + for p, x, dx in zip(pars[0], fixed_points, dfxdx): + fp_type = stability.stability_analysis(dx) + container[fp_type]['p'].append(p) + container[fp_type]['x'].append(x) + + # visualization + pyplot.figure(self.x_var) + for fp_type, points in container.items(): + if len(points['x']): + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) + pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type) + pyplot.xlabel(self.target_par_names[0]) + pyplot.ylabel(self.x_var) + + scale = (self.lim_scale - 1) / 2 + pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) + pyplot.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) + + pyplot.legend() + if show: + pyplot.show() + + elif len(self.target_pars) == 2: + container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()} + + # fixed point + for p0, p1, x, dx in zip(pars[0], pars[1], fixed_points, dfxdx): + fp_type = stability.stability_analysis(dx) + container[fp_type]['p0'].append(p0) + container[fp_type]['p1'].append(p1) + container[fp_type]['x'].append(x) + + # visualization + fig = pyplot.figure(self.x_var) + ax = fig.add_subplot(projection='3d') + for fp_type, points in container.items(): + if len(points['x']): + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) + xs = points['p0'] + ys = points['p1'] + zs = points['x'] + plot_style.pop('linestyle') + plot_style['s'] = plot_style.pop('markersize', None) + ax.scatter(xs, ys, zs, **plot_style, label=fp_type) + + ax.set_xlabel(self.target_par_names[0]) + ax.set_ylabel(self.target_par_names[1]) + ax.set_zlabel(self.x_var) + + scale = (self.lim_scale - 1) / 2 + ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) + ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) + ax.set_zlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) + + ax.grid(True) + ax.legend() + if show: + pyplot.show() + + else: + raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} ' + f'bifurcation.') + if with_return: + return fixed_points, pars, dfxdx - # visualization - for var in self.target_var_names: - pyplot.figure(var) - for fp_type, points in container.items(): - if len(points['p']): - plot_style = deepcopy(plotstyle.plot_schema[fp_type]) - pyplot.plot(points['p'], points[var], **plot_style, label=fp_type) - pyplot.xlabel(self.target_par_names[0]) - pyplot.ylabel(var) - - scale = (self.lim_scale - 1) / 2 - pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) - pyplot.ylim(*utils.rescale(self.target_vars[var], scale=scale)) - - pyplot.legend() - if show: - pyplot.show() - - # bifurcation analysis of co-dimension 2 - elif len(self.target_pars) == 2: - container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []} - for c in stability.get_2d_stability_types()} - - # fixed point - for p, xy, J in zip(final_pars, final_fps, jacobians): - fp_type = stability.stability_analysis(J) - container[fp_type]['p0'].append(p[0]) - container[fp_type]['p1'].append(p[1]) - container[fp_type][self.x_var].append(xy[0]) - container[fp_type][self.y_var].append(xy[1]) - # visualization - for var in self.target_var_names: - fig = pyplot.figure(var) - ax = fig.add_subplot(projection='3d') - for fp_type, points in container.items(): - if len(points['p0']): - plot_style = deepcopy(plotstyle.plot_schema[fp_type]) - xs = points['p0'] - ys = points['p1'] - zs = points[var] - plot_style.pop('linestyle') - plot_style['s'] = plot_style.pop('markersize', None) - ax.scatter(xs, ys, zs, **plot_style, label=fp_type) - - ax.set_xlabel(self.target_par_names[0]) - ax.set_ylabel(self.target_par_names[1]) - ax.set_zlabel(var) - scale = (self.lim_scale - 1) / 2 - ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) - ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) - ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale)) - ax.grid(True) - ax.legend() - if show: - pyplot.show() - - else: - raise ValueError('Unknown length of parameters.') - - if with_return: - return final_fps, final_pars, jacobians - - def plot_limit_cycle_by_sim( - self, - duration=100, - with_plot: bool = True, - with_return: bool = False, - plot_style: dict = None, - tol: float = 0.001, - show: bool = False, - dt: float = None, - offset: float = 1. - ): - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am plotting the limit cycle ...') - if self._fixed_points is None: - utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.') - return - - final_fps, final_pars = self._fixed_points - dt = bm.get_dt() if dt is None else dt - traject_model = utils.TrajectModel( - initial_vars={self.x_var: final_fps[:, 0] + offset, self.y_var: final_fps[:, 1] + offset}, - integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, - pars={p: v for p, v in zip(self.target_par_names, final_pars.T)}, - dt=dt - ) - mon_res = traject_model.run(duration=duration) - - # find limit cycles - vs_limit_cycle = tuple({'min': [], 'max': []} for _ in self.target_var_names) - ps_limit_cycle = tuple([] for _ in self.target_par_names) - for i in range(mon_res[self.x_var].shape[1]): - data = mon_res[self.x_var][:, i] - max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol) - if max_index[0] != -1: - cycle = data[max_index[0]: max_index[1]] - vs_limit_cycle[0]['max'].append(mon_res[self.x_var][max_index[1], i]) - vs_limit_cycle[0]['min'].append(cycle.min()) - cycle = mon_res[self.y_var][max_index[0]: max_index[1], i] - vs_limit_cycle[1]['max'].append(mon_res[self.y_var][max_index[1], i]) - vs_limit_cycle[1]['min'].append(cycle.min()) - for j in range(len(self.target_par_names)): - ps_limit_cycle[j].append(final_pars[i, j]) - vs_limit_cycle = tuple({k: np.asarray(v) for k, v in lm.items()} for lm in vs_limit_cycle) - ps_limit_cycle = tuple(np.array(p) for p in ps_limit_cycle) - - # visualization - if with_plot: - if plot_style is None: plot_style = dict() - fmt = plot_style.pop('fmt', '*') - - if len(self.target_par_names) == 2: - if len(ps_limit_cycle[0]): - for i, var in enumerate(self.target_var_names): - pyplot.figure(var) - pyplot.plot(ps_limit_cycle[0], - ps_limit_cycle[1], - vs_limit_cycle[i]['max'], - **plot_style, - label='limit cycle (max)') - pyplot.plot(ps_limit_cycle[0], - ps_limit_cycle[1], - vs_limit_cycle[i]['min'], - **plot_style, - label='limit cycle (min)') - pyplot.legend() +class Bifurcation2D(Num2DAnalyzer): + """Bifurcation analysis of 2D system. - elif len(self.target_par_names) == 1: - if len(ps_limit_cycle[0]): - for i, var in enumerate(self.target_var_names): - pyplot.figure(var) - pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt, - **plot_style, label='limit cycle (max)') - pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt, - **plot_style, label='limit cycle (min)') - pyplot.legend() + Using this class, we can make co-dimension1 or co-dimension2 bifurcation analysis. + """ - else: - raise errors.AnalyzerError + def __init__(self, model, target_pars, target_vars, fixed_vars=None, + pars_update=None, resolutions=None, options=None): + super().__init__(model=model, + target_pars=target_pars, + target_vars=target_vars, + fixed_vars=fixed_vars, + pars_update=pars_update, + resolutions=resolutions, + options=options) + + if len(self.target_pars) == 0: + raise ValueError + + self._fixed_points = None + + @property + def F_vmap_jacobian(self): + if C.F_vmap_jacobian not in self.analyzed_results: + f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args), + self.F_fy(xy[0], xy[1], *args)]) + f2 = jax.jit(vmap(bm.jacobian(f1)), device=self.jit_device) + self.analyzed_results[C.F_vmap_jacobian] = f2 + return self.analyzed_results[C.F_vmap_jacobian] + + def plot_bifurcation(self, with_plot=True, show=False, with_return=False, + tol_aux=1e-8, tol_unique=1e-2, tol_opt_candidate=None, + num_par_segments=1, num_fp_segment=1, nullcline_aux_filter=1., + select_candidates='aux_rank', num_rank=100): + r"""Make the bifurcation analysis. + + Parameters:: + + with_plot: bool + Whether plot the bifurcation figure. + show: bool + Whether show the figure. + with_return: bool + Whether return the computed bifurcation results. + tol_aux: float + The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed + point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`, + :math:`x_1` will be a fixed point. + tol_unique: float + The tolerance of distance between candidate fixed points to confirm they are + the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`, + then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise, + :math:`x_1` and :math:`x_2` will be treated as a same fixed point. + tol_opt_candidate: float, optional + The tolerance of auxiliary function :math:`f_{aux}` to select candidate + initial points for fixed point optimization. + num_par_segments: int, sequence of int + How to segment parameters. + num_fp_segment: int + How to segment fixed points. + nullcline_aux_filter: float + The + select_candidates: str + The method to select candidate fixed points. It can be: + + - ``fx-nullcline``: use the points of fx-nullcline. + - ``fy-nullcline``: use the points of fy-nullcline. + - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline. + - ``aux_rank``: use the minimal value of points for the auxiliary function. + num_rank: int + The number of candidates to be used to optimize the fixed points. + rank to use. + + Returns:: + + results : tuple + Return a tuple of analyzed results: + + - fixed points: a 2D matrix with the shape of (num_point, num_var) + - parameters: a 2D matrix with the shape of (num_point, num_par) + - jacobians: a 3D tensors with the shape of (num_point, 2, 2) + """ + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am making bifurcation analysis ...') + + if self._can_convert_to_one_eq(): + if self.convert_type() == C.x_by_y: + X = bm.as_jax(self.resolutions[self.y_var]) + else: + X = bm.as_jax(self.resolutions[self.x_var]) + pars = tuple(bm.as_jax(self.resolutions[p]) for p in self.target_par_names) + mesh_values = jnp.meshgrid(*((X,) + pars)) + mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values) + candidates = mesh_values[0] + parameters = mesh_values[1:] - if show: - pyplot.show() + else: + if select_candidates == 'fx-nullcline': + fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments, + fp_aux_filter=nullcline_aux_filter) + candidates = fx_nullclines[0] + parameters = fx_nullclines[1:] + elif select_candidates == 'fy-nullcline': + fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments, + fp_aux_filter=nullcline_aux_filter) + candidates = fy_nullclines[0] + parameters = fy_nullclines[1:] + elif select_candidates == 'nullclines': + fx_nullclines = self._get_fx_nullcline_points(num_segments=num_par_segments, + fp_aux_filter=nullcline_aux_filter) + fy_nullclines = self._get_fy_nullcline_points(num_segments=num_par_segments, + fp_aux_filter=nullcline_aux_filter) + candidates = jnp.vstack([fx_nullclines[0], fy_nullclines[0]]) + parameters = [jnp.concatenate([fx_nullclines[i], fy_nullclines[i]]) + for i in range(1, len(fy_nullclines))] + elif select_candidates == 'aux_rank': + assert nullcline_aux_filter > 0. + candidates, parameters = self._get_fp_candidates_by_aux_rank(num_segments=num_par_segments, + num_rank=num_rank) + else: + raise ValueError + candidates, _, parameters = self._get_fixed_points(candidates, + *parameters, + tol_aux=tol_aux, + tol_unique=tol_unique, + tol_opt_candidate=tol_opt_candidate, + num_segment=num_fp_segment) + candidates = np.asarray(candidates) + parameters = np.stack(tuple(np.asarray(p) for p in parameters)).T + utils.output('I am trying to filter out duplicate fixed points ...') + final_fps = [] + final_pars = [] + for par in np.unique(parameters, axis=0): + ids = np.where(np.all(parameters == par, axis=1))[0] + fps, ids2 = utils.keep_unique(candidates[ids], tolerance=tol_unique) + final_fps.append(fps) + final_pars.append(parameters[ids[ids2]]) + final_fps = np.vstack(final_fps) # with the shape of (num_point, num_var) + final_pars = np.vstack(final_pars) # with the shape of (num_point, num_par) + jacobians = np.asarray(self.F_vmap_jacobian(jnp.asarray(final_fps), *final_pars.T)) + utils.output(f'{C.prefix}Found {len(final_fps)} fixed points.') + + # remember the fixed points for later limit cycle plotting + self._fixed_points = (final_fps, final_pars) + + if with_plot: + # bifurcation analysis of co-dimension 1 + if len(self.target_pars) == 1: + container = {c: {'p': [], self.x_var: [], self.y_var: []} + for c in stability.get_2d_stability_types()} + + # fixed point + for p, xy, J in zip(final_pars, final_fps, jacobians): + fp_type = stability.stability_analysis(J) + container[fp_type]['p'].append(p[0]) + container[fp_type][self.x_var].append(xy[0]) + container[fp_type][self.y_var].append(xy[1]) + + # visualization + for var in self.target_var_names: + pyplot.figure(var) + for fp_type, points in container.items(): + if len(points['p']): + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) + pyplot.plot(points['p'], points[var], **plot_style, label=fp_type) + pyplot.xlabel(self.target_par_names[0]) + pyplot.ylabel(var) + + scale = (self.lim_scale - 1) / 2 + pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) + pyplot.ylim(*utils.rescale(self.target_vars[var], scale=scale)) + + pyplot.legend() + if show: + pyplot.show() + + # bifurcation analysis of co-dimension 2 + elif len(self.target_pars) == 2: + container = {c: {'p0': [], 'p1': [], self.x_var: [], self.y_var: []} + for c in stability.get_2d_stability_types()} + + # fixed point + for p, xy, J in zip(final_pars, final_fps, jacobians): + fp_type = stability.stability_analysis(J) + container[fp_type]['p0'].append(p[0]) + container[fp_type]['p1'].append(p[1]) + container[fp_type][self.x_var].append(xy[0]) + container[fp_type][self.y_var].append(xy[1]) + + # visualization + for var in self.target_var_names: + fig = pyplot.figure(var) + ax = fig.add_subplot(projection='3d') + for fp_type, points in container.items(): + if len(points['p0']): + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) + xs = points['p0'] + ys = points['p1'] + zs = points[var] + plot_style.pop('linestyle') + plot_style['s'] = plot_style.pop('markersize', None) + ax.scatter(xs, ys, zs, **plot_style, label=fp_type) + + ax.set_xlabel(self.target_par_names[0]) + ax.set_ylabel(self.target_par_names[1]) + ax.set_zlabel(var) + scale = (self.lim_scale - 1) / 2 + ax.set_xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) + ax.set_ylim(*utils.rescale(self.target_pars[self.target_par_names[1]], scale=scale)) + ax.set_zlim(*utils.rescale(self.target_vars[var], scale=scale)) + ax.grid(True) + ax.legend() + if show: + pyplot.show() + + else: + raise ValueError('Unknown length of parameters.') + + if with_return: + return final_fps, final_pars, jacobians + + def plot_limit_cycle_by_sim( + self, + duration=100, + with_plot: bool = True, + with_return: bool = False, + plot_style: dict = None, + tol: float = 0.001, + show: bool = False, + dt: float = None, + offset: float = 1. + ): + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am plotting the limit cycle ...') + if self._fixed_points is None: + utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.') + return + + final_fps, final_pars = self._fixed_points + dt = bm.get_dt() if dt is None else dt + traject_model = utils.TrajectModel( + initial_vars={self.x_var: final_fps[:, 0] + offset, self.y_var: final_fps[:, 1] + offset}, + integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, + pars={p: v for p, v in zip(self.target_par_names, final_pars.T)}, + dt=dt + ) + mon_res = traject_model.run(duration=duration) + + # find limit cycles + vs_limit_cycle = tuple({'min': [], 'max': []} for _ in self.target_var_names) + ps_limit_cycle = tuple([] for _ in self.target_par_names) + for i in range(mon_res[self.x_var].shape[1]): + data = mon_res[self.x_var][:, i] + max_index = utils.find_indexes_of_limit_cycle_max(data, tol=tol) + if max_index[0] != -1: + cycle = data[max_index[0]: max_index[1]] + vs_limit_cycle[0]['max'].append(mon_res[self.x_var][max_index[1], i]) + vs_limit_cycle[0]['min'].append(cycle.min()) + cycle = mon_res[self.y_var][max_index[0]: max_index[1], i] + vs_limit_cycle[1]['max'].append(mon_res[self.y_var][max_index[1], i]) + vs_limit_cycle[1]['min'].append(cycle.min()) + for j in range(len(self.target_par_names)): + ps_limit_cycle[j].append(final_pars[i, j]) + vs_limit_cycle = tuple({k: np.asarray(v) for k, v in lm.items()} for lm in vs_limit_cycle) + ps_limit_cycle = tuple(np.array(p) for p in ps_limit_cycle) - if with_return: - return vs_limit_cycle, ps_limit_cycle + # visualization + if with_plot: + if plot_style is None: plot_style = dict() + fmt = plot_style.pop('fmt', '*') + + if len(self.target_par_names) == 2: + if len(ps_limit_cycle[0]): + for i, var in enumerate(self.target_var_names): + pyplot.figure(var) + pyplot.plot(ps_limit_cycle[0], + ps_limit_cycle[1], + vs_limit_cycle[i]['max'], + **plot_style, + label='limit cycle (max)') + pyplot.plot(ps_limit_cycle[0], + ps_limit_cycle[1], + vs_limit_cycle[i]['min'], + **plot_style, + label='limit cycle (min)') + pyplot.legend() + + elif len(self.target_par_names) == 1: + if len(ps_limit_cycle[0]): + for i, var in enumerate(self.target_var_names): + pyplot.figure(var) + pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt, + **plot_style, label='limit cycle (max)') + pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt, + **plot_style, label='limit cycle (min)') + pyplot.legend() + + else: + raise errors.AnalyzerError + + if show: + pyplot.show() + + if with_return: + return vs_limit_cycle, ps_limit_cycle class FastSlow1D(Bifurcation1D): - def __init__( - self, - model, - fast_vars: dict, - slow_vars: dict, - fixed_vars: dict = None, - pars_update: dict = None, - resolutions=None, - options: dict = None - ): - super().__init__(model=model, - target_pars=slow_vars, - target_vars=fast_vars, - fixed_vars=fixed_vars, - pars_update=pars_update, - resolutions=resolutions, - options=options) - - # standard integrators - self._std_integrators = dict() - for key, intg in self.model.name2integral.items(): - wrap_x = utils.std_derivative(utils.get_args(self.model.name2derivative[key])[1], - self.target_var_names + self.target_par_names, []) - self._std_integrators[key] = partial(wrap_x(self.model.name2integral[key]), - **(self.pars_update + self.fixed_vars)) - - def plot_trajectory(self, initials, duration, plot_durations=None, - dt=None, show=False, with_plot=True, with_return=False): - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am plotting the trajectory ...') - - # check the initial values - initials = utils.check_initials(initials, self.target_var_names + self.target_par_names) - - # 2. format the running duration - assert isinstance(duration, (int, float)) - - # 3. format the plot duration - plot_durations = utils.check_plot_durations(plot_durations, duration, initials) - - # 5. run the network - dt = bm.get_dt() if dt is None else dt - - traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt) - mon_res = traject_model.run(duration=duration) - - if with_plot: - assert len(self.target_par_names) <= 2 - - # plots - for i, initial in enumerate(zip(*list(initials.values()))): - # legend - legend = f'$traj_{i}$: ' - for j, key in enumerate(self.target_var_names): - legend += f'{key}={initial[j]}, ' - legend = legend[:-2] - - # visualization - start = int(plot_durations[i][0] / dt) - end = int(plot_durations[i][1] / dt) - p1_var = self.target_par_names[0] - if len(self.target_par_names) == 1: - lines = pyplot.plot(mon_res[self.x_var][start: end, i], - mon_res[p1_var][start: end, i], label=legend) - elif len(self.target_par_names) == 2: - p2_var = self.target_par_names[1] - lines = pyplot.plot(mon_res[self.x_var][start: end, i], - mon_res[p1_var][start: end, i], - mon_res[p2_var][start: end, i], - label=legend) - else: - raise ValueError - utils.add_arrow(lines[0]) - - # # visualization of others - # plt.xlabel(self.x_var) - # plt.ylabel(self.target_par_names[0]) - # scale = (self.lim_scale - 1.) / 2 - # plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) - # plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale)) - pyplot.legend() + def __init__( + self, + model, + fast_vars: dict, + slow_vars: dict, + fixed_vars: dict = None, + pars_update: dict = None, + resolutions=None, + options: dict = None + ): + super().__init__(model=model, + target_pars=slow_vars, + target_vars=fast_vars, + fixed_vars=fixed_vars, + pars_update=pars_update, + resolutions=resolutions, + options=options) + + # standard integrators + self._std_integrators = dict() + for key, intg in self.model.name2integral.items(): + wrap_x = utils.std_derivative(utils.get_args(self.model.name2derivative[key])[1], + self.target_var_names + self.target_par_names, []) + self._std_integrators[key] = partial(wrap_x(self.model.name2integral[key]), + **(self.pars_update + self.fixed_vars)) + + def plot_trajectory(self, initials, duration, plot_durations=None, + dt=None, show=False, with_plot=True, with_return=False): + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am plotting the trajectory ...') + + # check the initial values + initials = utils.check_initials(initials, self.target_var_names + self.target_par_names) + + # 2. format the running duration + assert isinstance(duration, (int, float)) + + # 3. format the plot duration + plot_durations = utils.check_plot_durations(plot_durations, duration, initials) + + # 5. run the network + dt = bm.get_dt() if dt is None else dt + + traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt) + mon_res = traject_model.run(duration=duration) + + if with_plot: + assert len(self.target_par_names) <= 2 + + # plots + for i, initial in enumerate(zip(*list(initials.values()))): + # legend + legend = f'$traj_{i}$: ' + for j, key in enumerate(self.target_var_names): + legend += f'{key}={initial[j]}, ' + legend = legend[:-2] + + # visualization + start = int(plot_durations[i][0] / dt) + end = int(plot_durations[i][1] / dt) + p1_var = self.target_par_names[0] + if len(self.target_par_names) == 1: + lines = pyplot.plot(mon_res[self.x_var][start: end, i], + mon_res[p1_var][start: end, i], label=legend) + elif len(self.target_par_names) == 2: + p2_var = self.target_par_names[1] + lines = pyplot.plot(mon_res[self.x_var][start: end, i], + mon_res[p1_var][start: end, i], + mon_res[p2_var][start: end, i], + label=legend) + else: + raise ValueError + utils.add_arrow(lines[0]) + + # # visualization of others + # plt.xlabel(self.x_var) + # plt.ylabel(self.target_par_names[0]) + # scale = (self.lim_scale - 1.) / 2 + # plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) + # plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale)) + pyplot.legend() - if show: - pyplot.show() + if show: + pyplot.show() - if with_return: - return mon_res + if with_return: + return mon_res class FastSlow2D(Bifurcation2D): - def __init__( - self, - model, - fast_vars: dict, - slow_vars: dict, - fixed_vars: dict = None, - pars_update: dict = None, - resolutions=0.1, - options: dict = None - ): - super().__init__(model=model, - target_pars=slow_vars, - target_vars=fast_vars, - fixed_vars=fixed_vars, - pars_update=pars_update, - resolutions=resolutions, - options=options) - # standard integrators - self._std_integrators = dict() - for key, intg in self.model.name2integral.items(): - wrap_x = utils.std_derivative(utils.get_args(self.model.name2derivative[key])[1], - self.target_var_names + self.target_par_names, []) - self._std_integrators[key] = partial(wrap_x(self.model.name2integral[key]), - **(self.pars_update + self.fixed_vars)) - - def plot_trajectory(self, initials, duration, plot_durations=None, - dt=None, show=False, with_plot=True, with_return=False): - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am plotting the trajectory ...') - - # check the initial values - initials = utils.check_initials(initials, self.target_var_names + self.target_par_names) - - # 2. format the running duration - assert isinstance(duration, (int, float)) - - # 3. format the plot duration - plot_durations = utils.check_plot_durations(plot_durations, duration, initials) - - # 5. run the network - dt = bm.get_dt() if dt is None else dt - - traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt) - mon_res = traject_model.run(duration=duration) - - if with_plot: - assert len(self.target_par_names) <= 1 - # plots - for i, initial in enumerate(zip(*list(initials.values()))): - # legend - legend = f'$traj_{i}$: ' - for j, key in enumerate(self.target_var_names): - legend += f'{key}={initial[j]}, ' - legend = legend[:-2] - - start = int(plot_durations[i][0] / dt) - end = int(plot_durations[i][1] / dt) + def __init__( + self, + model, + fast_vars: dict, + slow_vars: dict, + fixed_vars: dict = None, + pars_update: dict = None, + resolutions=0.1, + options: dict = None + ): + super().__init__(model=model, + target_pars=slow_vars, + target_vars=fast_vars, + fixed_vars=fixed_vars, + pars_update=pars_update, + resolutions=resolutions, + options=options) + # standard integrators + self._std_integrators = dict() + for key, intg in self.model.name2integral.items(): + wrap_x = utils.std_derivative(utils.get_args(self.model.name2derivative[key])[1], + self.target_var_names + self.target_par_names, []) + self._std_integrators[key] = partial(wrap_x(self.model.name2integral[key]), + **(self.pars_update + self.fixed_vars)) + + def plot_trajectory(self, initials, duration, plot_durations=None, + dt=None, show=False, with_plot=True, with_return=False): + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am plotting the trajectory ...') + + # check the initial values + initials = utils.check_initials(initials, self.target_var_names + self.target_par_names) + + # 2. format the running duration + assert isinstance(duration, (int, float)) + + # 3. format the plot duration + plot_durations = utils.check_plot_durations(plot_durations, duration, initials) + + # 5. run the network + dt = bm.get_dt() if dt is None else dt + + traject_model = utils.TrajectModel(initial_vars=initials, integrals=self._std_integrators, dt=dt) + mon_res = traject_model.run(duration=duration) + + if with_plot: + assert len(self.target_par_names) <= 1 + # plots + for i, initial in enumerate(zip(*list(initials.values()))): + # legend + legend = f'$traj_{i}$: ' + for j, key in enumerate(self.target_var_names): + legend += f'{key}={initial[j]}, ' + legend = legend[:-2] + + start = int(plot_durations[i][0] / dt) + end = int(plot_durations[i][1] / dt) + + # visualization + pyplot.figure(self.x_var) + lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i], + mon_res[self.x_var][start: end, i], + label=legend) + utils.add_arrow(lines[0]) + + pyplot.figure(self.y_var) + lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i], + mon_res[self.y_var][start: end, i], + label=legend) + utils.add_arrow(lines[0]) + + pyplot.figure(self.x_var) + pyplot.legend() + pyplot.figure(self.y_var) + pyplot.legend() - # visualization - pyplot.figure(self.x_var) - lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i], - mon_res[self.x_var][start: end, i], - label=legend) - utils.add_arrow(lines[0]) - - pyplot.figure(self.y_var) - lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i], - mon_res[self.y_var][start: end, i], - label=legend) - utils.add_arrow(lines[0]) - - pyplot.figure(self.x_var) - pyplot.legend() - pyplot.figure(self.y_var) - pyplot.legend() - - if show: - pyplot.show() - - if with_return: - return mon_res + if show: + pyplot.show() + + if with_return: + return mon_res diff --git a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py index ac4506ad5..7ccd263db 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py +++ b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- +from copy import deepcopy + import jax import jax.numpy as jnp import numpy as np from jax import vmap -from copy import deepcopy import brainpy.math as bm from brainpy import errors, math from brainpy._src.analysis import stability, plotstyle, constants as C, utils @@ -14,511 +15,511 @@ pyplot = None __all__ = [ - 'PhasePlane1D', - 'PhasePlane2D', + 'PhasePlane1D', + 'PhasePlane2D', ] class PhasePlane1D(Num1DAnalyzer): - """Phase plane analyzer for 1D dynamical system. + """Phase plane analyzer for 1D dynamical system. - This class can help users fast check: + This class can help users fast check: - - Vector fields - - Fixed points + - Vector fields + - Fixed points - Parameters:: + Parameters:: - model : Any - A model of the population, the integrator function, - or a list/tuple of integrator functions. - target_vars : dict - The target/dynamical variables. - fixed_vars : dict - The fixed variables. - target_pars : dict, optional - The parameters which can be dynamical varied. - pars_update : dict, optional - The parameters to update. - resolutions : float, dict - - """ - - def __init__(self, - model, - target_vars, - fixed_vars=None, - target_pars=None, - pars_update=None, - resolutions=None, - **kwargs): - if (target_pars is not None) and len(target_pars) > 0: - raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". ' - f'While we detect "target_pars={target_pars}".') - super().__init__(model=model, - target_vars=target_vars, - fixed_vars=fixed_vars, - target_pars=target_pars, - pars_update=pars_update, - resolutions=resolutions, - **kwargs) - # utils.output(f'I am {PhasePlane1D.__name__}.') - - def plot_vector_field(self, show=False, with_plot=True, with_return=False): - """Plot the vector filed.""" - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am creating the vector field ...') - - # Nullcline of the x variable - y_val = self.F_fx(self.resolutions[self.x_var]) - y_val = np.asarray(y_val) - - # visualization - if with_plot: - label = f"d{self.x_var}dt" - x_style = dict(color='lightcoral', alpha=.7, linewidth=4) - pyplot.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label) - pyplot.axhline(0) - pyplot.xlabel(self.x_var) - pyplot.ylabel(label) - pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2)) - pyplot.legend() - if show: pyplot.show() - # return - if with_return: - return y_val - - def plot_fixed_point(self, show=False, with_plot=True, with_return=False): - """Plot the fixed point.""" - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am searching fixed points ...') - - # fixed points and stability analysis - fps, _, pars = self._get_fixed_points(self.resolutions[self.x_var]) - container = {a: [] for a in stability.get_1d_stability_types()} - for i in range(len(fps)): - x = fps[i] - dfdx = self.F_dfxdx(x) - fp_type = stability.stability_analysis(dfdx) - utils.output(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.") - container[fp_type].append(x) - - # visualization - if with_plot: - for fp_type, points in container.items(): - if len(points): - plot_style = deepcopy(plotstyle.plot_schema[fp_type]) - pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type) - pyplot.legend() - if show: - pyplot.show() - - # return - if with_return: - return fps + model : Any + A model of the population, the integrator function, + or a list/tuple of integrator functions. + target_vars : dict + The target/dynamical variables. + fixed_vars : dict + The fixed variables. + target_pars : dict, optional + The parameters which can be dynamical varied. + pars_update : dict, optional + The parameters to update. + resolutions : float, dict + """ -class PhasePlane2D(Num2DAnalyzer): - """Phase plane analyzer for 2D dynamical system. - - Parameters:: - - model : Any - A model of the population, the integrator function, - or a list/tuple of integrator functions. - target_vars : dict - The target/dynamical variables. - fixed_vars : dict - The fixed variables. - target_pars : dict, optional - The parameters which can be dynamical varied. - pars_update : dict, optional - The parameters to update. - resolutions : float, dict - """ - - def __init__(self, - model, - target_vars, - fixed_vars=None, - target_pars=None, - pars_update=None, - resolutions=None, - **kwargs): - if (target_pars is not None) and len(target_pars) > 0: - raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". ' - f'While we detect "target_pars={target_pars}".') - super().__init__(model=model, - target_vars=target_vars, - fixed_vars=fixed_vars, - target_pars=target_pars, - pars_update=pars_update, - resolutions=resolutions, - **kwargs) - - @property - def F_vmap_brentq_fy(self): - if C.F_vmap_brentq_fy not in self.analyzed_results: - f_opt = jax.jit(vmap(utils.jax_brentq(self.F_fy))) - self.analyzed_results[C.F_vmap_brentq_fy] = f_opt - return self.analyzed_results[C.F_vmap_brentq_fy] - - def plot_vector_field(self, with_plot=True, with_return=False, - plot_method='streamplot', plot_style=None, show=False): - """Plot the vector field. + def __init__(self, + model, + target_vars, + fixed_vars=None, + target_pars=None, + pars_update=None, + resolutions=None, + **kwargs): + if (target_pars is not None) and len(target_pars) > 0: + raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". ' + f'While we detect "target_pars={target_pars}".') + super().__init__(model=model, + target_vars=target_vars, + fixed_vars=fixed_vars, + target_pars=target_pars, + pars_update=pars_update, + resolutions=resolutions, + **kwargs) + # utils.output(f'I am {PhasePlane1D.__name__}.') + + def plot_vector_field(self, show=False, with_plot=True, with_return=False): + """Plot the vector filed.""" + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am creating the vector field ...') + + # Nullcline of the x variable + y_val = self.F_fx(self.resolutions[self.x_var]) + y_val = np.asarray(y_val) - Parameters:: + # visualization + if with_plot: + label = f"d{self.x_var}dt" + x_style = dict(color='lightcoral', alpha=.7, linewidth=4) + pyplot.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label) + pyplot.axhline(0) + pyplot.xlabel(self.x_var) + pyplot.ylabel(label) + pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2)) + pyplot.legend() + if show: pyplot.show() + # return + if with_return: + return y_val + + def plot_fixed_point(self, show=False, with_plot=True, with_return=False): + """Plot the fixed point.""" + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am searching fixed points ...') + + # fixed points and stability analysis + fps, _, pars = self._get_fixed_points(self.resolutions[self.x_var]) + container = {a: [] for a in stability.get_1d_stability_types()} + for i in range(len(fps)): + x = fps[i] + dfdx = self.F_dfxdx(x) + fp_type = stability.stability_analysis(dfdx) + utils.output(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.") + container[fp_type].append(x) + + # visualization + if with_plot: + for fp_type, points in container.items(): + if len(points): + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) + pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type) + pyplot.legend() + if show: + pyplot.show() + + # return + if with_return: + return fps - with_plot: bool - with_return : bool - show : bool - plot_method : str - The method to plot the vector filed. It can be "streamplot" or "quiver". - plot_style : dict, optional - The style for vector filed plotting. - - - For ``plot_method="streamplot"``, it can set the keywords like "density", - "linewidth", "color", "arrowsize". More settings please check - https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html. - - For ``plot_method="quiver"``, it can set the keywords like "color", - "units", "angles", "scale". More settings please check - https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html. - """ - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am creating the vector field ...') - - # get vector fields - xs = self.resolutions[self.x_var] - ys = self.resolutions[self.y_var] - X, Y = jnp.meshgrid(xs, ys) - dx = self.F_fx(X, Y) - dy = self.F_fy(X, Y) - X, Y = np.asarray(X), np.asarray(Y) - dx, dy = np.asarray(dx), np.asarray(dy) - - if with_plot: # plot vector fields - if plot_method == 'quiver': - if plot_style is None: - plot_style = dict(units='xy') - if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): - speed = np.sqrt(dx ** 2 + dy ** 2) - dx = dx / speed - dy = dy / speed - pyplot.quiver(X, Y, dx, dy, **plot_style) - elif plot_method == 'streamplot': - if plot_style is None: - plot_style = dict(arrowsize=1.2, density=1, color='thistle') - linewidth = plot_style.get('linewidth', None) - if linewidth is None: - if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): - min_width, max_width = 0.5, 5.5 - speed = np.nan_to_num(np.sqrt(dx ** 2 + dy ** 2)) - linewidth = min_width + max_width * (speed / speed.max()) - pyplot.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style) - else: - raise errors.AnalyzerError(f'Unknown plot_method "{plot_method}", ' - f'only supports "quiver" and "streamplot".') - - pyplot.xlabel(self.x_var) - pyplot.ylabel(self.y_var) - if show: - pyplot.show() - - if with_return: # return vector fields - return dx, dy - - def plot_nullcline(self, with_plot=True, with_return=False, - y_style=None, x_style=None, show=False, - coords=None, tol_nullcline=1e-7): - """Plot the nullcline.""" - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am computing fx-nullcline ...') - - if coords is None: - coords = dict() - x_coord = coords.get(self.x_var, None) - y_coord = coords.get(self.y_var, None) - - # Nullcline of the x variable - xy_values_in_fx, = self._get_fx_nullcline_points(coords=x_coord, tol=tol_nullcline) - x_values_in_fx = np.asarray(xy_values_in_fx[:, 0]) - y_values_in_fx = np.asarray(xy_values_in_fx[:, 1]) - - if with_plot: - if x_style is None: - x_style = dict(color='cornflowerblue', alpha=.7, fmt='.') - line_args = (x_style.pop('fmt'),) if 'fmt' in x_style else tuple() - pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline") - - # Nullcline of the y variable - utils.output('I am computing fy-nullcline ...') - xy_values_in_fy, = self._get_fy_nullcline_points(coords=y_coord, tol=tol_nullcline) - x_values_in_fy = np.asarray(xy_values_in_fy[:, 0]) - y_values_in_fy = np.asarray(xy_values_in_fy[:, 1]) - - if with_plot: - if y_style is None: - y_style = dict(color='lightcoral', alpha=.7, fmt='.') - line_args = (y_style.pop('fmt'),) if 'fmt' in y_style else tuple() - pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline") - - if with_plot: - pyplot.xlabel(self.x_var) - pyplot.ylabel(self.y_var) - scale = (self.lim_scale - 1.) / 2 - pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) - pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) - pyplot.legend() - if show: - pyplot.show() - - if with_return: - return {self.x_var: (x_values_in_fx, y_values_in_fx), - self.y_var: (x_values_in_fy, y_values_in_fy)} - - def plot_fixed_point(self, with_plot=True, with_return=False, show=False, - tol_unique=1e-2, tol_aux=1e-8, tol_opt_screen=None, - select_candidates='fx-nullcline', num_rank=100, ): - """Plot the fixed point and analyze its stability. - """ - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am searching fixed points ...') - - if self._can_convert_to_one_eq(): - if self.convert_type() == C.x_by_y: - candidates = bm.as_jax(self.resolutions[self.y_var]) - else: - candidates = bm.as_jax(self.resolutions[self.x_var]) - else: - if select_candidates == 'fx-nullcline': - candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() - if key.startswith(C.fx_nullcline_points)] - if len(candidates) == 0: - raise errors.AnalyzerError(f'No nullcline points are found, please call ' - f'".{self.plot_nullcline.__name__}()" first.') - candidates = jnp.vstack(candidates) - elif select_candidates == 'fy-nullcline': - candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() - if key.startswith(C.fy_nullcline_points)] - if len(candidates) == 0: - raise errors.AnalyzerError(f'No nullcline points are found, please call ' - f'".{self.plot_nullcline.__name__}()" first.') - candidates = jnp.vstack(candidates) - elif select_candidates == 'nullclines': - candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() - if key.startswith(C.fy_nullcline_points) or key.startswith(C.fy_nullcline_points)] - if len(candidates) == 0: - raise errors.AnalyzerError(f'No nullcline points are found, please call ' - f'".{self.plot_nullcline.__name__}()" first.') - candidates = jnp.vstack(candidates) - elif select_candidates == 'aux_rank': - candidates, _ = self._get_fp_candidates_by_aux_rank(num_rank=num_rank) - else: - raise ValueError - - # get fixed points - if len(candidates): - fixed_points, _, _ = self._get_fixed_points(jnp.asarray(candidates), - tol_aux=tol_aux, - tol_unique=tol_unique, - tol_opt_candidate=tol_opt_screen) - utils.output('I am trying to filter out duplicate fixed points ...') - fixed_points = np.asarray(fixed_points) - fixed_points, _ = utils.keep_unique(fixed_points, tolerance=tol_unique) - utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.') - else: - utils.output(f'{C.prefix}Found no fixed points.') - return - - # stability analysis - # ------------------ - container = {a: {'x': [], 'y': []} for a in stability.get_2d_stability_types()} - for i in range(len(fixed_points)): - x = fixed_points[i, 0] - y = fixed_points[i, 1] - fp_type = stability.stability_analysis(self.F_jacobian(x, y)) - utils.output(f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.") - container[fp_type]['x'].append(x) - container[fp_type]['y'].append(y) - - # visualization - # ------------- - if with_plot: - for fp_type, points in container.items(): - if len(points['x']): - plot_style = deepcopy(plotstyle.plot_schema[fp_type]) - pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type) - pyplot.legend() - if show: - pyplot.show() - - if with_return: - return fixed_points - - def plot_trajectory(self, initials, duration, plot_durations=None, axes='v-v', - dt=None, show=False, with_plot=True, with_return=False, **kwargs): - """Plot trajectories according to the settings. + +class PhasePlane2D(Num2DAnalyzer): + """Phase plane analyzer for 2D dynamical system. Parameters:: - initials : list, tuple, dict - The initial value setting of the targets. It can be a tuple/list of floats to specify - each value of dynamical variables (for example, ``(a, b)``). It can also be a - tuple/list of tuple to specify multiple initial values (for example, - ``[(a1, b1), (a2, b2)]``). - duration : int, float, tuple, list - The running duration. Same with the ``duration`` in ``NeuGroup.run()``. - - - It can be a int/float (``t_end``) to specify the same running end time, - - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify - the start and end simulation time. - - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) - to specify the specific start and end simulation time for each initial value. - plot_durations : tuple, list, optional - The duration to plot. It can be a tuple with ``(start, end)``. It can - also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify - the plot duration for each initial value running. - axes : str - The axes to plot. It can be: - - - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis. - - 't-v': Plot the trajectory in the 'time'-'var' axis. - show : bool - Whether show or not. + model : Any + A model of the population, the integrator function, + or a list/tuple of integrator functions. + target_vars : dict + The target/dynamical variables. + fixed_vars : dict + The fixed variables. + target_pars : dict, optional + The parameters which can be dynamical varied. + pars_update : dict, optional + The parameters to update. + resolutions : float, dict """ - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am plotting the trajectory ...') - - if axes not in ['v-v', 't-v']: - raise errors.AnalyzerError(f'Unknown axes "{axes}", only support "v-v" and "t-v".') - - # check the initial values - initials = utils.check_initials(initials, self.target_var_names) - - # 2. format the running duration - assert isinstance(duration, (int, float)) - - # 3. format the plot duration - plot_durations = utils.check_plot_durations(plot_durations, duration, initials) - - # 5. run the network - dt = math.get_dt() if dt is None else dt - traject_model = utils.TrajectModel( - initial_vars=initials, - integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, - dt=dt) - mon_res = traject_model.run(duration=duration) - - if with_plot: - # plots - for i, initial in enumerate(zip(*list(initials.values()))): - # legend - legend = f'$traj_{i}$: ' - for j, key in enumerate(self.target_var_names): - legend += f'{key}={round(float(initial[j]), 4)}, ' - legend = legend[:-2] - # visualization - start = int(plot_durations[i][0] / dt) - end = int(plot_durations[i][1] / dt) - if axes == 'v-v': - lines = pyplot.plot(mon_res[self.x_var][start: end, i], - mon_res[self.y_var][start: end, i], - label=legend, **kwargs) - utils.add_arrow(lines[0]) + def __init__(self, + model, + target_vars, + fixed_vars=None, + target_pars=None, + pars_update=None, + resolutions=None, + **kwargs): + if (target_pars is not None) and len(target_pars) > 0: + raise errors.AnalyzerError(f'Phase plane analysis does not support "target_pars". ' + f'While we detect "target_pars={target_pars}".') + super().__init__(model=model, + target_vars=target_vars, + fixed_vars=fixed_vars, + target_pars=target_pars, + pars_update=pars_update, + resolutions=resolutions, + **kwargs) + + @property + def F_vmap_brentq_fy(self): + if C.F_vmap_brentq_fy not in self.analyzed_results: + f_opt = jax.jit(vmap(utils.jax_brentq(self.F_fy))) + self.analyzed_results[C.F_vmap_brentq_fy] = f_opt + return self.analyzed_results[C.F_vmap_brentq_fy] + + def plot_vector_field(self, with_plot=True, with_return=False, + plot_method='streamplot', plot_style=None, show=False): + """Plot the vector field. + + Parameters:: + + with_plot: bool + with_return : bool + show : bool + plot_method : str + The method to plot the vector filed. It can be "streamplot" or "quiver". + plot_style : dict, optional + The style for vector filed plotting. + + - For ``plot_method="streamplot"``, it can set the keywords like "density", + "linewidth", "color", "arrowsize". More settings please check + https://matplotlib.org/api/_as_gen/matplotlib.pyplot.streamplot.html. + - For ``plot_method="quiver"``, it can set the keywords like "color", + "units", "angles", "scale". More settings please check + https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html. + """ + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am creating the vector field ...') + + # get vector fields + xs = self.resolutions[self.x_var] + ys = self.resolutions[self.y_var] + X, Y = jnp.meshgrid(xs, ys) + dx = self.F_fx(X, Y) + dy = self.F_fy(X, Y) + X, Y = np.asarray(X), np.asarray(Y) + dx, dy = np.asarray(dx), np.asarray(dy) + + if with_plot: # plot vector fields + if plot_method == 'quiver': + if plot_style is None: + plot_style = dict(units='xy') + if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): + speed = np.sqrt(dx ** 2 + dy ** 2) + dx = dx / speed + dy = dy / speed + pyplot.quiver(X, Y, dx, dy, **plot_style) + elif plot_method == 'streamplot': + if plot_style is None: + plot_style = dict(arrowsize=1.2, density=1, color='thistle') + linewidth = plot_style.get('linewidth', None) + if linewidth is None: + if (not np.isnan(dx).any()) and (not np.isnan(dy).any()): + min_width, max_width = 0.5, 5.5 + speed = np.nan_to_num(np.sqrt(dx ** 2 + dy ** 2)) + linewidth = min_width + max_width * (speed / speed.max()) + pyplot.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style) + else: + raise errors.AnalyzerError(f'Unknown plot_method "{plot_method}", ' + f'only supports "quiver" and "streamplot".') + + pyplot.xlabel(self.x_var) + pyplot.ylabel(self.y_var) + if show: + pyplot.show() + + if with_return: # return vector fields + return dx, dy + + def plot_nullcline(self, with_plot=True, with_return=False, + y_style=None, x_style=None, show=False, + coords=None, tol_nullcline=1e-7): + """Plot the nullcline.""" + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am computing fx-nullcline ...') + + if coords is None: + coords = dict() + x_coord = coords.get(self.x_var, None) + y_coord = coords.get(self.y_var, None) + + # Nullcline of the x variable + xy_values_in_fx, = self._get_fx_nullcline_points(coords=x_coord, tol=tol_nullcline) + x_values_in_fx = np.asarray(xy_values_in_fx[:, 0]) + y_values_in_fx = np.asarray(xy_values_in_fx[:, 1]) + + if with_plot: + if x_style is None: + x_style = dict(color='cornflowerblue', alpha=.7, fmt='.') + line_args = (x_style.pop('fmt'),) if 'fmt' in x_style else tuple() + pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline") + + # Nullcline of the y variable + utils.output('I am computing fy-nullcline ...') + xy_values_in_fy, = self._get_fy_nullcline_points(coords=y_coord, tol=tol_nullcline) + x_values_in_fy = np.asarray(xy_values_in_fy[:, 0]) + y_values_in_fy = np.asarray(xy_values_in_fy[:, 1]) + + if with_plot: + if y_style is None: + y_style = dict(color='lightcoral', alpha=.7, fmt='.') + line_args = (y_style.pop('fmt'),) if 'fmt' in y_style else tuple() + pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline") + + if with_plot: + pyplot.xlabel(self.x_var) + pyplot.ylabel(self.y_var) + scale = (self.lim_scale - 1.) / 2 + pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) + pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) + pyplot.legend() + if show: + pyplot.show() + + if with_return: + return {self.x_var: (x_values_in_fx, y_values_in_fx), + self.y_var: (x_values_in_fy, y_values_in_fy)} + + def plot_fixed_point(self, with_plot=True, with_return=False, show=False, + tol_unique=1e-2, tol_aux=1e-8, tol_opt_screen=None, + select_candidates='fx-nullcline', num_rank=100, ): + """Plot the fixed point and analyze its stability. + """ + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am searching fixed points ...') + + if self._can_convert_to_one_eq(): + if self.convert_type() == C.x_by_y: + candidates = bm.as_jax(self.resolutions[self.y_var]) + else: + candidates = bm.as_jax(self.resolutions[self.x_var]) else: - pyplot.plot(mon_res.ts[start: end], - mon_res[self.x_var][start: end, i], - label=legend + f', {self.x_var}', **kwargs) - pyplot.plot(mon_res.ts[start: end], - mon_res[self.y_var][start: end, i], - label=legend + f', {self.y_var}', **kwargs) - - # visualization of others - if axes == 'v-v': + if select_candidates == 'fx-nullcline': + candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() + if key.startswith(C.fx_nullcline_points)] + if len(candidates) == 0: + raise errors.AnalyzerError(f'No nullcline points are found, please call ' + f'".{self.plot_nullcline.__name__}()" first.') + candidates = jnp.vstack(candidates) + elif select_candidates == 'fy-nullcline': + candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() + if key.startswith(C.fy_nullcline_points)] + if len(candidates) == 0: + raise errors.AnalyzerError(f'No nullcline points are found, please call ' + f'".{self.plot_nullcline.__name__}()" first.') + candidates = jnp.vstack(candidates) + elif select_candidates == 'nullclines': + candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() + if key.startswith(C.fy_nullcline_points) or key.startswith(C.fy_nullcline_points)] + if len(candidates) == 0: + raise errors.AnalyzerError(f'No nullcline points are found, please call ' + f'".{self.plot_nullcline.__name__}()" first.') + candidates = jnp.vstack(candidates) + elif select_candidates == 'aux_rank': + candidates, _ = self._get_fp_candidates_by_aux_rank(num_rank=num_rank) + else: + raise ValueError + + # get fixed points + if len(candidates): + fixed_points, _, _ = self._get_fixed_points(jnp.asarray(candidates), + tol_aux=tol_aux, + tol_unique=tol_unique, + tol_opt_candidate=tol_opt_screen) + utils.output('I am trying to filter out duplicate fixed points ...') + fixed_points = np.asarray(fixed_points) + fixed_points, _ = utils.keep_unique(fixed_points, tolerance=tol_unique) + utils.output(f'{C.prefix}Found {len(fixed_points)} fixed points.') + else: + utils.output(f'{C.prefix}Found no fixed points.') + return + + # stability analysis + # ------------------ + container = {a: {'x': [], 'y': []} for a in stability.get_2d_stability_types()} + for i in range(len(fixed_points)): + x = fixed_points[i, 0] + y = fixed_points[i, 1] + fp_type = stability.stability_analysis(self.F_jacobian(x, y)) + utils.output(f"{C.prefix}#{i + 1} {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.") + container[fp_type]['x'].append(x) + container[fp_type]['y'].append(y) + + # visualization + # ------------- + if with_plot: + for fp_type, points in container.items(): + if len(points['x']): + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) + pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type) + pyplot.legend() + if show: + pyplot.show() + + if with_return: + return fixed_points + + def plot_trajectory(self, initials, duration, plot_durations=None, axes='v-v', + dt=None, show=False, with_plot=True, with_return=False, **kwargs): + """Plot trajectories according to the settings. + + Parameters:: + + initials : list, tuple, dict + The initial value setting of the targets. It can be a tuple/list of floats to specify + each value of dynamical variables (for example, ``(a, b)``). It can also be a + tuple/list of tuple to specify multiple initial values (for example, + ``[(a1, b1), (a2, b2)]``). + duration : int, float, tuple, list + The running duration. Same with the ``duration`` in ``NeuGroup.run()``. + + - It can be a int/float (``t_end``) to specify the same running end time, + - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify + the start and end simulation time. + - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) + to specify the specific start and end simulation time for each initial value. + plot_durations : tuple, list, optional + The duration to plot. It can be a tuple with ``(start, end)``. It can + also be a list of tuple ``[(start1, end1), (start2, end2)]`` to specify + the plot duration for each initial value running. + axes : str + The axes to plot. It can be: + + - 'v-v': Plot the trajectory in the 'x_var'-'y_var' axis. + - 't-v': Plot the trajectory in the 'time'-'var' axis. + show : bool + Whether show or not. + """ + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am plotting the trajectory ...') + + if axes not in ['v-v', 't-v']: + raise errors.AnalyzerError(f'Unknown axes "{axes}", only support "v-v" and "t-v".') + + # check the initial values + initials = utils.check_initials(initials, self.target_var_names) + + # 2. format the running duration + assert isinstance(duration, (int, float)) + + # 3. format the plot duration + plot_durations = utils.check_plot_durations(plot_durations, duration, initials) + + # 5. run the network + dt = math.get_dt() if dt is None else dt + traject_model = utils.TrajectModel( + initial_vars=initials, + integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, + dt=dt) + mon_res = traject_model.run(duration=duration) + + if with_plot: + # plots + for i, initial in enumerate(zip(*list(initials.values()))): + # legend + legend = f'$traj_{i}$: ' + for j, key in enumerate(self.target_var_names): + legend += f'{key}={round(float(initial[j]), 4)}, ' + legend = legend[:-2] + + # visualization + start = int(plot_durations[i][0] / dt) + end = int(plot_durations[i][1] / dt) + if axes == 'v-v': + lines = pyplot.plot(mon_res[self.x_var][start: end, i], + mon_res[self.y_var][start: end, i], + label=legend, **kwargs) + utils.add_arrow(lines[0]) + else: + pyplot.plot(mon_res.ts[start: end], + mon_res[self.x_var][start: end, i], + label=legend + f', {self.x_var}', **kwargs) + pyplot.plot(mon_res.ts[start: end], + mon_res[self.y_var][start: end, i], + label=legend + f', {self.y_var}', **kwargs) + + # visualization of others + if axes == 'v-v': + pyplot.xlabel(self.x_var) + pyplot.ylabel(self.y_var) + scale = (self.lim_scale - 1.) / 2 + pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) + pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) + pyplot.legend() + else: + pyplot.legend(title='Initial values') + + if show: + pyplot.show() + + if with_return: + return mon_res + + def plot_limit_cycle_by_sim(self, initials, duration, tol=0.01, show=False, dt=None): + """Plot trajectories according to the settings. + + Parameters:: + + initials : list, tuple + The initial value setting of the targets. + + - It can be a tuple/list of floats to specify each value of dynamical variables + (for example, ``(a, b)``). + - It can also be a tuple/list of tuple to specify multiple initial values (for + example, ``[(a1, b1), (a2, b2)]``). + duration : int, float, tuple, list + The running duration. Same with the ``duration`` in ``NeuGroup.run()``. + + - It can be a int/float (``t_end``) to specify the same running end time, + - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify + the start and end simulation time. + - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) + to specify the specific start and end simulation time for each initial value. + show : bool + Whether show or not. + """ + global pyplot + if pyplot is None: from matplotlib import pyplot + utils.output('I am plotting the limit cycle ...') + + # 1. format the initial values + initials = utils.check_initials(initials, self.target_var_names) + + # 2. format the running duration + assert isinstance(duration, (int, float)) + + dt = math.get_dt() if dt is None else dt + traject_model = utils.TrajectModel( + initial_vars=initials, + integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, + dt=dt) + mon_res = traject_model.run(duration=duration) + + # 5. run the network + for init_i, initial in enumerate(zip(*list(initials.values()))): + # 5.2 run the model + x_data = mon_res[self.x_var][:, init_i] + y_data = mon_res[self.y_var][:, init_i] + max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol) + if max_index[0] != -1: + x_cycle = x_data[max_index[0]: max_index[1]] + y_cycle = y_data[max_index[0]: max_index[1]] + # 5.5 visualization + lines = pyplot.plot(x_cycle, y_cycle, label='limit cycle') + utils.add_arrow(lines[0]) + else: + utils.output(f'No limit cycle found for initial value {initial}') + + # 6. visualization pyplot.xlabel(self.x_var) pyplot.ylabel(self.y_var) scale = (self.lim_scale - 1.) / 2 pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) pyplot.legend() - else: - pyplot.legend(title='Initial values') - - if show: - pyplot.show() - - if with_return: - return mon_res - def plot_limit_cycle_by_sim(self, initials, duration, tol=0.01, show=False, dt=None): - """Plot trajectories according to the settings. - - Parameters:: - - initials : list, tuple - The initial value setting of the targets. - - - It can be a tuple/list of floats to specify each value of dynamical variables - (for example, ``(a, b)``). - - It can also be a tuple/list of tuple to specify multiple initial values (for - example, ``[(a1, b1), (a2, b2)]``). - duration : int, float, tuple, list - The running duration. Same with the ``duration`` in ``NeuGroup.run()``. - - - It can be a int/float (``t_end``) to specify the same running end time, - - Or it can be a tuple/list of int/float (``(t_start, t_end)``) to specify - the start and end simulation time. - - Or, it can be a list of tuple (``[(t1_start, t1_end), (t2_start, t2_end)]``) - to specify the specific start and end simulation time for each initial value. - show : bool - Whether show or not. - """ - global pyplot - if pyplot is None: from matplotlib import pyplot - utils.output('I am plotting the limit cycle ...') - - # 1. format the initial values - initials = utils.check_initials(initials, self.target_var_names) - - # 2. format the running duration - assert isinstance(duration, (int, float)) - - dt = math.get_dt() if dt is None else dt - traject_model = utils.TrajectModel( - initial_vars=initials, - integrals={self.x_var: self.F_int_x, self.y_var: self.F_int_y}, - dt=dt) - mon_res = traject_model.run(duration=duration) - - # 5. run the network - for init_i, initial in enumerate(zip(*list(initials.values()))): - # 5.2 run the model - x_data = mon_res[self.x_var][:, init_i] - y_data = mon_res[self.y_var][:, init_i] - max_index = utils.find_indexes_of_limit_cycle_max(x_data, tol=tol) - if max_index[0] != -1: - x_cycle = x_data[max_index[0]: max_index[1]] - y_cycle = y_data[max_index[0]: max_index[1]] - # 5.5 visualization - lines = pyplot.plot(x_cycle, y_cycle, label='limit cycle') - utils.add_arrow(lines[0]) - else: - utils.output(f'No limit cycle found for initial value {initial}') - - # 6. visualization - pyplot.xlabel(self.x_var) - pyplot.ylabel(self.y_var) - scale = (self.lim_scale - 1.) / 2 - pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) - pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) - pyplot.legend() - - if show: - pyplot.show() + if show: + pyplot.show() diff --git a/brainpy/_src/analysis/lowdim/tests/test_bifurcation.py b/brainpy/_src/analysis/lowdim/tests/test_bifurcation.py index 5eaf1631b..43acca98c 100644 --- a/brainpy/_src/analysis/lowdim/tests/test_bifurcation.py +++ b/brainpy/_src/analysis/lowdim/tests/test_bifurcation.py @@ -2,6 +2,7 @@ import pytest + pytest.skip('Test cannot pass in github action.', allow_module_level=True) import unittest @@ -13,77 +14,77 @@ class FitzHughNagumoModel(bp.DynamicalSystem): - def __init__(self, method='exp_auto'): - super(FitzHughNagumoModel, self).__init__() + def __init__(self, method='exp_auto'): + super(FitzHughNagumoModel, self).__init__() - # parameters - self.a = 0.7 - self.b = 0.8 - self.tau = 12.5 + # parameters + self.a = 0.7 + self.b = 0.8 + self.tau = 12.5 - # variables - self.V = bm.Variable(bm.zeros(1)) - self.w = bm.Variable(bm.zeros(1)) - self.Iext = bm.Variable(bm.zeros(1)) + # variables + self.V = bm.Variable(bm.zeros(1)) + self.w = bm.Variable(bm.zeros(1)) + self.Iext = bm.Variable(bm.zeros(1)) - # functions - def dV(V, t, w, Iext=0.): - dV = V - V * V * V / 3 - w + Iext - return dV + # functions + def dV(V, t, w, Iext=0.): + dV = V - V * V * V / 3 - w + Iext + return dV - def dw(w, t, V, a=0.7, b=0.8): - dw = (V + a - b * w) / self.tau - return dw + def dw(w, t, V, a=0.7, b=0.8): + dw = (V + a - b * w) / self.tau + return dw - self.int_V = bp.odeint(dV, method=method) - self.int_w = bp.odeint(dw, method=method) + self.int_V = bp.odeint(dV, method=method) + self.int_w = bp.odeint(dw, method=method) - def update(self, tdi): - t, dt = tdi['t'], tdi['dt'] - self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt) - self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt) - self.Iext[:] = 0. + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] + self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt) + self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt) + self.Iext[:] = 0. class TestBifurcation1D(unittest.TestCase): - def test_bifurcation_1d(self): - bp.math.enable_x64() - - @bp.odeint - def int_x(x, t, a=1., b=1.): - return bp.math.sin(a * x) + bp.math.cos(b * x) - - pp = bp.analysis.PhasePlane1D( - model=int_x, - target_vars={'x': [-bp.math.pi, bp.math.pi]}, - resolutions=0.1 - ) - pp.plot_vector_field() - pp.plot_fixed_point(show=True) - - bf = bp.analysis.Bifurcation1D( - model=int_x, - target_vars={'x': [-bp.math.pi, bp.math.pi]}, - target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]}, - resolutions={'a': 0.1, 'b': 0.1} - ) - bf.plot_bifurcation(show=False) - plt.show(block=block) - plt.close() - bp.math.disable_x64() - - def test_bifurcation_2d(self): - bp.math.enable_x64() - - model = FitzHughNagumoModel() - bif = bp.analysis.Bifurcation2D( - model=model, - target_vars={'V': [-3., 3.], 'w': [-1, 3.]}, - target_pars={'Iext': [0., 1.]}, - resolutions={'Iext': 0.1} - ) - bif.plot_bifurcation() - bif.plot_limit_cycle_by_sim() - plt.show(block=block) - - # bp.math.disable_x64() + def test_bifurcation_1d(self): + bp.math.enable_x64() + + @bp.odeint + def int_x(x, t, a=1., b=1.): + return bp.math.sin(a * x) + bp.math.cos(b * x) + + pp = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={'x': [-bp.math.pi, bp.math.pi]}, + resolutions=0.1 + ) + pp.plot_vector_field() + pp.plot_fixed_point(show=True) + + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-bp.math.pi, bp.math.pi]}, + target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]}, + resolutions={'a': 0.1, 'b': 0.1} + ) + bf.plot_bifurcation(show=False) + plt.show(block=block) + plt.close() + bp.math.disable_x64() + + def test_bifurcation_2d(self): + bp.math.enable_x64() + + model = FitzHughNagumoModel() + bif = bp.analysis.Bifurcation2D( + model=model, + target_vars={'V': [-3., 3.], 'w': [-1, 3.]}, + target_pars={'Iext': [0., 1.]}, + resolutions={'Iext': 0.1} + ) + bif.plot_bifurcation() + bif.plot_limit_cycle_by_sim() + plt.show(block=block) + + # bp.math.disable_x64() diff --git a/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py b/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py index 5b666fbbe..df9d46d65 100644 --- a/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py +++ b/brainpy/_src/analysis/lowdim/tests/test_phase_plane.py @@ -2,80 +2,80 @@ import unittest -import brainpy as bp -import matplotlib.pyplot as plt import jax.numpy as jnp +import matplotlib.pyplot as plt +import brainpy as bp show = False class TestPhasePlane(unittest.TestCase): - def test_1d(self): - bp.math.enable_x64() + def test_1d(self): + bp.math.enable_x64() - @bp.odeint - def int_x(x, t, Iext): - dx = x ** 3 - x + Iext - return dx + @bp.odeint + def int_x(x, t, Iext): + dx = x ** 3 - x + Iext + return dx - analyzer = bp.analysis.PhasePlane1D(model=int_x, - target_vars={'x': [-2, 2]}, - pars_update={'Iext': 0.}, - resolutions=0.01) + analyzer = bp.analysis.PhasePlane1D(model=int_x, + target_vars={'x': [-2, 2]}, + pars_update={'Iext': 0.}, + resolutions=0.01) - plt.ion() - analyzer.plot_vector_field() - analyzer.plot_fixed_point() - if show: - plt.show() - plt.close() - bp.math.disable_x64() + plt.ion() + analyzer.plot_vector_field() + analyzer.plot_fixed_point() + if show: + plt.show() + plt.close() + bp.math.disable_x64() - def test_2d_decision_making_model(self): - bp.math.enable_x64() - gamma = 0.641 # Saturation factor for gating variable - tau = 0.06 # Synaptic time constant [sec] - tau0 = 0.002 # Noise time constant [sec] - a = 270. - b = 108. - d = 0.154 + def test_2d_decision_making_model(self): + bp.math.enable_x64() + gamma = 0.641 # Saturation factor for gating variable + tau = 0.06 # Synaptic time constant [sec] + tau0 = 0.002 # Noise time constant [sec] + a = 270. + b = 108. + d = 0.154 - I0 = 0.3255 # background current [nA] - JE = 0.3725 # self-coupling strength [nA] - JI = -0.1137 # cross-coupling strength [nA] - JAext = 0.00117 # Stimulus input strength [nA] - sigma = 1.02 # nA + I0 = 0.3255 # background current [nA] + JE = 0.3725 # self-coupling strength [nA] + JI = -0.1137 # cross-coupling strength [nA] + JAext = 0.00117 # Stimulus input strength [nA] + sigma = 1.02 # nA - mu0 = 40. # Stimulus firing rate [spikes/sec] - coh = 0.5 # # Stimulus coherence [%] - Ib1 = 0.3297 - Ib2 = 0.3297 + mu0 = 40. # Stimulus firing rate [spikes/sec] + coh = 0.5 # # Stimulus coherence [%] + Ib1 = 0.3297 + Ib2 = 0.3297 - @bp.odeint - def int_s1(s1, t, s2, gamma=0.641): - I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu0 * (1. + coh) - r1 = (a * I1 - b) / (1. - jnp.exp(-d * (a * I1 - b))) - ds1dt = - s1 / tau + (1. - s1) * gamma * r1 - return ds1dt + @bp.odeint + def int_s1(s1, t, s2, gamma=0.641): + I1 = JE * s1 + JI * s2 + Ib1 + JAext * mu0 * (1. + coh) + r1 = (a * I1 - b) / (1. - jnp.exp(-d * (a * I1 - b))) + ds1dt = - s1 / tau + (1. - s1) * gamma * r1 + return ds1dt - @bp.odeint - def int_s2(s2, t, s1): - I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu0 * (1. - coh) - r2 = (a * I2 - b) / (1. - jnp.exp(-d * (a * I2 - b))) - ds2dt = - s2 / tau + (1. - s2) * gamma * r2 - return ds2dt + @bp.odeint + def int_s2(s2, t, s1): + I2 = JE * s2 + JI * s1 + Ib2 + JAext * mu0 * (1. - coh) + r2 = (a * I2 - b) / (1. - jnp.exp(-d * (a * I2 - b))) + ds2dt = - s2 / tau + (1. - s2) * gamma * r2 + return ds2dt - analyzer = bp.analysis.PhasePlane2D( - model=[int_s1, int_s2], - target_vars={'s1': [0, 1], 's2': [0, 1]}, - resolutions=0.001 - ) - plt.ion() - analyzer.plot_vector_field() - analyzer.plot_nullcline(coords=dict(s2='s2-s1')) - analyzer.plot_fixed_point() - if show: - plt.show() - plt.close() - bp.math.disable_x64() + analyzer = bp.analysis.PhasePlane2D( + model=[int_s1, int_s2], + target_vars={'s1': [0, 1], 's2': [0, 1]}, + resolutions=0.001 + ) + plt.ion() + analyzer.plot_vector_field() + analyzer.plot_nullcline(coords=dict(s2='s2-s1')) + analyzer.plot_fixed_point() + if show: + plt.show() + plt.close() + bp.math.disable_x64() diff --git a/brainpy/_src/analysis/plotstyle.py b/brainpy/_src/analysis/plotstyle.py index e59ea6435..6c99466d2 100644 --- a/brainpy/_src/analysis/plotstyle.py +++ b/brainpy/_src/analysis/plotstyle.py @@ -2,9 +2,9 @@ __all__ = [ - 'plot_schema', - 'set_plot_schema', - 'set_markersize', + 'plot_schema', + 'set_plot_schema', + 'set_markersize', ] from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D, @@ -12,11 +12,10 @@ STABLE_FOCUS_2D, STABLE_STAR_2D, STABLE_DEGENERATE_2D, UNSTABLE_NODE_2D, UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D, UNSTABLE_DEGENERATE_2D, UNSTABLE_LINE_2D, - STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D, - UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D, + STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D, + UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D, UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D) - _markersize = 10 plot_schema = {} @@ -28,46 +27,43 @@ plot_schema[UNSTABLE_POINT_1D] = {"color": 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'} plot_schema.update({ - CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, }) - plot_schema.update({ - STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, - UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, + UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}, }) def set_plot_schema(fixed_point: str, **schema): - if not isinstance(fixed_point, str): - raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}') - if fixed_point not in plot_schema: - raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ') - plot_schema[fixed_point].update(**schema) + if not isinstance(fixed_point, str): + raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}') + if fixed_point not in plot_schema: + raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ') + plot_schema[fixed_point].update(**schema) def set_markersize(markersize): - if not isinstance(markersize, int): - raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}") - global _markersize - __markersize = markersize - for key in tuple(plot_schema.keys()): - plot_schema[key]['markersize'] = markersize - - + if not isinstance(markersize, int): + raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}") + global _markersize + __markersize = markersize + for key in tuple(plot_schema.keys()): + plot_schema[key]['markersize'] = markersize diff --git a/brainpy/_src/analysis/stability.py b/brainpy/_src/analysis/stability.py index 532c8d9c6..a4f98c456 100644 --- a/brainpy/_src/analysis/stability.py +++ b/brainpy/_src/analysis/stability.py @@ -3,32 +3,29 @@ import numpy as np __all__ = [ - 'get_1d_stability_types', - 'get_2d_stability_types', - 'get_3d_stability_types', - - - 'stability_analysis', - - 'CENTER_MANIFOLD', - 'SADDLE_NODE', - 'STABLE_POINT_1D', - 'UNSTABLE_POINT_1D', - - 'CENTER_2D', - 'STABLE_NODE_2D', - 'STABLE_FOCUS_2D', - 'STABLE_STAR_2D', - 'STABLE_DEGENERATE_2D', - 'UNSTABLE_NODE_2D', - 'UNSTABLE_FOCUS_2D', - 'UNSTABLE_STAR_2D', - 'UNSTABLE_DEGENERATE_2D', - 'UNSTABLE_LINE_2D', + 'get_1d_stability_types', + 'get_2d_stability_types', + 'get_3d_stability_types', + + 'stability_analysis', + + 'CENTER_MANIFOLD', + 'SADDLE_NODE', + 'STABLE_POINT_1D', + 'UNSTABLE_POINT_1D', + + 'CENTER_2D', + 'STABLE_NODE_2D', + 'STABLE_FOCUS_2D', + 'STABLE_STAR_2D', + 'STABLE_DEGENERATE_2D', + 'UNSTABLE_NODE_2D', + 'UNSTABLE_FOCUS_2D', + 'UNSTABLE_STAR_2D', + 'UNSTABLE_DEGENERATE_2D', + 'UNSTABLE_LINE_2D', ] - - SADDLE_NODE = 'saddle node' CENTER_MANIFOLD = 'center manifold' @@ -46,7 +43,6 @@ UNSTABLE_DEGENERATE_2D = 'unstable degenerate' UNSTABLE_LINE_2D = 'unstable line' - STABLE_POINT_3D = 'unclassified stable point' UNSTABLE_POINT_3D = 'unclassified unstable point' STABLE_NODE_3D = 'stable node' @@ -59,157 +55,157 @@ def get_1d_stability_types(): - """Get the stability types of 1D system.""" - return [SADDLE_NODE, STABLE_POINT_1D, UNSTABLE_POINT_1D] + """Get the stability types of 1D system.""" + return [SADDLE_NODE, STABLE_POINT_1D, UNSTABLE_POINT_1D] def get_2d_stability_types(): - """Get the stability types of 2D system.""" - return [SADDLE_NODE, CENTER_2D, STABLE_NODE_2D, STABLE_FOCUS_2D, - STABLE_STAR_2D, CENTER_MANIFOLD, UNSTABLE_NODE_2D, - UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D, UNSTABLE_LINE_2D, - STABLE_DEGENERATE_2D, UNSTABLE_DEGENERATE_2D] + """Get the stability types of 2D system.""" + return [SADDLE_NODE, CENTER_2D, STABLE_NODE_2D, STABLE_FOCUS_2D, + STABLE_STAR_2D, CENTER_MANIFOLD, UNSTABLE_NODE_2D, + UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D, UNSTABLE_LINE_2D, + STABLE_DEGENERATE_2D, UNSTABLE_DEGENERATE_2D] def get_3d_stability_types(): - """Get the stability types of 3D system.""" - return [STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D, - UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, SADDLE_NODE, - STABLE_FOCUS_3D, UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D] + """Get the stability types of 3D system.""" + return [STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D, + UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, SADDLE_NODE, + STABLE_FOCUS_3D, UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D] def stability_analysis(derivatives): - """Stability analysis of fixed points for low-dimensional system. + """Stability analysis of fixed points for low-dimensional system. - The analysis is referred to [1]_. + The analysis is referred to [1]_. - Parameters:: + Parameters:: - derivatives : float, tuple, list, np.ndarray - The derivative of the f. + derivatives : float, tuple, list, np.ndarray + The derivative of the f. - Returns:: + Returns:: - fp_type : str - The type of the fixed point. + fp_type : str + The type of the fixed point. - References:: + References:: - .. [1] http://www.egwald.ca/nonlineardynamics/twodimensionaldynamics.php + .. [1] http://www.egwald.ca/nonlineardynamics/twodimensionaldynamics.php - """ - if np.size(derivatives) == 1: # 1D dynamical system - if derivatives == 0: - return SADDLE_NODE - elif derivatives > 0: - return UNSTABLE_POINT_1D - else: - return STABLE_POINT_1D - - elif np.size(derivatives) == 4: # 2D dynamical system - a = derivatives[0][0] - b = derivatives[0][1] - c = derivatives[1][0] - d = derivatives[1][1] - - # trace - p = a + d - # det - q = a * d - b * c - - # judgement - if q < 0: - return SADDLE_NODE - elif q == 0: - if p <= 0: - return CENTER_MANIFOLD - else: - return UNSTABLE_LINE_2D - else: - # parabola - e = p * p - 4 * q - if p == 0: - return CENTER_2D - elif p > 0: - if e < 0: - return UNSTABLE_FOCUS_2D - elif e > 0: - return UNSTABLE_NODE_2D + """ + if np.size(derivatives) == 1: # 1D dynamical system + if derivatives == 0: + return SADDLE_NODE + elif derivatives > 0: + return UNSTABLE_POINT_1D else: - w = np.linalg.eigvals(derivatives) - if w[0] == w[1]: - return UNSTABLE_DEGENERATE_2D - else: - return UNSTABLE_STAR_2D - else: - if e < 0: - return STABLE_FOCUS_2D - elif e > 0: - return STABLE_NODE_2D + return STABLE_POINT_1D + + elif np.size(derivatives) == 4: # 2D dynamical system + a = derivatives[0][0] + b = derivatives[0][1] + c = derivatives[1][0] + d = derivatives[1][1] + + # trace + p = a + d + # det + q = a * d - b * c + + # judgement + if q < 0: + return SADDLE_NODE + elif q == 0: + if p <= 0: + return CENTER_MANIFOLD + else: + return UNSTABLE_LINE_2D else: - w = np.linalg.eigvals(derivatives) - if w[0] == w[1]: - return STABLE_DEGENERATE_2D - else: - return STABLE_STAR_2D - - elif np.size(derivatives) == 9: # 3D dynamical system - eigenvalues = np.linalg.eigvals(np.array(derivatives)) - is_real = np.isreal(eigenvalues) - if is_real.all(): - eigenvalues = np.sort(eigenvalues) - if eigenvalues[2] < 0: - return STABLE_NODE_3D - elif eigenvalues[2] == 0: - return UNKNOWN_3D - else: - if eigenvalues[0] > 0: - return UNSTABLE_NODE_3D - elif eigenvalues[0] == 0: - return UNKNOWN_3D + # parabola + e = p * p - 4 * q + if p == 0: + return CENTER_2D + elif p > 0: + if e < 0: + return UNSTABLE_FOCUS_2D + elif e > 0: + return UNSTABLE_NODE_2D + else: + w = np.linalg.eigvals(derivatives) + if w[0] == w[1]: + return UNSTABLE_DEGENERATE_2D + else: + return UNSTABLE_STAR_2D + else: + if e < 0: + return STABLE_FOCUS_2D + elif e > 0: + return STABLE_NODE_2D + else: + w = np.linalg.eigvals(derivatives) + if w[0] == w[1]: + return STABLE_DEGENERATE_2D + else: + return STABLE_STAR_2D + + elif np.size(derivatives) == 9: # 3D dynamical system + eigenvalues = np.linalg.eigvals(np.array(derivatives)) + is_real = np.isreal(eigenvalues) + if is_real.all(): + eigenvalues = np.sort(eigenvalues) + if eigenvalues[2] < 0: + return STABLE_NODE_3D + elif eigenvalues[2] == 0: + return UNKNOWN_3D + else: + if eigenvalues[0] > 0: + return UNSTABLE_NODE_3D + elif eigenvalues[0] == 0: + return UNKNOWN_3D + else: + if eigenvalues[1] < 0: + return SADDLE_NODE + elif eigenvalues[1] == 0: + return UNKNOWN_3D + else: + return UNSTABLE_SADDLE_3D else: - if eigenvalues[1] < 0: - return SADDLE_NODE - elif eigenvalues[1] == 0: - return UNKNOWN_3D - else: - return UNSTABLE_SADDLE_3D - else: - if is_real.sum() == 1: - real_id = np.where(is_real)[0] - non_real_id = np.where(np.logical_not(is_real))[0] - v0 = eigenvalues[real_id] - v1 = eigenvalues[non_real_id[0]] - v2 = eigenvalues[non_real_id[1]] - v1_real = np.real(v1) - assert np.conj(v1) == v2 - if v0 < 0: - if v1_real < 0: - return STABLE_FOCUS_3D - elif v1_real == 0: # 零实部 - return UNKNOWN_3D - else: - return UNSTABLE_FOCUS_3D - elif v0 == 0: - if v1_real <= 0: - return UNKNOWN_3D # 零实部 - else: - return UNSTABLE_POINT_3D # TODO + if is_real.sum() == 1: + real_id = np.where(is_real)[0] + non_real_id = np.where(np.logical_not(is_real))[0] + v0 = eigenvalues[real_id] + v1 = eigenvalues[non_real_id[0]] + v2 = eigenvalues[non_real_id[1]] + v1_real = np.real(v1) + assert np.conj(v1) == v2 + if v0 < 0: + if v1_real < 0: + return STABLE_FOCUS_3D + elif v1_real == 0: # 零实部 + return UNKNOWN_3D + else: + return UNSTABLE_FOCUS_3D + elif v0 == 0: + if v1_real <= 0: + return UNKNOWN_3D # 零实部 + else: + return UNSTABLE_POINT_3D # TODO + else: + if v1_real < 0: + return UNSTABLE_FOCUS_3D + elif v1_real == 0: + return UNSTABLE_CENTER_3D + else: + return UNSTABLE_POINT_3D # TODO + # else: + # raise ValueError() + + eigenvalues = np.real(eigenvalues) + if np.all(eigenvalues < 0): + return STABLE_POINT_3D # TODO else: - if v1_real < 0: - return UNSTABLE_FOCUS_3D - elif v1_real == 0: - return UNSTABLE_CENTER_3D - else: return UNSTABLE_POINT_3D # TODO - # else: - # raise ValueError() - - eigenvalues = np.real(eigenvalues) - if np.all(eigenvalues < 0): - return STABLE_POINT_3D # TODO else: - return UNSTABLE_POINT_3D # TODO - else: - raise ValueError('Unknown derivatives, only supports the jacobian ' - 'matrix with the shape of (1), (2, 2), or (3, 3).') + raise ValueError('Unknown derivatives, only supports the jacobian ' + 'matrix with the shape of (1), (2, 2), or (3, 3).') diff --git a/brainpy/_src/analysis/tests/test_stability.py b/brainpy/_src/analysis/tests/test_stability.py index 62f283325..91cf9f34f 100644 --- a/brainpy/_src/analysis/tests/test_stability.py +++ b/brainpy/_src/analysis/tests/test_stability.py @@ -4,8 +4,6 @@ def test_d1(): - assert stability_analysis(1.) == UNSTABLE_POINT_1D - assert stability_analysis(-1.) == STABLE_POINT_1D - assert stability_analysis(0.) == SADDLE_NODE - - + assert stability_analysis(1.) == UNSTABLE_POINT_1D + assert stability_analysis(-1.) == STABLE_POINT_1D + assert stability_analysis(0.) == SADDLE_NODE diff --git a/brainpy/_src/analysis/utils/function.py b/brainpy/_src/analysis/utils/function.py index 1843286d3..7f77ebf62 100644 --- a/brainpy/_src/analysis/utils/function.py +++ b/brainpy/_src/analysis/utils/function.py @@ -6,99 +6,99 @@ from brainpy import errors __all__ = [ - 'f_without_jaxarray_return', - 'remove_return_shape', - 'get_args', - 'std_derivative', - 'std_func', + 'f_without_jaxarray_return', + 'remove_return_shape', + 'get_args', + 'std_derivative', + 'std_func', ] def f_without_jaxarray_return(f): - def f2(*args, **kwargs): - r = f(*args, **kwargs) - return r.value if isinstance(r, bm.BaseArray) else r + def f2(*args, **kwargs): + r = f(*args, **kwargs) + return r.value if isinstance(r, bm.BaseArray) else r - return f2 + return f2 def remove_return_shape(f): - def f2(*args, **kwargs): - r = f(*args, **kwargs) - if r.shape == (1,):r = r[0] - return r - return f2 + def f2(*args, **kwargs): + r = f(*args, **kwargs) + if r.shape == (1,): r = r[0] + return r + return f2 -def get_args(f, gather_var=True): - reduced_args = [] - for name, par in inspect.signature(f).parameters.items(): - if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: - reduced_args.append(par.name) - elif par.kind is inspect.Parameter.KEYWORD_ONLY: - reduced_args.append(par.name) - elif par.kind is inspect.Parameter.VAR_POSITIONAL: - raise errors.DiffEqError('Don not support positional only parameters, e.g., /') - elif par.kind is inspect.Parameter.POSITIONAL_ONLY: - raise errors.DiffEqError('Don not support positional only parameters, e.g., /') - elif par.kind is inspect.Parameter.VAR_KEYWORD: - raise errors.DiffEqError(f'Don not support dict of keyword arguments: {str(par)}') - else: - raise errors.DiffEqError(f'Unknown argument type: {par.kind}') - if gather_var: - var_names = [] - for a in reduced_args: - if a == 't': break - var_names.append(a) +def get_args(f, gather_var=True): + reduced_args = [] + for name, par in inspect.signature(f).parameters.items(): + if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + reduced_args.append(par.name) + elif par.kind is inspect.Parameter.KEYWORD_ONLY: + reduced_args.append(par.name) + elif par.kind is inspect.Parameter.VAR_POSITIONAL: + raise errors.DiffEqError('Don not support positional only parameters, e.g., /') + elif par.kind is inspect.Parameter.POSITIONAL_ONLY: + raise errors.DiffEqError('Don not support positional only parameters, e.g., /') + elif par.kind is inspect.Parameter.VAR_KEYWORD: + raise errors.DiffEqError(f'Don not support dict of keyword arguments: {str(par)}') + else: + raise errors.DiffEqError(f'Unknown argument type: {par.kind}') + + if gather_var: + var_names = [] + for a in reduced_args: + if a == 't': break + var_names.append(a) + else: + raise ValueError('Do not find time variable "t".') + return var_names, reduced_args else: - raise ValueError('Do not find time variable "t".') - return var_names, reduced_args - else: - return reduced_args + return reduced_args def std_derivative(original_fargs, target_vars, target_pars): - var = original_fargs[0] - num_vars = len(target_vars) + var = original_fargs[0] + num_vars = len(target_vars) - def inner(f): - def call(*dyn_vars_and_pars, **fixed_vars_and_pars): - params = dict() - for i, v in enumerate(target_vars): - if (v != var) and (v in original_fargs): - params[v] = dyn_vars_and_pars[i] - for j, p in enumerate(target_pars): - if p in original_fargs: - params[p] = dyn_vars_and_pars[num_vars + j] - for k, v in fixed_vars_and_pars.items(): - if k in original_fargs: - params[k] = v - return f(dyn_vars_and_pars[target_vars.index(var)], 0., **params) + def inner(f): + def call(*dyn_vars_and_pars, **fixed_vars_and_pars): + params = dict() + for i, v in enumerate(target_vars): + if (v != var) and (v in original_fargs): + params[v] = dyn_vars_and_pars[i] + for j, p in enumerate(target_pars): + if p in original_fargs: + params[p] = dyn_vars_and_pars[num_vars + j] + for k, v in fixed_vars_and_pars.items(): + if k in original_fargs: + params[k] = v + return f(dyn_vars_and_pars[target_vars.index(var)], 0., **params) - return call + return call - return inner + return inner def std_func(original_fargs, target_vars, target_pars): - num_vars = len(target_vars) - - def inner(f): - def call(*dyn_vars_and_pars, **fixed_vars_and_pars): - params = dict() - for i, v in enumerate(target_vars): - if v in original_fargs: - params[v] = dyn_vars_and_pars[i] - for j, p in enumerate(target_pars): - if p in original_fargs: - params[p] = dyn_vars_and_pars[num_vars + j] - for k, v in fixed_vars_and_pars.items(): - if k in original_fargs: - params[k] = v - return f(**params) - - return call - - return inner - + num_vars = len(target_vars) + + def inner(f): + def call(*dyn_vars_and_pars, **fixed_vars_and_pars): + params = dict() + for i, v in enumerate(target_vars): + if v in original_fargs: + params[v] = dyn_vars_and_pars[i] + for j, p in enumerate(target_pars): + if p in original_fargs: + params[p] = dyn_vars_and_pars[num_vars + j] + for k, v in fixed_vars_and_pars.items(): + if k in original_fargs: + params[k] = v + return f(**params) + + return call + + return inner diff --git a/brainpy/_src/analysis/utils/measurement.py b/brainpy/_src/analysis/utils/measurement.py index be8223c8c..07d079c9f 100644 --- a/brainpy/_src/analysis/utils/measurement.py +++ b/brainpy/_src/analysis/utils/measurement.py @@ -12,108 +12,107 @@ from brainpy.tools import numba_jit __all__ = [ - 'find_indexes_of_limit_cycle_max', - 'euclidean_distance', - 'euclidean_distance_jax', + 'find_indexes_of_limit_cycle_max', + 'euclidean_distance', + 'euclidean_distance_jax', ] @numba_jit def _f1(arr, grad, tol): - condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0) - indexes = np.where(condition)[0] - if len(indexes) >= 2: - data = arr[indexes[-2]: indexes[-1]] - length = np.max(data) - np.min(data) - a = arr[indexes[-2]] - b = arr[indexes[-1]] - # TODO: how to choose length threshold, 1e-3? - if length > 1e-3 and np.abs(a - b) <= tol * length: - return indexes[-2:] - return np.array([-1, -1]) + condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0) + indexes = np.where(condition)[0] + if len(indexes) >= 2: + data = arr[indexes[-2]: indexes[-1]] + length = np.max(data) - np.min(data) + a = arr[indexes[-2]] + b = arr[indexes[-1]] + # TODO: how to choose length threshold, 1e-3? + if length > 1e-3 and np.abs(a - b) <= tol * length: + return indexes[-2:] + return np.array([-1, -1]) def find_indexes_of_limit_cycle_max(arr, tol=0.001): - grad = np.gradient(arr) - return _f1(arr, grad, tol) + grad = np.gradient(arr) + return _f1(arr, grad, tol) @numba_jit def euclidean_distance(points: np.ndarray, num_point=None): - """Get the distance matrix. + """Get the distance matrix. - Equivalent to: + Equivalent to: - >>> from scipy.spatial.distance import squareform, pdist - >>> f = lambda points: squareform(pdist(points, metric="euclidean")) + >>> from scipy.spatial.distance import squareform, pdist + >>> f = lambda points: squareform(pdist(points, metric="euclidean")) - Parameters:: + Parameters:: - points: ArrayType - The points. + points: ArrayType + The points. - Returns:: + Returns:: - dist_matrix: jnp.ndarray - The distance matrix. - """ + dist_matrix: jnp.ndarray + The distance matrix. + """ - if isinstance(points, dict): - if num_point is None: - raise ValueError('Please provide num_point') - indices = np.triu_indices(num_point) - dist_mat = np.zeros((num_point, num_point)) - for idx in range(len(indices[0])): - i = indices[0][idx] - j = indices[1][idx] - dist_mat[i, j] = np.sqrt(np.sum([np.sum((value[i] - value[j]) ** 2) for value in points.values()])) - else: - num_point = points.shape[0] - indices = np.triu_indices(num_point) - dist_mat = np.zeros((num_point, num_point)) - for idx in range(len(indices[0])): - i = indices[0][idx] - j = indices[1][idx] - dist_mat[i, j] = np.linalg.norm(points[i] - points[j]) - dist_mat = np.maximum(dist_mat, dist_mat.T) - return dist_mat + if isinstance(points, dict): + if num_point is None: + raise ValueError('Please provide num_point') + indices = np.triu_indices(num_point) + dist_mat = np.zeros((num_point, num_point)) + for idx in range(len(indices[0])): + i = indices[0][idx] + j = indices[1][idx] + dist_mat[i, j] = np.sqrt(np.sum([np.sum((value[i] - value[j]) ** 2) for value in points.values()])) + else: + num_point = points.shape[0] + indices = np.triu_indices(num_point) + dist_mat = np.zeros((num_point, num_point)) + for idx in range(len(indices[0])): + i = indices[0][idx] + j = indices[1][idx] + dist_mat[i, j] = np.linalg.norm(points[i] - points[j]) + dist_mat = np.maximum(dist_mat, dist_mat.T) + return dist_mat @jax.jit @partial(jax.vmap, in_axes=[0, 0, None]) def _ed(i, j, leaves): - squares = jnp.asarray([((leaf[i] - leaf[j]) ** 2).sum() for leaf in leaves]) - return jnp.sqrt(jnp.sum(squares)) + squares = jnp.asarray([((leaf[i] - leaf[j]) ** 2).sum() for leaf in leaves]) + return jnp.sqrt(jnp.sum(squares)) def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=None): - """Get the distance matrix. - - Equivalent to: - - >>> from scipy.spatial.distance import squareform, pdist - >>> f = lambda points: squareform(pdist(points, metric="euclidean")) - - Parameters:: - - points: ArrayType - The points. - num_point: int - - Returns:: - - dist_matrix: ArrayType - The distance matrix. - """ - if isinstance(points, dict): - if num_point is None: - raise ValueError('Please provide num_point') - else: - num_point = points.shape[0] - indices = jnp.triu_indices(num_point) - dist_mat = bm.zeros((num_point, num_point)) - leaves, _ = tree_flatten(points) - dist_mat[indices] = _ed(*indices, leaves) - dist_mat = jnp.maximum(dist_mat.value, dist_mat.value.T) - return dist_mat - + """Get the distance matrix. + + Equivalent to: + + >>> from scipy.spatial.distance import squareform, pdist + >>> f = lambda points: squareform(pdist(points, metric="euclidean")) + + Parameters:: + + points: ArrayType + The points. + num_point: int + + Returns:: + + dist_matrix: ArrayType + The distance matrix. + """ + if isinstance(points, dict): + if num_point is None: + raise ValueError('Please provide num_point') + else: + num_point = points.shape[0] + indices = jnp.triu_indices(num_point) + dist_mat = bm.zeros((num_point, num_point)) + leaves, _ = tree_flatten(points) + dist_mat[indices] = _ed(*indices, leaves) + dist_mat = jnp.maximum(dist_mat.value, dist_mat.value.T) + return dist_mat diff --git a/brainpy/_src/analysis/utils/model.py b/brainpy/_src/analysis/utils/model.py index e51d392e1..43ca65708 100644 --- a/brainpy/_src/analysis/utils/model.py +++ b/brainpy/_src/analysis/utils/model.py @@ -1,146 +1,146 @@ # -*- coding: utf-8 -*- -from brainpy._src.math.object_transform import Variable -from brainpy._src.math.environment import get_float -from brainpy._src.math.interoperability import as_jax -from brainpy._src.dynsys import DynamicalSystem from brainpy._src.context import share -from brainpy._src.runners import DSRunner +from brainpy._src.dynsys import DynamicalSystem from brainpy._src.integrators.base import Integrator from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.base import ODEIntegrator from brainpy._src.integrators.ode.generic import odeint +from brainpy._src.math.environment import get_float +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.object_transform import Variable +from brainpy._src.runners import DSRunner from brainpy.errors import AnalyzerError, UnsupportedError __all__ = [ - 'model_transform', - 'NumDSWrapper', - 'TrajectModel', + 'model_transform', + 'NumDSWrapper', + 'TrajectModel', ] def _check_model(model): - if isinstance(model, Integrator): - if not isinstance(model, ODEIntegrator): - raise AnalyzerError(f'Must be the instance of {ODEIntegrator.__name__}, but got {model}.') - elif callable(model): - model = odeint(model) - else: - raise ValueError(f'Please provide derivative function or integral function. But we got {model}') - if isinstance(model.f, JointEq): - return [type(model)(eq, var_type=model.var_type, dt=model.dt) for eq in model.f.eqs] - else: - return [model] + if isinstance(model, Integrator): + if not isinstance(model, ODEIntegrator): + raise AnalyzerError(f'Must be the instance of {ODEIntegrator.__name__}, but got {model}.') + elif callable(model): + model = odeint(model) + else: + raise ValueError(f'Please provide derivative function or integral function. But we got {model}') + if isinstance(model.f, JointEq): + return [type(model)(eq, var_type=model.var_type, dt=model.dt) for eq in model.f.eqs] + else: + return [model] def model_transform(model): - # check model - if isinstance(model, DynamicalSystem): - model = tuple(model.nodes(level=-1).subset(ODEIntegrator).unique().values()) - elif isinstance(model, NumDSWrapper): - return model - elif isinstance(model, ODEIntegrator): # - model = [model] - elif callable(model): - model = [model] - all_models = [] - if isinstance(model, (list, tuple)): - if len(model) == 0: - raise AnalyzerError(f'Found no derivative/integral functions: {model}') - for fun in tuple(model): - all_models.extend(_check_model(fun)) - elif isinstance(model, dict): - if len(model) == 0: - raise AnalyzerError(f'Found no derivative/integral functions: {model}') - for fun in tuple(model.values()): - all_models.extend(_check_model(fun)) - else: - raise UnsupportedError(f'Dynamics analysis by symbolic approach only supports ' - f'derivative/integral functions or {DynamicalSystem.__name__}, ' - f'but we got: {type(model)}: {str(model)}') - - # pars to update - pars_update = set() - for fun in all_models: - pars_update.update(fun.parameters[1:]) - - # variables and parameters - all_variables = set() - all_parameters = set() - for integral in all_models: - # variable - if len(integral.variables) != 1: - raise AnalyzerError(f'Only supports one {ODEIntegrator.__name__} one variable, ' - f'but we got {len(integral.variables)} variables in {integral}.') - var = integral.variables[0] - if var in all_variables: - raise AnalyzerError(f'Variable name {var} has been defined before. ' - f'Please change another name.') - all_variables.add(var) - # parameter - all_parameters.update(integral.parameters[1:]) - - # form a dynamic model - return NumDSWrapper(integrals=all_models, - variables=list(all_variables), - parameters=list(all_parameters), - pars_update=pars_update) + # check model + if isinstance(model, DynamicalSystem): + model = tuple(model.nodes(level=-1).subset(ODEIntegrator).unique().values()) + elif isinstance(model, NumDSWrapper): + return model + elif isinstance(model, ODEIntegrator): # + model = [model] + elif callable(model): + model = [model] + all_models = [] + if isinstance(model, (list, tuple)): + if len(model) == 0: + raise AnalyzerError(f'Found no derivative/integral functions: {model}') + for fun in tuple(model): + all_models.extend(_check_model(fun)) + elif isinstance(model, dict): + if len(model) == 0: + raise AnalyzerError(f'Found no derivative/integral functions: {model}') + for fun in tuple(model.values()): + all_models.extend(_check_model(fun)) + else: + raise UnsupportedError(f'Dynamics analysis by symbolic approach only supports ' + f'derivative/integral functions or {DynamicalSystem.__name__}, ' + f'but we got: {type(model)}: {str(model)}') + + # pars to update + pars_update = set() + for fun in all_models: + pars_update.update(fun.parameters[1:]) + + # variables and parameters + all_variables = set() + all_parameters = set() + for integral in all_models: + # variable + if len(integral.variables) != 1: + raise AnalyzerError(f'Only supports one {ODEIntegrator.__name__} one variable, ' + f'but we got {len(integral.variables)} variables in {integral}.') + var = integral.variables[0] + if var in all_variables: + raise AnalyzerError(f'Variable name {var} has been defined before. ' + f'Please change another name.') + all_variables.add(var) + # parameter + all_parameters.update(integral.parameters[1:]) + + # form a dynamic model + return NumDSWrapper(integrals=all_models, + variables=list(all_variables), + parameters=list(all_parameters), + pars_update=pars_update) class NumDSWrapper(object): - """The wrapper of a dynamical model.""" + """The wrapper of a dynamical model.""" - def __init__(self, - integrals, - variables, - parameters, - pars_update=None): - self.f_integrals = integrals # all integrators - self.f_derivatives = {intg.variables[0]: intg.f for intg in integrals} # all integrators - self.variables = variables # all variables - self.parameters = parameters # all parameters - self.pars_update = pars_update # the parameters to update - self.name2integral = {intg.variables[0]: intg for intg in integrals} - self.name2derivative = {intg.variables[0]: intg.f for intg in integrals} + def __init__(self, + integrals, + variables, + parameters, + pars_update=None): + self.f_integrals = integrals # all integrators + self.f_derivatives = {intg.variables[0]: intg.f for intg in integrals} # all integrators + self.variables = variables # all variables + self.parameters = parameters # all parameters + self.pars_update = pars_update # the parameters to update + self.name2integral = {intg.variables[0]: intg for intg in integrals} + self.name2derivative = {intg.variables[0]: intg.f for intg in integrals} - def __repr__(self): - return f'{self.__class__.__name__}(variables={self.variables}, parameters={self.parameters})' + def __repr__(self): + return f'{self.__class__.__name__}(variables={self.variables}, parameters={self.parameters})' class TrajectModel(DynamicalSystem): - def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None): - super(TrajectModel, self).__init__() - - # variables - assert isinstance(initial_vars, dict) - initial_vars = {k: Variable(as_jax(v, dtype=get_float())) - for k, v in initial_vars.items()} - self.register_implicit_vars(initial_vars) - - # parameters - pars = dict() if pars is None else pars - assert isinstance(pars, dict) - self.pars = [as_jax(v, dtype=get_float()) for k, v in pars.items()] - - # integrals - self.integrals = integrals - - # runner - self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False) - - def update(self): - all_vars = list(self.implicit_vars.values()) - for key, intg in self.integrals.items(): - self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt'])) - - def __getattr__(self, item): - child_vars = super().__getattribute__('implicit_vars') - if item in child_vars: - return child_vars[item] - else: - return super().__getattribute__(item) - - def run(self, duration): - self.runner.run(duration) - return self.runner.mon + def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None): + super(TrajectModel, self).__init__() + + # variables + assert isinstance(initial_vars, dict) + initial_vars = {k: Variable(as_jax(v, dtype=get_float())) + for k, v in initial_vars.items()} + self.register_implicit_vars(initial_vars) + + # parameters + pars = dict() if pars is None else pars + assert isinstance(pars, dict) + self.pars = [as_jax(v, dtype=get_float()) for k, v in pars.items()] + + # integrals + self.integrals = integrals + + # runner + self.runner = DSRunner(self, monitors=list(initial_vars.keys()), dt=dt, progress_bar=False) + + def update(self): + all_vars = list(self.implicit_vars.values()) + for key, intg in self.integrals.items(): + self.implicit_vars[key].update(intg(*all_vars, *self.pars, dt=share['dt'])) + + def __getattr__(self, item): + child_vars = super().__getattribute__('implicit_vars') + if item in child_vars: + return child_vars[item] + else: + return super().__getattribute__(item) + + def run(self, duration): + self.runner.run(duration) + return self.runner.mon diff --git a/brainpy/_src/analysis/utils/optimization.py b/brainpy/_src/analysis/utils/optimization.py index 13b893e82..1aded84a9 100644 --- a/brainpy/_src/analysis/utils/optimization.py +++ b/brainpy/_src/analysis/utils/optimization.py @@ -7,26 +7,26 @@ from jax import grad, jit, vmap from jax.flatten_util import ravel_pytree -from brainpy import errors import brainpy._src.math as bm +from brainpy import errors from . import f_without_jaxarray_return try: - import scipy.optimize as soptimize + import scipy.optimize as soptimize except (ModuleNotFoundError, ImportError): - soptimize = None + soptimize = None __all__ = [ - 'ECONVERGED', 'ECONVERR', - - 'jax_brentq', - 'get_brentq_candidates', - 'brentq_candidates', - 'brentq_roots', - 'brentq_roots2', - 'scipy_minimize_with_jax', - 'roots_of_1d_by_x', - 'roots_of_1d_by_xy', + 'ECONVERGED', 'ECONVERR', + + 'jax_brentq', + 'get_brentq_candidates', + 'brentq_candidates', + 'brentq_roots', + 'brentq_roots2', + 'scipy_minimize_with_jax', + 'roots_of_1d_by_x', + 'roots_of_1d_by_xy', ] ECONVERGED = 0 @@ -34,188 +34,188 @@ def _logical_or(a, b): - a = a.value if isinstance(a, bm.BaseArray) else a - b = b.value if isinstance(b, bm.BaseArray) else b - return jnp.logical_or(a, b) + a = a.value if isinstance(a, bm.BaseArray) else a + b = b.value if isinstance(b, bm.BaseArray) else b + return jnp.logical_or(a, b) def _logical_and(a, b): - a = a.value if isinstance(a, bm.BaseArray) else a - b = b.value if isinstance(b, bm.BaseArray) else b - return jnp.logical_and(a, b) + a = a.value if isinstance(a, bm.BaseArray) else a + b = b.value if isinstance(b, bm.BaseArray) else b + return jnp.logical_and(a, b) def _where(p, a, b): - p = p.value if isinstance(p, bm.BaseArray) else p - a = a.value if isinstance(a, bm.BaseArray) else a - b = b.value if isinstance(b, bm.BaseArray) else b - return jnp.where(p, a, b) + p = p.value if isinstance(p, bm.BaseArray) else p + a = a.value if isinstance(a, bm.BaseArray) else a + b = b.value if isinstance(b, bm.BaseArray) else b + return jnp.where(p, a, b) def jax_brentq(fun): - f = f_without_jaxarray_return(fun) - assert jax.config.read('jax_enable_x64'), ('Brentq optimization need x64 support. ' - 'Please enable x64 with "brainpy.math.enable_x64()"') - rtol = 4 * jnp.finfo(jnp.float64).eps - - # if jax.config.read('jax_enable_x64'): - # rtol = 4 * jnp.finfo(jnp.float64).eps - # else: - # rtol = 1.5 * jnp.finfo(jnp.float32).eps - - def x(a, b, args=(), xtol=2e-14, maxiter=200): - # Convert to float - xpre = a * 1.0 - xcur = b * 1.0 - - # Conditional checks for intervals in methods involving bisection - fpre = f(xpre, *args) - fcur = f(xcur, *args) - - # Root found at either end of [a,b] - root = _where(fpre == 0, xpre, 0.) - status = _where(fpre == 0, ECONVERGED, ECONVERR) - root = _where(fcur == 0, xcur, root) - status = _where(fcur == 0, ECONVERGED, status) - - # Check for sign error and early termination - # Perform Brent's method - def _f1(x): - x['xblk'] = x['xpre'] - x['fblk'] = x['fpre'] - x['spre'] = x['xcur'] - x['xpre'] - x['scur'] = x['xcur'] - x['xpre'] - return x - - def _f2(x): - x['xpre'] = x['xcur'] - x['xcur'] = x['xblk'] - x['xblk'] = x['xpre'] - x['fpre'] = x['fcur'] - x['fcur'] = x['fblk'] - x['fblk'] = x['fpre'] - return x - - def _f5(x): - x['stry'] = -x['fcur'] * (x['xcur'] - x['xpre']) / (x['fcur'] - x['fpre']) - return x - - def _f6(x): - x['dpre'] = (x['fpre'] - x['fcur']) / (x['xpre'] - x['xcur']) - dblk = (x['fblk'] - x['fcur']) / (x['xblk'] - x['xcur']) - _tmp = dblk * x['dpre'] * (x['fblk'] - x['fpre']) - x['stry'] = -x['fcur'] * (x['fblk'] * dblk - x['fpre'] * x['dpre']) / _tmp - return x - - def _f3(x): - x = jax.lax.cond(x['xpre'] == x['xblk'], _f5, _f6, x) - k = jnp.min(jnp.array([abs(x['spre']), 3 * abs(x['sbis']) - x['delta']])) - j = 2 * abs(x['stry']) < k - x['spre'] = _where(j, x['scur'], x['sbis']) - x['scur'] = _where(j, x['stry'], x['sbis']) - return x - - def _f4(x): # bisect - x['spre'] = x['sbis'] - x['scur'] = x['sbis'] - return x - - def body_fun(x): - x['itr'] += 1 - x = jax.lax.cond(x['fpre'] * x['fcur'] < 0, _f1, lambda a: a, x) - x = jax.lax.cond(abs(x['fblk']) < abs(x['fcur']), _f2, lambda a: a, x) - x['delta'] = (xtol + rtol * abs(x['xcur'])) / 2 - x['sbis'] = (x['xblk'] - x['xcur']) / 2 - # Root found - j = _logical_or(x['fcur'] == 0, abs(x['sbis']) < x['delta']) - x['status'] = _where(j, ECONVERGED, x['status']) - x['root'] = _where(j, x['xcur'], x['root']) - x = jax.lax.cond(_logical_and(abs(x['spre']) > x['delta'], abs(x['fcur']) < abs(x['fpre'])), - _f3, _f4, x) - x['xpre'] = x['xcur'] - x['fpre'] = x['fcur'] - x['xcur'] += _where(abs(x['scur']) > x['delta'], - x['scur'], _where(x['sbis'] > 0, x['delta'], -x['delta'])) - x['fcur'] = f(x['xcur'], *args) - x['funcalls'] += 1 - return x - - def cond_fun(R): - return jnp.logical_and(R['status'] != ECONVERGED, R['itr'] <= maxiter) - - R = dict(root=root, status=status, xpre=xpre, xcur=xcur, fpre=fpre, fcur=fcur, - itr=0, funcalls=2, xblk=xpre, fblk=fpre, - sbis=(xpre - xcur) / 2, - delta=(xtol + rtol * abs(xcur)) / 2, - stry=-fcur * (xcur - xpre) / (fcur - fpre), - scur=xcur - xpre, spre=xcur - xpre, - dpre=(fpre - fcur) / (xpre - xcur)) - R = jax.lax.cond(status == ECONVERGED, - lambda x: x, - lambda x: jax.lax.while_loop(cond_fun, body_fun, x), - R) - return dict(root=R['root'], funcalls=R['funcalls'], itr=R['itr'], status=R['status']) - - return x + f = f_without_jaxarray_return(fun) + assert jax.config.read('jax_enable_x64'), ('Brentq optimization need x64 support. ' + 'Please enable x64 with "brainpy.math.enable_x64()"') + rtol = 4 * jnp.finfo(jnp.float64).eps + + # if jax.config.read('jax_enable_x64'): + # rtol = 4 * jnp.finfo(jnp.float64).eps + # else: + # rtol = 1.5 * jnp.finfo(jnp.float32).eps + + def x(a, b, args=(), xtol=2e-14, maxiter=200): + # Convert to float + xpre = a * 1.0 + xcur = b * 1.0 + + # Conditional checks for intervals in methods involving bisection + fpre = f(xpre, *args) + fcur = f(xcur, *args) + + # Root found at either end of [a,b] + root = _where(fpre == 0, xpre, 0.) + status = _where(fpre == 0, ECONVERGED, ECONVERR) + root = _where(fcur == 0, xcur, root) + status = _where(fcur == 0, ECONVERGED, status) + + # Check for sign error and early termination + # Perform Brent's method + def _f1(x): + x['xblk'] = x['xpre'] + x['fblk'] = x['fpre'] + x['spre'] = x['xcur'] - x['xpre'] + x['scur'] = x['xcur'] - x['xpre'] + return x + + def _f2(x): + x['xpre'] = x['xcur'] + x['xcur'] = x['xblk'] + x['xblk'] = x['xpre'] + x['fpre'] = x['fcur'] + x['fcur'] = x['fblk'] + x['fblk'] = x['fpre'] + return x + + def _f5(x): + x['stry'] = -x['fcur'] * (x['xcur'] - x['xpre']) / (x['fcur'] - x['fpre']) + return x + + def _f6(x): + x['dpre'] = (x['fpre'] - x['fcur']) / (x['xpre'] - x['xcur']) + dblk = (x['fblk'] - x['fcur']) / (x['xblk'] - x['xcur']) + _tmp = dblk * x['dpre'] * (x['fblk'] - x['fpre']) + x['stry'] = -x['fcur'] * (x['fblk'] * dblk - x['fpre'] * x['dpre']) / _tmp + return x + + def _f3(x): + x = jax.lax.cond(x['xpre'] == x['xblk'], _f5, _f6, x) + k = jnp.min(jnp.array([abs(x['spre']), 3 * abs(x['sbis']) - x['delta']])) + j = 2 * abs(x['stry']) < k + x['spre'] = _where(j, x['scur'], x['sbis']) + x['scur'] = _where(j, x['stry'], x['sbis']) + return x + + def _f4(x): # bisect + x['spre'] = x['sbis'] + x['scur'] = x['sbis'] + return x + + def body_fun(x): + x['itr'] += 1 + x = jax.lax.cond(x['fpre'] * x['fcur'] < 0, _f1, lambda a: a, x) + x = jax.lax.cond(abs(x['fblk']) < abs(x['fcur']), _f2, lambda a: a, x) + x['delta'] = (xtol + rtol * abs(x['xcur'])) / 2 + x['sbis'] = (x['xblk'] - x['xcur']) / 2 + # Root found + j = _logical_or(x['fcur'] == 0, abs(x['sbis']) < x['delta']) + x['status'] = _where(j, ECONVERGED, x['status']) + x['root'] = _where(j, x['xcur'], x['root']) + x = jax.lax.cond(_logical_and(abs(x['spre']) > x['delta'], abs(x['fcur']) < abs(x['fpre'])), + _f3, _f4, x) + x['xpre'] = x['xcur'] + x['fpre'] = x['fcur'] + x['xcur'] += _where(abs(x['scur']) > x['delta'], + x['scur'], _where(x['sbis'] > 0, x['delta'], -x['delta'])) + x['fcur'] = f(x['xcur'], *args) + x['funcalls'] += 1 + return x + + def cond_fun(R): + return jnp.logical_and(R['status'] != ECONVERGED, R['itr'] <= maxiter) + + R = dict(root=root, status=status, xpre=xpre, xcur=xcur, fpre=fpre, fcur=fcur, + itr=0, funcalls=2, xblk=xpre, fblk=fpre, + sbis=(xpre - xcur) / 2, + delta=(xtol + rtol * abs(xcur)) / 2, + stry=-fcur * (xcur - xpre) / (fcur - fpre), + scur=xcur - xpre, spre=xcur - xpre, + dpre=(fpre - fcur) / (xpre - xcur)) + R = jax.lax.cond(status == ECONVERGED, + lambda x: x, + lambda x: jax.lax.while_loop(cond_fun, body_fun, x), + R) + return dict(root=R['root'], funcalls=R['funcalls'], itr=R['itr'], status=R['status']) + + return x def get_brentq_candidates(f, xs, ys): - f = f_without_jaxarray_return(f) - xs = bm.as_jax(xs) - ys = bm.as_jax(ys) - Y, X = jnp.meshgrid(ys, xs) - vals = f(X, Y) - signs = jnp.sign(vals) - x_ids, y_ids = jnp.where(signs[:-1] * signs[1:] <= 0) - starts = xs[x_ids] - ends = xs[x_ids + 1] - args = ys[y_ids] - return starts, ends, args + f = f_without_jaxarray_return(f) + xs = bm.as_jax(xs) + ys = bm.as_jax(ys) + Y, X = jnp.meshgrid(ys, xs) + vals = f(X, Y) + signs = jnp.sign(vals) + x_ids, y_ids = jnp.where(signs[:-1] * signs[1:] <= 0) + starts = xs[x_ids] + ends = xs[x_ids + 1] + args = ys[y_ids] + return starts, ends, args def brentq_candidates(vmap_f, *values, args=()): - # change the position of meshgrid values - values = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in values) - xs = values[0] - mesh_values = jnp.meshgrid(*values) - if jnp.ndim(mesh_values[0]) > 1: - mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) - mesh_values = tuple(m.flatten() for m in mesh_values) - # function outputs - signs = jnp.sign(vmap_f(*(mesh_values + args))) - # compute the selected values - signs = signs.reshape((xs.shape[0], -1)) - par_len = signs.shape[1] - signs1 = signs.at[-1].set(1) # discard the final row - signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row - ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] - x_starts = mesh_values[0][ids] - x_ends = mesh_values[0][ids + par_len] - other_vals = tuple(v[ids] for v in mesh_values[1:]) - return x_starts, x_ends, other_vals + # change the position of meshgrid values + values = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in values) + xs = values[0] + mesh_values = jnp.meshgrid(*values) + if jnp.ndim(mesh_values[0]) > 1: + mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values) + mesh_values = tuple(m.flatten() for m in mesh_values) + # function outputs + signs = jnp.sign(vmap_f(*(mesh_values + args))) + # compute the selected values + signs = signs.reshape((xs.shape[0], -1)) + par_len = signs.shape[1] + signs1 = signs.at[-1].set(1) # discard the final row + signs2 = jnp.vstack((signs[1:], signs[:1])).at[-1].set(1) # discard the first row + ids = jnp.where((signs1 * signs2).flatten() <= 0)[0] + x_starts = mesh_values[0][ids] + x_ends = mesh_values[0][ids + par_len] + other_vals = tuple(v[ids] for v in mesh_values[1:]) + return x_starts, x_ends, other_vals def brentq_roots(f, starts, ends, *vmap_args, args=()): - in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) - vmap_f_opt = jax.jit(vmap(jax_brentq(f), in_axes=in_axes)) - all_args = vmap_args + args - if len(all_args): - res = vmap_f_opt(starts, ends, all_args) - else: - res = vmap_f_opt(starts, ends, ) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - roots = res['root'][valid_idx] - vmap_args = tuple(a[valid_idx] for a in vmap_args) - return roots, vmap_args + in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) + vmap_f_opt = jax.jit(vmap(jax_brentq(f), in_axes=in_axes)) + all_args = vmap_args + args + if len(all_args): + res = vmap_f_opt(starts, ends, all_args) + else: + res = vmap_f_opt(starts, ends, ) + valid_idx = jnp.where(res['status'] == ECONVERGED)[0] + roots = res['root'][valid_idx] + vmap_args = tuple(a[valid_idx] for a in vmap_args) + return roots, vmap_args def brentq_roots2(vmap_f, starts, ends, *vmap_args, args=()): - all_args = vmap_args + args - res = vmap_f(starts, ends, all_args) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - roots = res['root'][valid_idx] - vmap_args = tuple(a[valid_idx] for a in vmap_args) - return roots, vmap_args + all_args = vmap_args + args + res = vmap_f(starts, ends, all_args) + valid_idx = jnp.where(res['status'] == ECONVERGED)[0] + roots = res['root'][valid_idx] + vmap_args = tuple(a[valid_idx] for a in vmap_args) + return roots, vmap_args def scipy_minimize_with_jax(fun, x0, @@ -226,366 +226,366 @@ def scipy_minimize_with_jax(fun, x0, tol=None, callback=None, options=None): - """ - A simple wrapper for scipy.optimize.minimize using JAX. - - Parameters:: - - fun: function - The objective function to be minimized, written in JAX code - so that it is automatically differentiable. It is of type, - ```fun: x, *args -> float``` where `x` is a PyTree and args - is a tuple of the fixed parameters needed to completely specify the function. - - x0: jnp.ndarray - Initial guess represented as a JAX PyTree. - - args: tuple, optional. - Extra arguments passed to the objective function - and its derivative. Must consist of valid JAX types; e.g. the leaves - of the PyTree must be floats. - - method : str or callable, optional - Type of solver. Should be one of - - 'Nelder-Mead' :ref:`(see here) ` - - 'Powell' :ref:`(see here) ` - - 'CG' :ref:`(see here) ` - - 'BFGS' :ref:`(see here) ` - - 'Newton-CG' :ref:`(see here) ` - - 'L-BFGS-B' :ref:`(see here) ` - - 'TNC' :ref:`(see here) ` - - 'COBYLA' :ref:`(see here) ` - - 'SLSQP' :ref:`(see here) ` - - 'trust-constr':ref:`(see here) ` - - 'dogleg' :ref:`(see here) ` - - 'trust-ncg' :ref:`(see here) ` - - 'trust-exact' :ref:`(see here) ` - - 'trust-krylov' :ref:`(see here) ` - - custom - a callable object (added in version 0.14.0), - see below for description. - If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``, - depending on if the problem has constraints or bounds. - - bounds : sequence or `Bounds`, optional - Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and - trust-constr methods. There are two ways to specify the bounds: - 1. Instance of `Bounds` class. - 2. Sequence of ``(min, max)`` pairs for each element in `x`. None - is used to specify no bound. - Note that in order to use `bounds` you will need to manually flatten - them in the same order as your inputs `x0`. - - constraints : {Constraint, dict} or List of {Constraint, dict}, optional - Constraints definition (only for COBYLA, SLSQP and trust-constr). - Constraints for 'trust-constr' are defined as a single object or a - list of objects specifying constraints to the optimization problem. - Available constraints are: - - `LinearConstraint` - - `NonlinearConstraint` - Constraints for COBYLA, SLSQP are defined as a list of dictionaries. - Each dictionary with fields: - type : str - Constraint type: 'eq' for equality, 'ineq' for inequality. - fun : callable - The function defining the constraint. - jac : callable, optional - The Jacobian of `fun` (only for SLSQP). - args : sequence, optional - Extra arguments to be passed to the function and Jacobian. - Equality constraint means that the constraint function result is to - be zero whereas inequality means that it is to be non-negative. - Note that COBYLA only supports inequality constraints. - - Note that in order to use `constraints` you will need to manually flatten - them in the same order as your inputs `x0`. - - tol : float, optional - Tolerance for termination. For detailed control, use solver-specific - options. - - options : dict, optional - A dictionary of solver options. All methods accept the following - generic options: - maxiter : int - Maximum number of iterations to perform. Depending on the - method each iteration may use several function evaluations. - disp : bool - Set to True to print convergence messages. - For method-specific options, see :func:`show_options()`. - - callback : callable, optional - Called after each iteration. For 'trust-constr' it is a callable with - the signature: - ``callback(xk, OptimizeResult state) -> bool`` - where ``xk`` is the current parameter vector represented as a PyTree, - and ``state`` is an `OptimizeResult` object, with the same fields - as the ones from the return. If callback returns True the algorithm - execution is terminated. - - For all the other methods, the signature is: - ```callback(xk)``` - where `xk` is the current parameter vector, represented as a PyTree. - - Returns:: - - res : The optimization result represented as a ``OptimizeResult`` object. - Important attributes are: - ``x``: the solution array, represented as a JAX PyTree - ``success``: a Boolean flag indicating if the optimizer exited successfully - ``message``: describes the cause of the termination. - See `scipy.optimize.OptimizeResult` for a description of other attributes. - - """ - if soptimize is None: - raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' - f'function: {scipy_minimize_with_jax}') - - # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays - x0_flat, unravel = ravel_pytree(x0) - - # Wrap the objective function to consume flat _original_ - # numpy arrays and produce scalar outputs. - def fun_wrapper(x_flat, *args): - x = unravel(x_flat) - r = fun(x, *args) - r = r.value if isinstance(r, bm.BaseArray) else r - return float(r) - - # Wrap the gradient in a similar manner - jac = jit(grad(fun)) - - def jac_wrapper(x_flat, *args): - x = unravel(x_flat) - g_flat, _ = ravel_pytree(jac(x, *args)) - return np.array(g_flat) - - # Wrap the callback to consume a pytree - def callback_wrapper(x_flat, *args): - if callback is not None: - x = unravel(x_flat) - return callback(x, *args) - - # Minimize with scipy - results = soptimize.minimize(fun_wrapper, - x0_flat, - args=args, - method=method, - jac=jac_wrapper, - callback=callback_wrapper, - bounds=bounds, - constraints=constraints, - tol=tol, - options=options) - - # pack the output back into a PyTree - results["x"] = unravel(results["x"]) - return results + """ + A simple wrapper for scipy.optimize.minimize using JAX. + + Parameters:: + + fun: function + The objective function to be minimized, written in JAX code + so that it is automatically differentiable. It is of type, + ```fun: x, *args -> float``` where `x` is a PyTree and args + is a tuple of the fixed parameters needed to completely specify the function. + + x0: jnp.ndarray + Initial guess represented as a JAX PyTree. + + args: tuple, optional. + Extra arguments passed to the objective function + and its derivative. Must consist of valid JAX types; e.g. the leaves + of the PyTree must be floats. + + method : str or callable, optional + Type of solver. Should be one of + - 'Nelder-Mead' :ref:`(see here) ` + - 'Powell' :ref:`(see here) ` + - 'CG' :ref:`(see here) ` + - 'BFGS' :ref:`(see here) ` + - 'Newton-CG' :ref:`(see here) ` + - 'L-BFGS-B' :ref:`(see here) ` + - 'TNC' :ref:`(see here) ` + - 'COBYLA' :ref:`(see here) ` + - 'SLSQP' :ref:`(see here) ` + - 'trust-constr':ref:`(see here) ` + - 'dogleg' :ref:`(see here) ` + - 'trust-ncg' :ref:`(see here) ` + - 'trust-exact' :ref:`(see here) ` + - 'trust-krylov' :ref:`(see here) ` + - custom - a callable object (added in version 0.14.0), + see below for description. + If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``, + depending on if the problem has constraints or bounds. + + bounds : sequence or `Bounds`, optional + Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and + trust-constr methods. There are two ways to specify the bounds: + 1. Instance of `Bounds` class. + 2. Sequence of ``(min, max)`` pairs for each element in `x`. None + is used to specify no bound. + Note that in order to use `bounds` you will need to manually flatten + them in the same order as your inputs `x0`. + + constraints : {Constraint, dict} or List of {Constraint, dict}, optional + Constraints definition (only for COBYLA, SLSQP and trust-constr). + Constraints for 'trust-constr' are defined as a single object or a + list of objects specifying constraints to the optimization problem. + Available constraints are: + - `LinearConstraint` + - `NonlinearConstraint` + Constraints for COBYLA, SLSQP are defined as a list of dictionaries. + Each dictionary with fields: + type : str + Constraint type: 'eq' for equality, 'ineq' for inequality. + fun : callable + The function defining the constraint. + jac : callable, optional + The Jacobian of `fun` (only for SLSQP). + args : sequence, optional + Extra arguments to be passed to the function and Jacobian. + Equality constraint means that the constraint function result is to + be zero whereas inequality means that it is to be non-negative. + Note that COBYLA only supports inequality constraints. + + Note that in order to use `constraints` you will need to manually flatten + them in the same order as your inputs `x0`. + + tol : float, optional + Tolerance for termination. For detailed control, use solver-specific + options. + + options : dict, optional + A dictionary of solver options. All methods accept the following + generic options: + maxiter : int + Maximum number of iterations to perform. Depending on the + method each iteration may use several function evaluations. + disp : bool + Set to True to print convergence messages. + For method-specific options, see :func:`show_options()`. + + callback : callable, optional + Called after each iteration. For 'trust-constr' it is a callable with + the signature: + ``callback(xk, OptimizeResult state) -> bool`` + where ``xk`` is the current parameter vector represented as a PyTree, + and ``state`` is an `OptimizeResult` object, with the same fields + as the ones from the return. If callback returns True the algorithm + execution is terminated. + + For all the other methods, the signature is: + ```callback(xk)``` + where `xk` is the current parameter vector, represented as a PyTree. + + Returns:: + + res : The optimization result represented as a ``OptimizeResult`` object. + Important attributes are: + ``x``: the solution array, represented as a JAX PyTree + ``success``: a Boolean flag indicating if the optimizer exited successfully + ``message``: describes the cause of the termination. + See `scipy.optimize.OptimizeResult` for a description of other attributes. + + """ + if soptimize is None: + raise errors.PackageMissingError(f'"scipy" must be installed when user want to use ' + f'function: {scipy_minimize_with_jax}') + + # Use tree flatten and unflatten to convert params x0 from PyTrees to flat arrays + x0_flat, unravel = ravel_pytree(x0) + + # Wrap the objective function to consume flat _original_ + # numpy arrays and produce scalar outputs. + def fun_wrapper(x_flat, *args): + x = unravel(x_flat) + r = fun(x, *args) + r = r.value if isinstance(r, bm.BaseArray) else r + return float(r) + + # Wrap the gradient in a similar manner + jac = jit(grad(fun)) + + def jac_wrapper(x_flat, *args): + x = unravel(x_flat) + g_flat, _ = ravel_pytree(jac(x, *args)) + return np.array(g_flat) + + # Wrap the callback to consume a pytree + def callback_wrapper(x_flat, *args): + if callback is not None: + x = unravel(x_flat) + return callback(x, *args) + + # Minimize with scipy + results = soptimize.minimize(fun_wrapper, + x0_flat, + args=args, + method=method, + jac=jac_wrapper, + callback=callback_wrapper, + bounds=bounds, + constraints=constraints, + tol=tol, + options=options) + + # pack the output back into a PyTree + results["x"] = unravel(results["x"]) + return results def roots_of_1d_by_x(f, candidates, args=()): - """Find the roots of the given function by numerical methods. - """ - f = f_without_jaxarray_return(f) - candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates - args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) - vals = f(candidates, *args) - signs = jnp.sign(vals) - zero_sign_idx = jnp.where(signs == 0)[0] - fps = candidates[zero_sign_idx] - candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0] - if len(candidate_ids) <= 0: - return fps - starts = candidates[candidate_ids] - ends = candidates[candidate_ids + 1] - f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) - res = f_opt(starts, ends, args) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - fps2 = res['root'][valid_idx] - return jnp.concatenate([fps, fps2]) + """Find the roots of the given function by numerical methods. + """ + f = f_without_jaxarray_return(f) + candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates + args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args) + vals = f(candidates, *args) + signs = jnp.sign(vals) + zero_sign_idx = jnp.where(signs == 0)[0] + fps = candidates[zero_sign_idx] + candidate_ids = jnp.where(signs[:-1] * signs[1:] < 0)[0] + if len(candidate_ids) <= 0: + return fps + starts = candidates[candidate_ids] + ends = candidates[candidate_ids + 1] + f_opt = jax.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) + res = f_opt(starts, ends, args) + valid_idx = jnp.where(res['status'] == ECONVERGED)[0] + fps2 = res['root'][valid_idx] + return jnp.concatenate([fps, fps2]) def roots_of_1d_by_xy(f, starts, ends, args): - f = f_without_jaxarray_return(f) - f_opt = jax.jit(vmap(jax_brentq(f))) - res = f_opt(starts, ends, (args,)) - valid_idx = jnp.where(res['status'] == ECONVERGED)[0] - xs = res['root'][valid_idx] - ys = args[valid_idx] - return xs, ys + f = f_without_jaxarray_return(f) + f_opt = jax.jit(vmap(jax_brentq(f))) + res = f_opt(starts, ends, (args,)) + valid_idx = jnp.where(res['status'] == ECONVERGED)[0] + xs = res['root'][valid_idx] + ys = args[valid_idx] + return xs, ys # @tools.numba_jit def numpy_brentq(f, a, b, args=(), xtol=2e-14, maxiter=200, rtol=4 * np.finfo(float).eps): - """ - Find a root of a function in a bracketing interval using Brent's method - adapted from Scipy's brentq. - - Uses the classic Brent's method to find a zero of the function `f` on - the sign changing interval [a , b]. - - Parameters:: - - f : callable - Python function returning a number. `f` must be continuous. - a : number - One end of the bracketing interval [a,b]. - b : number - The other end of the bracketing interval [a,b]. - args : tuple, optional(default=()) - Extra arguments to be used in the function call. - xtol : number, optional(default=2e-12) - The computed root ``x0`` will satisfy ``np.allclose(x, x0, - atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The - parameter must be nonnegative. - rtol : number, optional(default=4*np.finfo(float).eps) - The computed root ``x0`` will satisfy ``np.allclose(x, x0, - atol=xtol, rtol=rtol)``, where ``x`` is the exact root. - maxiter : number, optional(default=100) - Maximum number of iterations. - """ - if xtol <= 0: - raise ValueError("xtol is too small (<= 0)") - if maxiter < 1: - raise ValueError("maxiter must be greater than 0") - - # Convert to float - xpre = a * 1.0 - xcur = b * 1.0 - - # Conditional checks for intervals in methods involving bisection - fpre = f(xpre, *args) - fcur = f(xcur, *args) - funcalls = 2 - - if fpre * fcur > 0: - raise ValueError("f(a) and f(b) must have different signs") - root = 0.0 - status = ECONVERR - - # Root found at either end of [a,b] - if fpre == 0: - root = xpre - status = ECONVERGED - if fcur == 0: - root = xcur - status = ECONVERGED - - root, status = root, status - - # Check for sign error and early termination - if status == ECONVERGED: - itr = 0 - else: - # Perform Brent's method - for itr in range(maxiter): - if fpre * fcur < 0: - xblk = xpre - fblk = fpre - spre = scur = xcur - xpre - if abs(fblk) < abs(fcur): - xpre = xcur - xcur = xblk - xblk = xpre - - fpre = fcur - fcur = fblk - fblk = fpre - - delta = (xtol + rtol * abs(xcur)) / 2 - sbis = (xblk - xcur) / 2 - - # Root found - if fcur == 0 or abs(sbis) < delta: + """ + Find a root of a function in a bracketing interval using Brent's method + adapted from Scipy's brentq. + + Uses the classic Brent's method to find a zero of the function `f` on + the sign changing interval [a , b]. + + Parameters:: + + f : callable + Python function returning a number. `f` must be continuous. + a : number + One end of the bracketing interval [a,b]. + b : number + The other end of the bracketing interval [a,b]. + args : tuple, optional(default=()) + Extra arguments to be used in the function call. + xtol : number, optional(default=2e-12) + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The + parameter must be nonnegative. + rtol : number, optional(default=4*np.finfo(float).eps) + The computed root ``x0`` will satisfy ``np.allclose(x, x0, + atol=xtol, rtol=rtol)``, where ``x`` is the exact root. + maxiter : number, optional(default=100) + Maximum number of iterations. + """ + if xtol <= 0: + raise ValueError("xtol is too small (<= 0)") + if maxiter < 1: + raise ValueError("maxiter must be greater than 0") + + # Convert to float + xpre = a * 1.0 + xcur = b * 1.0 + + # Conditional checks for intervals in methods involving bisection + fpre = f(xpre, *args) + fcur = f(xcur, *args) + funcalls = 2 + + if fpre * fcur > 0: + raise ValueError("f(a) and f(b) must have different signs") + root = 0.0 + status = ECONVERR + + # Root found at either end of [a,b] + if fpre == 0: + root = xpre status = ECONVERGED + if fcur == 0: root = xcur - itr += 1 - break + status = ECONVERGED - if abs(spre) > delta and abs(fcur) < abs(fpre): - if xpre == xblk: - # interpolate - stry = -fcur * (xcur - xpre) / (fcur - fpre) - else: - # extrapolate - dpre = (fpre - fcur) / (xpre - xcur) - dblk = (fblk - fcur) / (xblk - xcur) - stry = -fcur * (fblk * dblk - fpre * dpre) / \ - (dblk * dpre * (fblk - fpre)) - - if 2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta): - # good short step - spre = scur - scur = stry - else: - # bisect - spre = sbis - scur = sbis - else: - # bisect - spre = sbis - scur = sbis - - xpre = xcur - fpre = fcur - if abs(scur) > delta: - xcur += scur - else: - xcur += (delta if sbis > 0 else -delta) - fcur = f(xcur, *args) - funcalls += 1 - - if status == ECONVERR: - raise RuntimeError("Failed to converge") - - # x, funcalls, iterations = root, funcalls, itr - return root, funcalls, itr + root, status = root, status + + # Check for sign error and early termination + if status == ECONVERGED: + itr = 0 + else: + # Perform Brent's method + for itr in range(maxiter): + if fpre * fcur < 0: + xblk = xpre + fblk = fpre + spre = scur = xcur - xpre + if abs(fblk) < abs(fcur): + xpre = xcur + xcur = xblk + xblk = xpre + + fpre = fcur + fcur = fblk + fblk = fpre + + delta = (xtol + rtol * abs(xcur)) / 2 + sbis = (xblk - xcur) / 2 + + # Root found + if fcur == 0 or abs(sbis) < delta: + status = ECONVERGED + root = xcur + itr += 1 + break + + if abs(spre) > delta and abs(fcur) < abs(fpre): + if xpre == xblk: + # interpolate + stry = -fcur * (xcur - xpre) / (fcur - fpre) + else: + # extrapolate + dpre = (fpre - fcur) / (xpre - xcur) + dblk = (fblk - fcur) / (xblk - xcur) + stry = -fcur * (fblk * dblk - fpre * dpre) / \ + (dblk * dpre * (fblk - fpre)) + + if 2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta): + # good short step + spre = scur + scur = stry + else: + # bisect + spre = sbis + scur = sbis + else: + # bisect + spre = sbis + scur = sbis + + xpre = xcur + fpre = fcur + if abs(scur) > delta: + xcur += scur + else: + xcur += (delta if sbis > 0 else -delta) + fcur = f(xcur, *args) + funcalls += 1 + + if status == ECONVERR: + raise RuntimeError("Failed to converge") + + # x, funcalls, iterations = root, funcalls, itr + return root, funcalls, itr # @tools.numba_jit def find_root_of_1d_numpy(f, f_points, args=(), tol=1e-8): - """Find the roots of the given function by numerical methods. - - Parameters:: - - f : callable - The function. - f_points : np.ndarray, list, tuple - The value points. - - Returns:: - - roots : list - The roots. - """ - vals = f(f_points, *args) - fs_len = len(f_points) - signs = np.sign(vals) - - roots = [] - sign_l = signs[0] - point_l = f_points[0] - idx = 1 - while idx < fs_len and sign_l == 0.: - roots.append(f_points[idx - 1]) - sign_l = signs[idx] - idx += 1 - while idx < fs_len: - sign_r = signs[idx] - point_r = f_points[idx] - if sign_r == 0.: - roots.append(point_r) - if idx + 1 < fs_len: - sign_l = sign_r - point_l = point_r - else: - break - idx += 1 - else: - if not np.isnan(sign_r) and sign_l != sign_r: - root, funcalls, itr = numpy_brentq(f, point_l, point_r, args) - if abs(f(root, *args)) < tol: roots.append(root) - sign_l = sign_r - point_l = point_r - idx += 1 - - return roots + """Find the roots of the given function by numerical methods. + + Parameters:: + + f : callable + The function. + f_points : np.ndarray, list, tuple + The value points. + + Returns:: + + roots : list + The roots. + """ + vals = f(f_points, *args) + fs_len = len(f_points) + signs = np.sign(vals) + + roots = [] + sign_l = signs[0] + point_l = f_points[0] + idx = 1 + while idx < fs_len and sign_l == 0.: + roots.append(f_points[idx - 1]) + sign_l = signs[idx] + idx += 1 + while idx < fs_len: + sign_r = signs[idx] + point_r = f_points[idx] + if sign_r == 0.: + roots.append(point_r) + if idx + 1 < fs_len: + sign_l = sign_r + point_l = point_r + else: + break + idx += 1 + else: + if not np.isnan(sign_r) and sign_l != sign_r: + root, funcalls, itr = numpy_brentq(f, point_l, point_r, args) + if abs(f(root, *args)) < tol: roots.append(root) + sign_l = sign_r + point_l = point_r + idx += 1 + + return roots diff --git a/brainpy/_src/analysis/utils/others.py b/brainpy/_src/analysis/utils/others.py index 9b12aa708..7466cdbfb 100644 --- a/brainpy/_src/analysis/utils/others.py +++ b/brainpy/_src/analysis/utils/others.py @@ -12,170 +12,170 @@ from .measurement import euclidean_distance_jax __all__ = [ - 'Segment', - 'check_initials', - 'check_plot_durations', - 'get_sign', - 'get_sign2', - 'keep_unique', - 'rescale', + 'Segment', + 'check_initials', + 'check_plot_durations', + 'get_sign', + 'get_sign2', + 'keep_unique', + 'rescale', ] class Segment(object): - def __init__(self, targets, num_segments): - assert isinstance(targets, (tuple, list)) - # num segments - if isinstance(num_segments, int): - num_segments = tuple([num_segments] * len(targets)) - assert isinstance(num_segments, (tuple, list)) and len(num_segments) == len(targets) - arg_lens = tuple(len(p) for p in targets) - self. arg_pre_len = tuple(int(np.ceil(l / num_segments[i])) for i, l in enumerate(arg_lens)) - arg_id_segments = tuple(np.arange(0, l, self. arg_pre_len[i]) for i, l in enumerate(arg_lens)) - self. arg_id_segments = tuple(ids.flatten() for ids in np.meshgrid(*arg_id_segments)) - if len(arg_id_segments) == 0: - self. arg_id_segments = ((0,),) - self.targets = targets - - def __iter__(self): - for ids in zip(*self. arg_id_segments): - yield tuple(p[ids[i]: ids[i] + self. arg_pre_len[i]] for i, p in enumerate(self.targets)) + def __init__(self, targets, num_segments): + assert isinstance(targets, (tuple, list)) + # num segments + if isinstance(num_segments, int): + num_segments = tuple([num_segments] * len(targets)) + assert isinstance(num_segments, (tuple, list)) and len(num_segments) == len(targets) + arg_lens = tuple(len(p) for p in targets) + self.arg_pre_len = tuple(int(np.ceil(l / num_segments[i])) for i, l in enumerate(arg_lens)) + arg_id_segments = tuple(np.arange(0, l, self.arg_pre_len[i]) for i, l in enumerate(arg_lens)) + self.arg_id_segments = tuple(ids.flatten() for ids in np.meshgrid(*arg_id_segments)) + if len(arg_id_segments) == 0: + self.arg_id_segments = ((0,),) + self.targets = targets + + def __iter__(self): + for ids in zip(*self.arg_id_segments): + yield tuple(p[ids[i]: ids[i] + self.arg_pre_len[i]] for i, p in enumerate(self.targets)) def check_initials(initials, target_var_names): - # check the initial values - assert isinstance(initials, dict) - for p in target_var_names: - assert p in initials - initials = {p: bm.as_jax(initials[p], dtype=bm.float_) for p in target_var_names} - len_of_init = [] - for v in initials.values(): - assert isinstance(v, (tuple, list, np.ndarray, jnp.ndarray, bm.ndarray)) - len_of_init.append(len(v)) - len_of_init = np.unique(len_of_init) - assert len(len_of_init) == 1 - return initials + # check the initial values + assert isinstance(initials, dict) + for p in target_var_names: + assert p in initials + initials = {p: bm.as_jax(initials[p], dtype=bm.float_) for p in target_var_names} + len_of_init = [] + for v in initials.values(): + assert isinstance(v, (tuple, list, np.ndarray, jnp.ndarray, bm.ndarray)) + len_of_init.append(len(v)) + len_of_init = np.unique(len_of_init) + assert len(len_of_init) == 1 + return initials def check_plot_durations(plot_durations, duration, initials): - if plot_durations is None: - plot_durations = [(0., duration) for _ in range(len(initials))] - if isinstance(plot_durations[0], (int, float)): - assert len(plot_durations) == 2 - plot_durations = [plot_durations for _ in range(len(initials))] - else: - assert len(plot_durations) == len(initials) - for dur in plot_durations: - assert len(dur) == 2 - return plot_durations + if plot_durations is None: + plot_durations = [(0., duration) for _ in range(len(initials))] + if isinstance(plot_durations[0], (int, float)): + assert len(plot_durations) == 2 + plot_durations = [plot_durations for _ in range(len(initials))] + else: + assert len(plot_durations) == len(initials) + for dur in plot_durations: + assert len(dur) == 2 + return plot_durations def get_sign(f, xs, ys): - f = f_without_jaxarray_return(f) - xs = xs.value if isinstance(xs, bm.BaseArray) else xs - ys = ys.value if isinstance(ys, bm.BaseArray) else ys - Y, X = jnp.meshgrid(ys, xs) - return jnp.sign(f(X, Y)) + f = f_without_jaxarray_return(f) + xs = xs.value if isinstance(xs, bm.BaseArray) else xs + ys = ys.value if isinstance(ys, bm.BaseArray) else ys + Y, X = jnp.meshgrid(ys, xs) + return jnp.sign(f(X, Y)) def get_sign2(f, *xyz, args=()): - in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) - f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) - xyz = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in xyz) - XYZ = jnp.meshgrid(*xyz) - XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) - shape = (len(v) for v in xyz) - return jnp.sign(f(*(XYZ + args))).reshape(shape) + in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) + f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) + xyz = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in xyz) + XYZ = jnp.meshgrid(*xyz) + XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) + shape = (len(v) for v in xyz) + return jnp.sign(f(*(XYZ + args))).reshape(shape) def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], - tolerance: float=2.5e-2): - """Filter unique fixed points by choosing a representative within tolerance. - - Parameters:: - - candidates: np.ndarray, dict - The fixed points with the shape of (num_point, num_dim). - tolerance: float - tolerance. - - Returns:: - - fps_and_ids : tuple - A 2-tuple of (kept fixed points, ids of kept fixed points). - """ - if isinstance(candidates, dict): - element = tuple(candidates.values())[0] - num_fps = element.shape[0] - dtype = element.dtype - else: - num_fps = candidates.shape[0] - dtype = candidates.dtype - keep_ids = np.arange(num_fps) - if tolerance <= 0.0: - return candidates, keep_ids - if num_fps <= 1: - return candidates, keep_ids - candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - - # If point A and point B are within identical_tol of each other, and the - # A is first in the list, we keep A. - distances = np.asarray(euclidean_distance_jax(candidates, num_fps)) - example_idxs = np.arange(num_fps) - all_drop_idxs = [] - for fidx in range(num_fps - 1): - distances_f = distances[fidx, fidx + 1:] - drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance] - all_drop_idxs += list(drop_idxs) - keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs)) - if keep_ids.shape[0] > 0: - unique_fps = tree_map(lambda a: a[keep_ids], candidates) - else: - unique_fps = np.array([], dtype=dtype) - return unique_fps, keep_ids + tolerance: float = 2.5e-2): + """Filter unique fixed points by choosing a representative within tolerance. + + Parameters:: + + candidates: np.ndarray, dict + The fixed points with the shape of (num_point, num_dim). + tolerance: float + tolerance. + + Returns:: + + fps_and_ids : tuple + A 2-tuple of (kept fixed points, ids of kept fixed points). + """ + if isinstance(candidates, dict): + element = tuple(candidates.values())[0] + num_fps = element.shape[0] + dtype = element.dtype + else: + num_fps = candidates.shape[0] + dtype = candidates.dtype + keep_ids = np.arange(num_fps) + if tolerance <= 0.0: + return candidates, keep_ids + if num_fps <= 1: + return candidates, keep_ids + candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)) + + # If point A and point B are within identical_tol of each other, and the + # A is first in the list, we keep A. + distances = np.asarray(euclidean_distance_jax(candidates, num_fps)) + example_idxs = np.arange(num_fps) + all_drop_idxs = [] + for fidx in range(num_fps - 1): + distances_f = distances[fidx, fidx + 1:] + drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance] + all_drop_idxs += list(drop_idxs) + keep_ids = np.setdiff1d(example_idxs, np.unique(all_drop_idxs)) + if keep_ids.shape[0] > 0: + unique_fps = tree_map(lambda a: a[keep_ids], candidates) + else: + unique_fps = np.array([], dtype=dtype) + return unique_fps, keep_ids def keep_unique_jax(candidates, tolerance=2.5e-2): - """Filter unique fixed points by choosing a representative within tolerance. - - Parameters:: - - candidates: Tesnor - The fixed points with the shape of (num_point, num_dim). - - Returns:: - - fps_and_ids : tuple - A 2-tuple of (kept fixed points, ids of kept fixed points). - """ - keep_ids = np.arange(candidates.shape[0]) - if tolerance <= 0.0: - return candidates, keep_ids - if candidates.shape[0] <= 1: - return candidates, keep_ids - - # If point A and point B are within identical_tol of each other, and the - # A is first in the list, we keep A. - nfps = candidates.shape[0] - distances = euclidean_distance_jax(candidates) - example_idxs = np.arange(nfps) - all_drop_idxs = [] - for fidx in range(nfps - 1): - distances_f = distances[fidx, fidx + 1:] - drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance] - all_drop_idxs += list(drop_idxs) - keep_ids = np.setdiff1d(example_idxs, np.unique(np.asarray(all_drop_idxs))) - if keep_ids.shape[0] > 0: - unique_fps = candidates[keep_ids, :] - else: - unique_fps = np.array([], dtype=candidates.dtype) - return unique_fps, keep_ids + """Filter unique fixed points by choosing a representative within tolerance. + + Parameters:: + + candidates: Tesnor + The fixed points with the shape of (num_point, num_dim). + + Returns:: + + fps_and_ids : tuple + A 2-tuple of (kept fixed points, ids of kept fixed points). + """ + keep_ids = np.arange(candidates.shape[0]) + if tolerance <= 0.0: + return candidates, keep_ids + if candidates.shape[0] <= 1: + return candidates, keep_ids + + # If point A and point B are within identical_tol of each other, and the + # A is first in the list, we keep A. + nfps = candidates.shape[0] + distances = euclidean_distance_jax(candidates) + example_idxs = np.arange(nfps) + all_drop_idxs = [] + for fidx in range(nfps - 1): + distances_f = distances[fidx, fidx + 1:] + drop_idxs = example_idxs[fidx + 1:][distances_f <= tolerance] + all_drop_idxs += list(drop_idxs) + keep_ids = np.setdiff1d(example_idxs, np.unique(np.asarray(all_drop_idxs))) + if keep_ids.shape[0] > 0: + unique_fps = candidates[keep_ids, :] + else: + unique_fps = np.array([], dtype=candidates.dtype) + return unique_fps, keep_ids def rescale(min_max, scale=0.01): - """Rescale lim.""" - min_, max_ = min_max - length = max_ - min_ - min_ -= scale * length - max_ += scale * length - return min_, max_ + """Rescale lim.""" + min_, max_ = min_max + length = max_ - min_ + min_ -= scale * length + max_ += scale * length + return min_, max_ diff --git a/brainpy/_src/analysis/utils/outputs.py b/brainpy/_src/analysis/utils/outputs.py index adc695490..7b0cf53a0 100644 --- a/brainpy/_src/analysis/utils/outputs.py +++ b/brainpy/_src/analysis/utils/outputs.py @@ -3,11 +3,11 @@ import sys __all__ = [ - 'output', + 'output', ] def output(msg, file=None): - if file is None: - file = sys.stderr - print(msg, file=file) + if file is None: + file = sys.stderr + print(msg, file=file) diff --git a/brainpy/_src/analysis/utils/visualization.py b/brainpy/_src/analysis/utils/visualization.py index f590e0f46..b563d8f95 100644 --- a/brainpy/_src/analysis/utils/visualization.py +++ b/brainpy/_src/analysis/utils/visualization.py @@ -4,32 +4,32 @@ def add_arrow(line, position=None, direction='right', size=15, color=None): - """ - add an arrow to a line. + """ + add an arrow to a line. - line: Line2D object - position: x-position of the arrow. If None, mean of xdata is taken - direction: 'left' or 'right' - size: size of the arrow in fontsize points - color: if None, line color is taken. - """ - if color is None: - color = line.get_color() + line: Line2D object + position: x-position of the arrow. If None, mean of xdata is taken + direction: 'left' or 'right' + size: size of the arrow in fontsize points + color: if None, line color is taken. + """ + if color is None: + color = line.get_color() - xdata = line.get_xdata() - ydata = line.get_ydata() + xdata = line.get_xdata() + ydata = line.get_ydata() - if position is None: - position = xdata.mean() - # find closest index - start_ind = np.argmin(np.absolute(xdata - position)) - if direction == 'right': - end_ind = start_ind + 1 - else: - end_ind = start_ind - 1 + if position is None: + position = xdata.mean() + # find closest index + start_ind = np.argmin(np.absolute(xdata - position)) + if direction == 'right': + end_ind = start_ind + 1 + else: + end_ind = start_ind - 1 - line.axes.annotate(text='', - xytext=(xdata[start_ind], ydata[start_ind]), - xy=(xdata[end_ind], ydata[end_ind]), - arrowprops=dict(arrowstyle="->", color=color), - size=size) + line.axes.annotate(text='', + xytext=(xdata[start_ind], ydata[start_ind]), + xy=(xdata[end_ind], ydata[end_ind]), + arrowprops=dict(arrowstyle="->", color=color), + size=size) diff --git a/brainpy/_src/checking.py b/brainpy/_src/checking.py index 41ad2c283..609b79166 100644 --- a/brainpy/_src/checking.py +++ b/brainpy/_src/checking.py @@ -5,48 +5,48 @@ from brainpy._src.deprecations import deprecation_getattr2 __deprecations = { - 'check_shape_consistency': ('brainpy.checking.check_shape_consistency', - 'brainpy.check.is_shape_consistency', - check.is_shape_consistency), - 'check_shape_broadcastable': ('brainpy.checking.check_shape_broadcastable', - 'brainpy.check.is_shape_broadcastable', - check.is_shape_broadcastable), - 'check_shape_except_batch': ('brainpy.checking.check_shape_except_batch', - 'brainpy.check.check_shape_except_batch', - check.check_shape_except_batch), - 'check_shape': ('brainpy.checking.check_shape', - 'brainpy.check.check_shape', - check.check_shape), - 'check_dict_data': ('brainpy.checking.check_dict_data', - 'brainpy.check.is_dict_data', - check.is_dict_data), - 'check_callable': ('brainpy.checking.check_callable', - 'brainpy.check.is_callable', - check.is_callable), - 'check_initializer': ('brainpy.checking.check_initializer', - 'brainpy.check.is_initializer', - check.is_initializer), - 'check_connector': ('brainpy.checking.check_connector', - 'brainpy.check.is_connector', - check.is_connector), - 'check_float': ('brainpy.checking.check_float', - 'brainpy.check.is_float', - check.is_float), - 'check_integer': ('brainpy.checking.check_integer', - 'brainpy.check.is_integer', - check.is_integer), - 'check_string': ('brainpy.checking.check_string', - 'brainpy.check.is_string', - check.is_string), - 'check_sequence': ('brainpy.checking.check_sequence', - 'brainpy.check.is_sequence', - check.is_sequence), - 'check_mode': ('brainpy.checking.check_mode', - 'brainpy.check.is_subclass', - check.is_subclass), - 'serialize_kwargs': ('brainpy.checking.serialize_kwargs', - 'brainpy.check.serialize_kwargs', - check.serialize_kwargs), + 'check_shape_consistency': ('brainpy.checking.check_shape_consistency', + 'brainpy.check.is_shape_consistency', + check.is_shape_consistency), + 'check_shape_broadcastable': ('brainpy.checking.check_shape_broadcastable', + 'brainpy.check.is_shape_broadcastable', + check.is_shape_broadcastable), + 'check_shape_except_batch': ('brainpy.checking.check_shape_except_batch', + 'brainpy.check.check_shape_except_batch', + check.check_shape_except_batch), + 'check_shape': ('brainpy.checking.check_shape', + 'brainpy.check.check_shape', + check.check_shape), + 'check_dict_data': ('brainpy.checking.check_dict_data', + 'brainpy.check.is_dict_data', + check.is_dict_data), + 'check_callable': ('brainpy.checking.check_callable', + 'brainpy.check.is_callable', + check.is_callable), + 'check_initializer': ('brainpy.checking.check_initializer', + 'brainpy.check.is_initializer', + check.is_initializer), + 'check_connector': ('brainpy.checking.check_connector', + 'brainpy.check.is_connector', + check.is_connector), + 'check_float': ('brainpy.checking.check_float', + 'brainpy.check.is_float', + check.is_float), + 'check_integer': ('brainpy.checking.check_integer', + 'brainpy.check.is_integer', + check.is_integer), + 'check_string': ('brainpy.checking.check_string', + 'brainpy.check.is_string', + check.is_string), + 'check_sequence': ('brainpy.checking.check_sequence', + 'brainpy.check.is_sequence', + check.is_sequence), + 'check_mode': ('brainpy.checking.check_mode', + 'brainpy.check.is_subclass', + check.is_subclass), + 'serialize_kwargs': ('brainpy.checking.serialize_kwargs', + 'brainpy.check.serialize_kwargs', + check.serialize_kwargs), } __getattr__ = deprecation_getattr2('brainpy.checking', __deprecations) del deprecation_getattr2 diff --git a/brainpy/_src/checkpoints/serialization.py b/brainpy/_src/checkpoints/serialization.py index 6f3ccc3bc..c259afb6a 100644 --- a/brainpy/_src/checkpoints/serialization.py +++ b/brainpy/_src/checkpoints/serialization.py @@ -1,1224 +1,27 @@ -"""Checkpointing helper functions. - -This module is rewritten from the Flax APIs (https://github.com/google/flax). -""" - -import enum -import functools -import logging -import os -import pathlib -import re -import shutil -import sys -import threading -import time -import warnings -from concurrent.futures import thread -from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Any, Optional +import braintools import jax -import numpy as np -from jax import monitoring -from jax import process_index -from jax.experimental.multihost_utils import sync_global_devices -try: - from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa -except: - get_tensorstore_spec = GlobalAsyncCheckpointManager = None - -try: - import msgpack -except ModuleNotFoundError: - msgpack = None +from braintools.file.msg_checkpoint import register_serialization_state, AsyncManager from brainpy._src.math.ndarray import BaseArray -from brainpy.errors import (AlreadyExistsError, - MPACheckpointingRequiredError, - MPARestoreTargetRequiredError, - MPARestoreDataCorruptedError, - InvalidCheckpointPath, - InvalidCheckpointError) from brainpy.types import PyTree __all__ = [ - # saving - 'save', 'multiprocess_save', 'save_pytree', 'load_pytree', - # loading - 'load', - # async - 'AsyncManager', + 'save_pytree', 'load_pytree', 'AsyncManager', ] -_LAST_CHECKPOINT_WRITE_TIME = time.time() -_READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' -_WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' - -# Single-group reg-exps for int or float numerical substrings. -# captures sign: -SIGNED_FLOAT_RE = re.compile(r'([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') -# does not capture sign: -UNSIGNED_FLOAT_RE = re.compile(r'[-+]?((?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') -# Module name followed by number. -MODULE_NUM_RE = re.compile(r'(.*)_\d+$') -# Alternative schemes handled by `gfile`, e.g. on Google Cloud Storage (GCS). -SCHEME_RE = re.compile('^(?P[a-z][a-z0-9.+-]+://)?(?P.*)', re.I) - -# Multiprocess arrays (GlobalDeviceArray, or JAX array with multiprocess -# sharding) is across processes and will be stored in directories with this -# postfix, seperated from the non-distributed data (e.g. the larger pytree) -MP_ARRAY_POSTFIX = '_gda' -# Occurrences of multiprocess arrays in the target pytree will be -# replaced by this string placeholder. -MP_ARRAY_PH = '//GDAPlaceholder:' - -# Add a copy-success file to a distributed array directory to indicate the -# array save is complete. -# We need this for GCS because GCS's directory move is not atomic. -COMMIT_SUCCESS_FILE = 'commit_success.txt' - -# Orbax main checkpoint file name. -ORBAX_CKPT_FILENAME = 'checkpoint' - -# Chunking array leaves - -# msgpack has a hard limit of 2**31 - 1 bytes per object leaf. To circumvent -# this limit for giant arrays (e.g. embedding tables), we traverse the tree -# and break up arrays near the limit into flattened array chunks. -# True limit is 2**31 - 1, but leave a margin for encoding padding. -MAX_CHUNK_SIZE = 2 ** 30 - -# containing jax.Array attribute. -MultiprocessArrayType = Any - -_STATE_DICT_REGISTRY: Dict[Any, Any] = {} - - -class _ErrorContext(threading.local): - """Context for deserialization error messages.""" - - def __init__(self): - self.path = [] - - -_error_context = _ErrorContext() - - -@contextmanager -def _record_path(name): - try: - _error_context.path.append(name) - yield - finally: - _error_context.path.pop() - - -def check_msgpack(): - if msgpack is None: - raise ModuleNotFoundError('\nbrainpy.checkpoints needs "msgpack" package. Please install msgpack via:\n' - '> pip install msgpack') - - -def current_path(): - """Current state_dict path during deserialization for error messages.""" - return '/'.join(_error_context.path) - - -class _NamedTuple: - """Fake type marker for namedtuple for registry.""" - pass - - -def _is_namedtuple(x): - """Duck typing test for namedtuple factory-generated objects.""" - return isinstance(x, tuple) and hasattr(x, '_fields') - - -def from_state_dict(target, state: Dict[str, Any], name: str = '.'): - """Restores the state of the given target using a state dict. - - This function takes the current target as an argument. This - lets us know the exact structure of the target, - as well as lets us add assertions that shapes and dtypes don't change. - - In practice, none of the leaf values in `target` are actually - used. Only the tree structure, shapes and dtypes. - - Args: - target: the object of which the state should be restored. - state: a dictionary generated by `to_state_dict` with the desired new - state for `target`. - name: name of branch taken, used to improve deserialization error messages. - Returns: - A copy of the object with the restored state. - """ - ty = _NamedTuple if _is_namedtuple(target) else type(target) - for t in _STATE_DICT_REGISTRY.keys(): - if issubclass(ty, t): - ty = t - break - else: - return state - ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1] - with _record_path(name): - return ty_from_state_dict(target, state) - - -def to_state_dict(target) -> Dict[str, Any]: - """Returns a dictionary with the state of the given target.""" - ty = _NamedTuple if _is_namedtuple(target) else type(target) - - for t in _STATE_DICT_REGISTRY.keys(): - if issubclass(ty, t): - ty = t - break - else: - return target - - ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0] - state_dict = ty_to_state_dict(target) - if isinstance(state_dict, dict): - for key in state_dict.keys(): - assert isinstance(key, str), 'A state dict must only have string keys.' - return state_dict - elif isinstance(state_dict, jax.Array): - return state_dict - else: - raise TypeError - - -def register_serialization_state(ty, - ty_to_state_dict, - ty_from_state_dict, - override=False): - """Register a type for serialization. - - Args: - ty: the type to be registered - ty_to_state_dict: a function that takes an instance of ty and - returns its state as a dictionary. - ty_from_state_dict: a function that takes an instance of ty and - a state dict, and returns a copy of the instance with the restored state. - override: override a previously registered serialization handler - (default: False). - """ - if ty in _STATE_DICT_REGISTRY and not override: - raise ValueError(f'a serialization handler for "{ty.__name__}"' - ' is already registered') - _STATE_DICT_REGISTRY[ty] = (ty_to_state_dict, ty_from_state_dict) - - -def _list_state_dict(xs: List[Any]) -> Dict[str, Any]: - return {str(i): to_state_dict(x) for i, x in enumerate(xs)} - - -def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]: - if len(state_dict) != len(xs): - raise ValueError('The size of the list and the state dict do not match,' - f' got {len(xs)} and {len(state_dict)} ' - f'at path {current_path()}') - ys = [] - for i in range(len(state_dict)): - y = from_state_dict(xs[i], state_dict[str(i)], name=str(i)) - ys.append(y) - return ys - def _array_dict_state(x: BaseArray) -> Dict[str, jax.Array]: - return x.value + return x.value def _restore_array(x, state_dict: jax.Array) -> BaseArray: - x.value = state_dict - return x - - -def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]: - str_keys = set(str(k) for k in xs.keys()) - if len(str_keys) != len(xs): - raise ValueError('Dict keys do not have a unique string representation: ' - f'{str_keys} vs given: {xs}') - return {str(key): to_state_dict(value) for key, value in xs.items()} - - -def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]: - diff = set(map(str, xs.keys())).difference(states.keys()) - if diff: - raise ValueError('The target dict keys and state dict keys do not match,' - f' target dict contains keys {diff} which are not present in state dict ' - f'at path {current_path()}') - - return {key: from_state_dict(value, states[str(key)], name=str(key)) - for key, value in xs.items()} - - -def _namedtuple_state_dict(nt) -> Dict[str, Any]: - return {key: to_state_dict(getattr(nt, key)) for key in nt._fields} - - -def _restore_namedtuple(xs, state_dict: Dict[str, Any]): - """Rebuild namedtuple from serialized dict.""" - if set(state_dict.keys()) == {'name', 'fields', 'values'}: - state_dict = {state_dict['fields'][str(i)]: state_dict['values'][str(i)] - for i in range(len(state_dict['fields']))} - - sd_keys = set(state_dict.keys()) - nt_keys = set(xs._fields) - - if sd_keys != nt_keys: - raise ValueError('The field names of the state dict and the named tuple do not match,' - f' got {sd_keys} and {nt_keys} at path {current_path()}') - fields = { - k: from_state_dict(getattr(xs, k), v, name=k) - for k, v in state_dict.items() - } - return type(xs)(**fields) + x.value = state_dict + return x register_serialization_state(BaseArray, _array_dict_state, _restore_array) -register_serialization_state(dict, _dict_state_dict, _restore_dict) -# register_serialization_state(DotDict, _dict_state_dict, _restore_dict) -# register_serialization_state(Collector, _dict_state_dict, _restore_dict) -# register_serialization_state(ArrayCollector, _dict_state_dict, _restore_dict) -register_serialization_state(list, _list_state_dict, _restore_list) -register_serialization_state(tuple, - _list_state_dict, - lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict))) -register_serialization_state(_NamedTuple, - _namedtuple_state_dict, - _restore_namedtuple) -register_serialization_state( - jax.tree_util.Partial, - lambda x: {"args": to_state_dict(x.args), - "keywords": to_state_dict(x.keywords), }, - lambda x, sd: jax.tree_util.Partial(x.func, - *from_state_dict(x.args, sd["args"]), - **from_state_dict(x.keywords, sd["keywords"])) -) - - -# On-the-wire / disk serialization format - -# We encode state-dicts via msgpack, using its custom type extension. -# https://github.com/msgpack/msgpack/blob/master/spec.md -# -# - ndarrays and DeviceArrays are serialized to nested msgpack-encoded string -# of (shape-tuple, dtype-name (e.g. 'float32'), row-major array-bytes). -# Note: only simple ndarray types are supported, no objects or fields. -# -# - native complex scalars are converted to nested msgpack-encoded tuples -# (real, imag). - - -def _ndarray_to_bytes(arr) -> bytes: - """Save ndarray to simple msgpack encoding.""" - if isinstance(arr, jax.Array): - arr = np.array(arr) - if arr.dtype.hasobject or arr.dtype.isalignedstruct: - raise ValueError('Object and structured dtypes not supported ' - 'for serialization of ndarrays.') - tpl = (arr.shape, arr.dtype.name, arr.tobytes('C')) - return msgpack.packb(tpl, use_bin_type=True) - - -def _dtype_from_name(name: str): - """Handle JAX bfloat16 dtype correctly.""" - if name == b'bfloat16': - return jax.numpy.bfloat16 - else: - return np.dtype(name) - - -def _ndarray_from_bytes(data: bytes) -> np.ndarray: - """Load ndarray from simple msgpack encoding.""" - shape, dtype_name, buffer = msgpack.unpackb(data, raw=True) - return np.frombuffer(buffer, - dtype=_dtype_from_name(dtype_name), - count=-1, - offset=0).reshape(shape, order='C') - - -class _MsgpackExtType(enum.IntEnum): - """Messagepack custom type ids.""" - ndarray = 1 - native_complex = 2 - npscalar = 3 - - -def _msgpack_ext_pack(x): - """Messagepack encoders for custom types.""" - # TODO: Array here only work when they are fully addressable. - # If they are not fully addressable, use the GDA path for checkpointing. - if isinstance(x, (np.ndarray, jax.Array)): - return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x)) - if np.issctype(type(x)): - # pack scalar as ndarray - return msgpack.ExtType(_MsgpackExtType.npscalar, - _ndarray_to_bytes(np.asarray(x))) - elif isinstance(x, complex): - return msgpack.ExtType(_MsgpackExtType.native_complex, - msgpack.packb((x.real, x.imag))) - return x - - -def _msgpack_ext_unpack(code, data): - """Messagepack decoders for custom types.""" - if code == _MsgpackExtType.ndarray: - return _ndarray_from_bytes(data) - elif code == _MsgpackExtType.native_complex: - complex_tuple = msgpack.unpackb(data) - return complex(complex_tuple[0], complex_tuple[1]) - elif code == _MsgpackExtType.npscalar: - ar = _ndarray_from_bytes(data) - return ar[()] # unpack ndarray to scalar - return msgpack.ExtType(code, data) - - -def _np_convert_in_place(d): - """Convert any jax devicearray leaves to numpy arrays in place.""" - if isinstance(d, dict): - for k, v in d.items(): - if isinstance(v, jax.Array): - d[k] = np.array(v) - elif isinstance(v, dict): - _np_convert_in_place(v) - elif isinstance(d, jax.Array): - return np.array(d) - return d - - -_tuple_to_dict = lambda tpl: {str(x): y for x, y in enumerate(tpl)} -_dict_to_tuple = lambda dct: tuple(dct[str(i)] for i in range(len(dct))) - - -def _chunk(arr) -> Dict[str, Any]: - """Convert array to a canonical dictionary of chunked arrays.""" - chunksize = max(1, int(MAX_CHUNK_SIZE / arr.dtype.itemsize)) - data = {'__msgpack_chunked_array__': True, - 'shape': _tuple_to_dict(arr.shape)} - flatarr = arr.reshape(-1) - chunks = [flatarr[i:i + chunksize] for i in range(0, flatarr.size, chunksize)] - data['chunks'] = _tuple_to_dict(chunks) - return data - - -def _unchunk(data: Dict[str, Any]): - """Convert canonical dictionary of chunked arrays back into array.""" - assert '__msgpack_chunked_array__' in data - shape = _dict_to_tuple(data['shape']) - flatarr = np.concatenate(_dict_to_tuple(data['chunks'])) - return flatarr.reshape(shape) - - -def _chunk_array_leaves_in_place(d): - """Convert oversized array leaves to safe chunked form in place.""" - if isinstance(d, dict): - for k, v in d.items(): - if isinstance(v, np.ndarray): - if v.size * v.dtype.itemsize > MAX_CHUNK_SIZE: - d[k] = _chunk(v) - elif isinstance(v, dict): - _chunk_array_leaves_in_place(v) - elif isinstance(d, np.ndarray): - if d.size * d.dtype.itemsize > MAX_CHUNK_SIZE: - return _chunk(d) - return d - - -def _unchunk_array_leaves_in_place(d): - """Convert chunked array leaves back into array leaves, in place.""" - if isinstance(d, dict): - if '__msgpack_chunked_array__' in d: - return _unchunk(d) - else: - for k, v in d.items(): - if isinstance(v, dict) and '__msgpack_chunked_array__' in v: - d[k] = _unchunk(v) - elif isinstance(v, dict): - _unchunk_array_leaves_in_place(v) - return d - - -def msgpack_serialize(pytree, in_place: bool = False) -> bytes: - """Save data structure to bytes in msgpack format. - - Low-level function that only supports python trees with array leaves, - for custom objects use `to_bytes`. It splits arrays above MAX_CHUNK_SIZE into - multiple chunks. - - Args: - pytree: python tree of dict, list, tuple with python primitives - and array leaves. - in_place: boolean specifyng if pytree should be modified in place. - - Returns: - msgpack-encoded bytes of pytree. - """ - if not in_place: - pytree = jax.tree_util.tree_map(lambda x: x, pytree) - pytree = _np_convert_in_place(pytree) - pytree = _chunk_array_leaves_in_place(pytree) - return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True) - - -def msgpack_restore(encoded_pytree: bytes): - """Restore data structure from bytes in msgpack format. - - Low-level function that only supports python trees with array leaves, - for custom objects use `from_bytes`. - - Args: - encoded_pytree: msgpack-encoded bytes of python tree. - - Returns: - Python tree of dict, list, tuple with python primitive - and array leaves. - """ - state_dict = msgpack.unpackb( - encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False) - return _unchunk_array_leaves_in_place(state_dict) - - -def from_bytes(target, encoded_bytes: bytes): - """Restore optimizer or other object from msgpack-serialized state-dict. - - Args: - target: template object with state-dict registrations that matches - the structure being deserialized from `encoded_bytes`. - encoded_bytes: msgpack serialized object structurally isomorphic to - `target`. Typically a flax model or optimizer. - - Returns: - A new object structurally isomorphic to `target` containing the updated - leaf data from saved data. - """ - state_dict = msgpack_restore(encoded_bytes) - return from_state_dict(target, state_dict) - - -def to_bytes(target) -> bytes: - """Save optimizer or other object as msgpack-serialized state-dict. - - Args: - target: template object with state-dict registrations to be - serialized to msgpack format. Typically a flax model or optimizer. - - Returns: - Bytes of msgpack-encoded state-dict of `target` object. - """ - state_dict = to_state_dict(target) - return msgpack_serialize(state_dict, in_place=True) - - -# the empty node is a struct.dataclass to be compatible with JAX. -class _EmptyNode: - pass - - -def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None): - """Flatten a nested dictionary. - - The nested keys are flattened to a tuple. - See `unflatten_dict` on how to restore the - nested dictionary structure. - - Example:: - - xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} - flat_xs = flatten_dict(xs) - print(flat_xs) - # { - # ('foo',): 1, - # ('bar', 'a'): 2, - # } - - Note that empty dictionaries are ignored and - will not be restored by `unflatten_dict`. - - Args: - xs: a nested dictionary - keep_empty_nodes: replaces empty dictionaries - with `traverse_util.empty_node`. - is_leaf: an optional function that takes the - next nested dictionary and nested keys and - returns True if the nested dictionary is a - leaf (i.e., should not be flattened further). - sep: if specified, then the keys of the returned - dictionary will be `sep`-joined strings (if - `None`, then keys will be tuples). - Returns: - The flattened dictionary. - """ - assert isinstance(xs, dict), f'expected (frozen)dict; got {type(xs)}' - - def _key(path): - if sep is None: - return path - return sep.join(path) - - def _flatten(xs, prefix): - if not isinstance(xs, dict) or (is_leaf and is_leaf(prefix, xs)): - return {_key(prefix): xs} - result = {} - is_empty = True - for key, value in xs.items(): - is_empty = False - path = prefix + (key,) - result.update(_flatten(value, path)) - if keep_empty_nodes and is_empty: - if prefix == (): # when the whole input is empty - return {} - return {_key(prefix): _EmptyNode()} - return result - - return _flatten(xs, ()) - - -def unflatten_dict(xs, sep=None): - """Unflatten a dictionary. - - See `flatten_dict` - - Example:: - - flat_xs = { - ('foo',): 1, - ('bar', 'a'): 2, - } - xs = unflatten_dict(flat_xs) - print(xs) - # { - # 'foo': 1 - # 'bar': {'a': 2} - # } - - Args: - xs: a flattened dictionary - sep: separator (same as used with `flatten_dict()`). - Returns: - The nested dictionary. - """ - assert isinstance(xs, dict), f'input is not a dict; it is a {type(xs)}' - result = {} - for path, value in xs.items(): - if sep is not None: - path = path.split(sep) - if isinstance(value, _EmptyNode): - value = {} - cursor = result - for key in path[:-1]: - if key not in cursor: - cursor[key] = {} - cursor = cursor[key] - cursor[path[-1]] = value - return result - - -def _rename_fn(src, dst, overwrite=False): - if os.path.exists(src): - if os.path.exists(dst) and not overwrite: - raise AlreadyExistsError(dst) - return os.rename(src, dst) - - -def _checkpoint_path(ckpt_dir: str, - step: Union[int, float, str], - prefix: str = 'checkpoint_') -> str: - return os.path.join(ckpt_dir, f'{prefix}{step}') - - -def _checkpoint_path_step(path: str) -> Optional[float]: - """Returns the step number of a checkpoint path.""" - for s in SIGNED_FLOAT_RE.split(path)[::-1]: - if SIGNED_FLOAT_RE.match(s): - return float(s) - return None - - -def _allowempty_listdir(path: str): - try: - return os.listdir(path) - except FileNotFoundError: - return [] - - -def _safe_remove(path: str): - """Identify whether a path is a dir or list and choose the correct remove method.""" - if os.path.isdir(path): - shutil.rmtree(path) - else: - os.remove(path) - - -class AsyncManager(object): - """A simple object to track async checkpointing. - - How to use: create an instance and pass to `brainpy.checkpoints.save()` calls: - am = AsyncManager() - brainpy.checkpoints.save(..., async_manager=am) - """ - - def __init__(self, max_workers: int = 1): - self.executor = thread.ThreadPoolExecutor(max_workers=max_workers) - self.save_future = None - - def wait_previous_save(self): - """Block until the previous save finishes, to keep files' consistency.""" - if self.save_future and not self.save_future.done(): - warnings.warn( - 'The previous async brainpy.checkpoints.save has not finished yet. Waiting ' - 'for it to complete before the next save.', - UserWarning - ) - self.save_future.result() - - def save_async(self, task: Callable[[], Any]): - """Run a task async. The future will be tracked as self.save_future. - - Args: - task: The callable to be executed asynchrously. - """ - self.wait_previous_save() - self.save_future = self.executor.submit(task) # type: ignore - - -def _use_multiprocess_serialization(value: Any) -> bool: - """Use GlobalAsyncCheckpointManager to save the array if it's only partially available on this host.""" - if isinstance(value, jax.Array): - return not value.is_fully_addressable - return False - - -def _split_mp_arrays( - target: Dict[str, Any] -) -> Tuple[Dict[str, Any], List[Tuple[MultiprocessArrayType, str]]]: - """Split out the multiprocess arrays from the target pytree to save.""" - # When target is a single leaf instead of a pytree dict. - if not isinstance(target, dict): - if _use_multiprocess_serialization(target): - return MP_ARRAY_PH, [(target, '')] - return target, [] - # Traverse the target and handle distributed arrays. - flattened = flatten_dict(target, keep_empty_nodes=True) - mpa_targets = [] - for key, value in flattened.items(): - if _use_multiprocess_serialization(value): - subpath = '/'.join(key) - mpa_targets.append((value, subpath)) - flattened[key] = MP_ARRAY_PH + subpath - target = unflatten_dict(flattened) - return target, mpa_targets - - -def _make_mpa_dirs( - mpa_targets: List[Tuple[MultiprocessArrayType, str]], - tmp_path: str -): - # Temporary array path is not used in GCS. - if tmp_path.startswith('gs://'): - return - mpa_tmp_path = tmp_path + MP_ARRAY_POSTFIX - # Clean up the previous MPA dir, in case some leftover from last preemption - # lingers. - if os.path.exists(mpa_tmp_path): - warnings.warn('Removing outdated MPA temporary files at %s' % mpa_tmp_path, UserWarning) - shutil.rmtree(mpa_tmp_path) - _, mpa_subpaths = zip(*mpa_targets) - for subpath in mpa_subpaths: - os.makedirs(os.path.join(mpa_tmp_path, subpath), exist_ok=True) - - -def _save_mpas(gda_manager, - mpa_targets: List[Tuple[MultiprocessArrayType, str]], - tmp_path: str, - final_path: str, - base_path: str, - keep: int, - overwrite: bool, - keep_every_n_steps: Optional[int], - ckpt_start_time: float, - async_manager: Optional[AsyncManager] = None): - """Save the multiprocess arrays given the paths.""" - mpa_list, mpa_subpaths = zip(*mpa_targets) - mpa_tmp_path, mpa_final_path = tmp_path + MP_ARRAY_POSTFIX, final_path + MP_ARRAY_POSTFIX - write_commit_success = False - # If the checkpoint directory is a GCS directory, then keep the final - # checkpoint directory as the temporary checkpoint directory. This is because - # renames are not atomic on GCS. When restoring check for the existence of a - # success file. - # TODO: figure out a way to unit-test the behavior. - if tmp_path.startswith('gs://'): - mpa_tmp_path = mpa_final_path - write_commit_success = True - mpa_paths = [os.path.join(mpa_tmp_path, x) for x in mpa_subpaths] - ts_specs = [get_tensorstore_spec(x) for x in mpa_paths] - gda_manager.serialize( - list(mpa_list), - ts_specs, - on_commit_callback=functools.partial( - _save_commit, - tmp_path, - final_path, - base_path, - keep, - overwrite, - keep_every_n_steps, - ckpt_start_time, - has_mpa=True, - write_commit_success=write_commit_success, - async_manager=async_manager)) - - -def _restore_mpas(state_dict, - target: Optional[Any], - ckpt_path: str, - step: Optional[Union[int, float]], - gda_manager: Optional[Any], - allow_partial: bool = False): - """Restore the multiprocess arrays given the target structure and type.""" - - def _check_mpa_errors(): - if not gda_manager: - raise MPACheckpointingRequiredError(ckpt_path, step) - if not target and not allow_partial: - raise MPARestoreTargetRequiredError(ckpt_path, step) - - def _safe_deserialize( - target_mpas: List[Tuple[Tuple[Any, ...], MultiprocessArrayType, str]], - gda_manager: Any - ) -> List[MultiprocessArrayType]: - gda_manager.wait_until_finished() - - # Check if reading from GCS and the array dir is potentially corrupted. - if ckpt_path.startswith('gs://') and not os.path.exists( - os.path.join(ckpt_path + MP_ARRAY_POSTFIX, COMMIT_SUCCESS_FILE)): - raise MPARestoreDataCorruptedError(step, ckpt_path) - - # Check if the given target array types are valid. - shardings = [] - for _, arr, path in target_mpas: - if isinstance(arr, jax.Array): - shardings.append(arr.sharding) - - # Restore the arrays. - ts_specs = [get_tensorstore_spec(path) for _, _, path in target_mpas] - return gda_manager.deserialize(shardings, ts_specs) - - # When target is a single leaf instead of a pytree dict. - if not isinstance(state_dict, dict): - if (_use_multiprocess_serialization(target) and - isinstance(state_dict, str) and - state_dict.startswith(MP_ARRAY_PH)): - _check_mpa_errors() - return _safe_deserialize([((), target, ckpt_path + MP_ARRAY_POSTFIX)], gda_manager)[0] - return state_dict - - # Go through the restored checkpoint pytree for all MPAs - flattened = flatten_dict(state_dict, keep_empty_nodes=True) - if target: - target_flattened = flatten_dict(to_state_dict(target), keep_empty_nodes=True) - # A list of (state_dict_key, target_array, array_file_path) for every array - # to be restored - target_mpas = [] - for key, value in flattened.items(): - if isinstance(value, str) and value.startswith(MP_ARRAY_PH): - _check_mpa_errors() - if (not target or - (key not in target_flattened) or - (not _use_multiprocess_serialization(target_flattened[key]))): - if allow_partial: - warnings.warn(f'Multiprocess array {key} could not be restored ' - 'because a valid array is not found in target at ' - 'the corresponding location. Proceed to restore ' - 'other arrays because allow_partial_restoration=True', - UserWarning) - else: - raise MPARestoreTargetRequiredError(ckpt_path, step, key) - else: - mpa_path = os.path.join(ckpt_path + MP_ARRAY_POSTFIX, value[len(MP_ARRAY_PH):]) - target_mpas.append((key, target_flattened[key], mpa_path)) - - # If any MPA needs to be restored, call deserialize - if target_mpas: - mpa_list = _safe_deserialize(target_mpas, gda_manager) - for mpa, (key, _, _) in zip(mpa_list, target_mpas): - flattened[key] = mpa - state_dict = unflatten_dict(flattened) - return state_dict - - -def _natural_sort(file_list: Iterable[str], - signed: bool = True) -> List[str]: - """Natural sort for filenames with numerical substrings. - - Args: - file_list: list of paths to sort containing numerical substrings. - signed: bool: if leading '-' (or '+') signs should be included in numerical - substrings as a sign or treated as a separator. - - Returns: - List of filenames sorted 'naturally', not lexicographically: any - integer substrings are used to subsort numerically. e.g. - file_1, file_10, file_2 --> file_1, file_2, file_10 - file_0.1, file_-0.2, file_2.0 --> file_-0.2, file_0.1, file_2.0 - """ - float_re = SIGNED_FLOAT_RE if signed else UNSIGNED_FLOAT_RE - - def maybe_num(s): - if float_re.match(s): - return float(s) - else: - return s - - def split_keys(s): - return [maybe_num(c) for c in float_re.split(s)] - - return sorted(file_list, key=split_keys) - - -def _safe_normpath(path: str) -> str: - """Normalizes path safely to get around `io.glob()` limitations.""" - match = SCHEME_RE.match(path) - assert match is not None - d = match.groupdict() - return (d['scheme'] or '') + os.path.normpath(d['path']) - - -def _remove_invalid_ckpts(ckpt_path: str, - base_path: str, - keep: int, - overwrite: bool, - keep_every_n_steps: Optional[int], - has_mpa: bool) -> None: - """Clean up the checkpoint space according to `overwrite`, `keep`, and `keep_every_n_steps` parameters. - """ - dir_path, prefix = os.path.split(base_path) - checkpoint_files: List[Any] = [pathlib.PurePath(c) for c in os.listdir(dir_path)] - checkpoint_files = [ - os.path.join(dir_path, c) - for c in checkpoint_files - if c.match(f'{prefix}*') and not c.match(f'*{MP_ARRAY_POSTFIX}') - ] - checkpoint_files = _natural_sort(checkpoint_files) - - # Remove newer checkpoints - if overwrite and ckpt_path in checkpoint_files: - ind = checkpoint_files.index(ckpt_path) + 1 - newer_ckpts = checkpoint_files[ind:] - checkpoint_files = checkpoint_files[:ind] - for path in newer_ckpts: - warnings.warn(f'Removing checkpoint at {path}', UserWarning) - if has_mpa: - # MPA might be removed already but the main ckpt is still there. This - # can happen if the job is previously preempted after deleting the MPA - # checkpoint folder and before deleting the main checkpoint. - if os.path.exists(path + MP_ARRAY_POSTFIX): - shutil.rmtree(path + MP_ARRAY_POSTFIX) - _safe_remove(path) - - # Remove old checkpoint files. - last_kept = -float('inf') - if len(checkpoint_files) > keep: - old_ckpts = checkpoint_files[:-keep] - # Note: old_ckpts is sorted from oldest to newest. - for path in old_ckpts: - if keep_every_n_steps: - step_number = _checkpoint_path_step(path) - if step_number and (step_number - last_kept) >= keep_every_n_steps: - logging.debug('Not deleting %s, because last_kept=%f and keeping ' - 'every %d steps.', - path, last_kept, keep_every_n_steps) - last_kept = step_number - continue - logging.info('Removing checkpoint at %s', path) - if has_mpa: - # MPA might be removed already but the main ckpt is still there. - if os.path.exists(path + MP_ARRAY_POSTFIX): - shutil.rmtree(path + MP_ARRAY_POSTFIX) - _safe_remove(path) - - -def _record_saved_duration(checkpoint_start_time: float): - """Record program duration that is accounted for by this checkpoint. - - For the very first checkpoint, this is the interval between program init and - current checkpoint start time. - - Note that we use the checkpoint start time instead of end time. The saved - duration should not include prallel training duration while the async - checkpoint is being written in the background. - - Args: - checkpoint_start_time: Start time of current checkpoint. - """ - global _LAST_CHECKPOINT_WRITE_TIME - # Note: for the very first checkpoint, this is the interval between program - # init and the current checkpoint start time. - duration_since_last_checkpoint = checkpoint_start_time - _LAST_CHECKPOINT_WRITE_TIME - if monitoring is not None: - monitoring.record_event_duration_secs( - '/jax/checkpoint/write/duration_since_last_checkpoint_secs', - duration_since_last_checkpoint) - _LAST_CHECKPOINT_WRITE_TIME = checkpoint_start_time - - -def _save_commit(ckpt_tmp_path: str, - ckpt_path: str, - base_path: str, - keep: int, - overwrite: bool, - keep_every_n_steps: Optional[int], - ckpt_start_time: float, - has_mpa: bool, - write_commit_success: bool, - async_manager: Optional[AsyncManager] = None) -> None: - """Commit changes after saving checkpoints to disk. - - This function does the following, sequentially: - 1. Make sure all ckpt writing finishes, and rename them from temp path to - the final path. - 2. Remove newer checkpoints (files that ordered larger than this save) if - `overwrite=True`. - 3. Remove old checkpoint files based on `keep` and `keep_every_n_steps`. - 4. Record program duration saved by this checkpoint. - """ - mpa_ckpt_tmp_path, mpa_ckpt_path = ckpt_tmp_path + MP_ARRAY_POSTFIX, ckpt_path + MP_ARRAY_POSTFIX - # Rename the multiprocess array path once serialization and writing finished. - if has_mpa: - if write_commit_success: - commit_success_path = os.path.join(mpa_ckpt_path, COMMIT_SUCCESS_FILE) - with open(commit_success_path, 'w', encoding='utf-8') as f: - f.write(f'Checkpoint commit was successful to {mpa_ckpt_path}') - else: - # Commits are a two stage process (renaming the array folder and renaming - # the main ckpt file in sequential order). We always try to overwrite - # here because the array ckpt might be already renamed in a previously - # interrupted commit. NOTE: io.rename does not support overwriting - # directories via `rename` so we manually overwrite it. - if os.path.exists(mpa_ckpt_path): - logging.info('Removing outdated checkpoint at %s', mpa_ckpt_path) - shutil.rmtree(mpa_ckpt_path) - _rename_fn(mpa_ckpt_tmp_path, mpa_ckpt_path) - # Commit the main checkpoint file after arrays (if any) are committed - if async_manager: - async_manager.wait_previous_save() - _rename_fn(ckpt_tmp_path, ckpt_path, overwrite=overwrite) - logging.info('Saved checkpoint at %s', ckpt_path) - - # Remove newer and older invalid checkpoints. - _remove_invalid_ckpts(ckpt_path, base_path, keep, overwrite, - keep_every_n_steps, has_mpa) - _record_saved_duration(ckpt_start_time) - - -def _check_overwrite_error(ckpt_tmp_path: str, - ckpt_path: str, - base_path: str, - step: int): - """Throw error if a ckpt file of this step or higher already exists.""" - dir_path, prefix = os.path.split(base_path) - checkpoint_files: List[Any] = [pathlib.PurePath(c) for c in _allowempty_listdir(dir_path)] - checkpoint_files = [ - os.path.join(dir_path, c) - for c in checkpoint_files - if c.match(f'{prefix}*') and not c.match(f'*{MP_ARRAY_POSTFIX}') - ] - if ckpt_path in checkpoint_files: - raise InvalidCheckpointError(ckpt_path, step) - checkpoint_files.append(ckpt_path) - - checkpoint_files = _natural_sort(checkpoint_files) - # Handle the case if the job was preempted after the temporary checkpoint - # was written, but before it was renamed to the final checkpoint name - if checkpoint_files[-1] == ckpt_tmp_path: - checkpoint_files.pop() - if ckpt_path != checkpoint_files[-1]: - raise InvalidCheckpointError(ckpt_path, step) - - -def _save_main_ckpt_file(target: bytes, has_mpa: bool, paths: Tuple[str, str], - base_path: str, step: int, - keep: int, overwrite: bool, - keep_every_n_steps: Optional[int], - ckpt_start_time: float): - """Save the main checkpoint file via file system.""" - ckpt_tmp_path, ckpt_path = paths - os.makedirs(os.path.dirname(ckpt_path), exist_ok=True) - - with open(ckpt_tmp_path, 'wb') as fp: - fp.write(target) - - # Postpone the commitment of checkpoint to after MPA writes are done. - if not has_mpa: - _save_commit( - ckpt_tmp_path, - ckpt_path, - base_path, - keep, - overwrite, - keep_every_n_steps, - ckpt_start_time, - has_mpa=False, - write_commit_success=False) - - -def _get_checkpoint_paths( - ckpt_dir: Union[str, os.PathLike], - step: Union[int, float], - prefix: str = 'checkpoint_' -) -> Tuple[str, str, str]: - """Generate the checkpoint paths used in this save operation.""" - ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str - logging.info('Saving checkpoint at step: %s', step) - # normalize path because io.glob() can modify path './', '//' ... - ckpt_dir = _safe_normpath(ckpt_dir) - ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) - ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) - base_path = os.path.join(ckpt_dir, prefix) - return ckpt_path, ckpt_tmp_path, base_path - - -def save( - ckpt_dir: Union[str, os.PathLike], - target: PyTree, - step: Union[int, float], - prefix: str = 'checkpoint_', - keep: int = 1, - overwrite: bool = False, - keep_every_n_steps: Optional[int] = None, - async_manager: Optional[AsyncManager] = None, -) -> str: - """Save a checkpoint of the model. Suitable for single-host. - - In this method, every JAX process saves the checkpoint on its own. Do not - use it if you have multiple processes and you intend for them to save data - to a common directory (e.g., a GCloud bucket). To save multi-process - checkpoints to a shared storage or to save `GlobalDeviceArray`s, use - `multiprocess_save()` instead. - - Pre-emption safe by writing to temporary before a final rename and cleanup - of past files. However, if async_manager is used, the final - commit will happen inside an async callback, which can be explicitly waited - by calling `async_manager.wait_previous_save()`. - - Parameters:: - - ckpt_dir: str, PathLike - str or pathlib-like path to store checkpoint files in. - target: Any - serializable flax object, usually a flax optimizer. - step: int, float - training step number or other metric number. - prefix: str - checkpoint file name prefix. - keep: int - number of past checkpoint files to keep. - overwrite: bool - overwrite existing checkpoint files if a checkpoint at the - current or a later step already exits (default: False). - keep_every_n_steps: int - if defined, keep every checkpoints every n steps (in - addition to keeping the last 'keep' checkpoints). - async_manager: optional, AsyncManager - if defined, the save will run without blocking the main - thread. Only works for single host. Note that an ongoing save will still - block subsequent saves, to make sure overwrite/keep logic works correctly. - - Returns:: - - out: str - Filename of saved checkpoint. - """ - check_msgpack() - start_time = time.time() - # Make sure all saves are finished before the logic of checking and removing - # outdated checkpoints happens. - if async_manager: - async_manager.wait_previous_save() - - ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths( - ckpt_dir, step, prefix - ) - - if not overwrite: - _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore - - target = to_bytes(target) - - # Save the files via I/O sync or async. - def save_main_ckpt_task(): - return _save_main_ckpt_file(target, False, (ckpt_tmp_path, ckpt_path), - base_path, step, keep, overwrite, - keep_every_n_steps, start_time) - - if async_manager: - async_manager.save_async(save_main_ckpt_task) - else: - save_main_ckpt_task() - end_time = time.time() - if monitoring is not None: - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) - return ckpt_path - - -def _save_commit2(filename: str, - overwrite: bool, - ckpt_start_time: float, - has_mpa: bool, - write_commit_success: bool, - async_manager: Optional[AsyncManager] = None) -> None: - """Commit changes after saving checkpoints to disk. - - This function does the following, sequentially: - 1. Make sure all ckpt writing finishes, and rename them from temp path to - the final path. - 2. Remove newer checkpoints (files that ordered larger than this save) if - `overwrite=True`. - 3. Remove old checkpoint files based on `keep` and `keep_every_n_steps`. - 4. Record program duration saved by this checkpoint. - """ - ckpt_path = os.path.dirname(filename) - ckpt_tmp_path = os.path.join(ckpt_path, 'tmp') - mpa_ckpt_tmp_path, mpa_ckpt_path = ckpt_tmp_path + MP_ARRAY_POSTFIX, ckpt_path + MP_ARRAY_POSTFIX - # Rename the multiprocess array path once serialization and writing finished. - if has_mpa: - if write_commit_success: - commit_success_path = os.path.join(mpa_ckpt_path, COMMIT_SUCCESS_FILE) - with open(commit_success_path, 'w', encoding='utf-8') as f: - f.write(f'Checkpoint commit was successful to {mpa_ckpt_path}') - else: - # Commits are a two stage process (renaming the array folder and renaming - # the main ckpt file in sequential order). We always try to overwrite - # here because the array ckpt might be already renamed in a previously - # interrupted commit. NOTE: io.rename does not support overwriting - # directories via `rename` so we manually overwrite it. - if os.path.exists(mpa_ckpt_path): - logging.info('Removing outdated checkpoint at %s', mpa_ckpt_path) - shutil.rmtree(mpa_ckpt_path) - _rename_fn(mpa_ckpt_tmp_path, mpa_ckpt_path) - # Commit the main checkpoint file after arrays (if any) are committed - if async_manager: - async_manager.wait_previous_save() - _rename_fn(ckpt_tmp_path, ckpt_path, overwrite=overwrite) - logging.info('Saved checkpoint at %s', ckpt_path) - - # Remove newer and older invalid checkpoints. - _record_saved_duration(ckpt_start_time) - - -def _save_main_ckpt_file2(target: bytes, - has_mpa: bool, - filename: str, - overwrite: bool, - ckpt_start_time: float): - """Save the main checkpoint file via file system.""" - with open(filename, 'wb') as fp: - fp.write(target) - # Postpone the commitment of checkpoint to after MPA writes are done. - if not has_mpa: - _save_commit2(filename, overwrite, ckpt_start_time, has_mpa=False, write_commit_success=False) def save_pytree( @@ -1228,405 +31,70 @@ def save_pytree( async_manager: Optional[AsyncManager] = None, verbose: bool = True, ) -> None: - """Save a checkpoint of the model. Suitable for single-host. - - In this method, every JAX process saves the checkpoint on its own. Do not - use it if you have multiple processes and you intend for them to save data - to a common directory (e.g., a GCloud bucket). To save multi-process - checkpoints to a shared storage or to save `GlobalDeviceArray`s, use - `multiprocess_save()` instead. - - Pre-emption safe by writing to temporary before a final rename and cleanup - of past files. However, if async_manager is used, the final - commit will happen inside an async callback, which can be explicitly waited - by calling `async_manager.wait_previous_save()`. - - Parameters:: - - filename: str - str or pathlib-like path to store checkpoint files in. - target: Any - serializable flax object, usually a flax optimizer. - overwrite: bool - overwrite existing checkpoint files if a checkpoint at the - current or a later step already exits (default: False). - async_manager: optional, AsyncManager - if defined, the save will run without blocking the main - thread. Only works for single host. Note that an ongoing save will still - block subsequent saves, to make sure overwrite/keep logic works correctly. - verbose: bool - Whether output the print information. - - Returns:: - - out: str - Filename of saved checkpoint. - """ - check_msgpack() - if verbose: - print(f'Saving checkpoint into {filename}') - start_time = time.time() - # Make sure all saves are finished before the logic of checking and removing - # outdated checkpoints happens. - if async_manager: - async_manager.wait_previous_save() - - if os.path.splitext(filename)[-1] != '.bp': - filename = filename + '.bp' - if os.path.dirname(filename): - os.makedirs(os.path.dirname(filename), exist_ok=True) - if not overwrite and os.path.exists(filename): - raise InvalidCheckpointPath(filename) - target = to_bytes(target) - - # Save the files via I/O sync or async. - def save_main_ckpt_task(): - return _save_main_ckpt_file2(target, False, filename, overwrite, start_time) - - if async_manager: - async_manager.save_async(save_main_ckpt_task) - else: - save_main_ckpt_task() - end_time = time.time() - if monitoring is not None: - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) - - -def multiprocess_save( - ckpt_dir: Union[str, os.PathLike], - target: PyTree, - step: Union[int, float], - prefix: str = 'checkpoint_', - keep: int = 1, - overwrite: bool = False, - keep_every_n_steps: Optional[int] = None, - async_manager: Optional[AsyncManager] = None, - gda_manager: Optional[Any] = None, -) -> str: - """Save a checkpoint of the model in multi-process environment. - - Use this method to save `GlobalDeviceArray`s, or to save data to a - common directory. Only process 0 will save the main checkpoint file and - remove old checkpoint files. - - Pre-emption safe by writing to temporary before a final rename and cleanup - of past files. However, if async_manager or gda_manager is used, the final - commit will happen inside an async callback, which can be explicitly waited - by calling `async_manager.wait_previous_save()` or - `gda_manager.wait_until_finished()`. - - Parameters:: - - ckpt_dir: str, PathLike - str or pathlib-like path to store checkpoint files in. - target: Any - serializable flax object, usually a flax optimizer. - step: int, float - training step number or other metric number. - prefix: str - checkpoint file name prefix. - keep: int - number of past checkpoint files to keep. - overwrite: bool - overwrite existing checkpoint files if a checkpoint at the - current or a later step already exits (default: False). - keep_every_n_steps: int - if defined, keep every checkpoints every n steps (in - addition to keeping the last 'keep' checkpoints). - async_manager: optional, AsyncManager - if defined, the save will run without blocking the main - thread. Only works for single host. Note that an ongoing save will still - block subsequent saves, to make sure overwrite/keep logic works correctly. - gda_manager: optional, Any - required if target contains a JAX GlobalDeviceArray. Type - should be GlobalAsyncCheckpointManager (needs Tensorstore to be imported - correctly). Will save the GDAs to a separate subdirectory with postfix - "_gda" asynchronously. Same as async_manager, this will block subsequent - saves. - - Returns:: - - out: str - Filename of saved checkpoint. - """ - check_msgpack() - start_time = time.time() - # Make sure all saves are finished before the logic of checking and removing - # outdated checkpoints happens. - sync_global_devices('starting_save_checkpoint') - if async_manager: - async_manager.wait_previous_save() - if gda_manager: - gda_manager.wait_until_finished() - sync_global_devices('before_save_checkpoint') - - ckpt_path, ckpt_tmp_path, base_path = _get_checkpoint_paths( - ckpt_dir, step, prefix) - - target = to_state_dict(target) - target, mpa_targets = _split_mp_arrays(target) - target = msgpack_serialize(target) - has_mpa = len(mpa_targets) > 0 and (get_tensorstore_spec is not None) - - if not overwrite: - _check_overwrite_error(ckpt_tmp_path, ckpt_path, base_path, step) # type: ignore - sync_global_devices('check_overwrite_strictly_before_save') - - # Save the files via I/O sync or async. - def save_main_ckpt_task(): - return _save_main_ckpt_file(target, has_mpa, (ckpt_tmp_path, ckpt_path), - base_path, step, keep, overwrite, - keep_every_n_steps, start_time) - - # Write the main checkpoint file only via process 0, to avoid race condition. - if process_index() == 0: - if async_manager: - async_manager.save_async(save_main_ckpt_task) - else: - save_main_ckpt_task() - - if has_mpa: - if not gda_manager: - raise MPACheckpointingRequiredError(ckpt_path, step) - # Creating the directory containing GDAs explicitly. This should happen only - # on process 0 and before any worker starts to write GDA data. - if process_index() == 0: - _make_mpa_dirs(mpa_targets, ckpt_tmp_path) - sync_global_devices('Flax:Checkpointing:AfterCreateMPADir') - _save_mpas(gda_manager, mpa_targets, ckpt_tmp_path, ckpt_path, base_path, - keep, overwrite, keep_every_n_steps, start_time, async_manager) - - end_time = time.time() - if monitoring is not None: - monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT, - end_time - start_time) - return ckpt_path - - -def _load_latest_fn( - ckpt_dir: Union[str, os.PathLike], - prefix: str = 'checkpoint_' -) -> Optional[str]: - """Retrieve the path of the latest checkpoint in a directory. - - Parameters:: - - ckpt_dir: str - directory of checkpoints to restore from. - prefix: str - name prefix of checkpoint files. - - Returns:: - - out: PyTree, None - The latest checkpoint path or None if no checkpoints were found. - """ - ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str - checkpoint_files: List[Any] = [ - pathlib.PurePath(c) for c in _allowempty_listdir(ckpt_dir) - ] - checkpoint_files = [ - os.path.join(ckpt_dir, c) - for c in checkpoint_files - if (c.match(f'{prefix}*') and not c.match(f'{prefix}tmp') and - not c.match(f'*{MP_ARRAY_POSTFIX}')) - ] - checkpoint_files = _natural_sort(checkpoint_files) - if checkpoint_files: - return checkpoint_files[-1] - else: - return None - - -def load( - ckpt_dir: Union[str, os.PathLike], - step: Optional[Union[int, float]] = None, - prefix: str = 'checkpoint_', - target: Optional[Any] = None, - parallel: bool = True, - gda_manager: Optional[Any] = None, - allow_partial_mpa_restoration: bool = False, -) -> PyTree: - """Load last or best checkpoint from the given checkpoint path. - - Sorts the checkpoint files naturally, returning the highest-valued - file, e.g.: - - * ``ckpt_1, ckpt_2, ckpt_3 --> ckpt_3`` - - * ``ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1`` - - * ``ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5`` - - Parameters:: - - ckpt_dir: str - checkpoint file or directory of checkpoints to restore from. - step: int, float - Step number to load or None to load latest. If - specified, ckpt_dir must be a directory. - prefix: str - name prefix of checkpoint files. - target: Any - Matching object to rebuild via deserialized state-dict. If None, the - deserialized state-dict is returned as-is. - parallel: bool - whether to load seekable checkpoints in parallel, for speed. - gda_manager: Any - required if checkpoint contains a multiprocess array - (GlobalDeviceArray or jax Array from pjit). Type should be - GlobalAsyncCheckpointManager (needs Tensorstore to be imported - correctly). Will read the arrays from the separate subdirectory with - postfix "_gda". - allow_partial_mpa_restoration: bool - If true, the given `target` doesn't have to - contain all valid multiprocess arrays. As a result, the restored Pytree - may have some MPAs not restored correctly. Use this if you cannot provide - a fully valid ``target`` and don't need all the MPAs in the checkpoint - to be restored. - - Returns:: - - out: Any - Restored `target` updated from checkpoint file, or if no step specified and - no checkpoint files present, returns the passed-in `target` unchanged. - If a file path is specified and is not found, the passed-in `target` will be - returned. This is to match the behavior of the case where a directory path - is specified but the directory has not yet been created. - """ - check_msgpack() - start_time = time.time() - - ckpt_dir = os.fspath(ckpt_dir) # Pathlib -> str - ckpt_dir = _safe_normpath(ckpt_dir) - if step: - ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) - if not os.path.exists(ckpt_path): - raise ValueError(f'Matching checkpoint not found: {ckpt_path}') - else: - if not os.path.exists(ckpt_dir): - logging.info('Found no checkpoint directory at %s', ckpt_dir) - return target - if os.path.isdir(ckpt_dir): - # This means the given dir is an orbax checkpoint. - if os.path.exists(os.path.join(ckpt_dir, ORBAX_CKPT_FILENAME)): - ckpt_path = ckpt_dir - else: - ckpt_path = _load_latest_fn(ckpt_dir, prefix) # type: ignore - if not ckpt_path: - warnings.warn(f'Found no checkpoint files in {ckpt_dir} with prefix {prefix}', - UserWarning) - return target - else: - ckpt_path = ckpt_dir - - sys.stdout.write(f'Loading checkpoint from {ckpt_path}\n') - sys.stdout.flush() - file_size = os.path.getsize(ckpt_path) - - with open(ckpt_path, 'rb') as fp: - if parallel and fp.seekable(): - buf_size = 128 << 20 # 128M buffer. - num_bufs = file_size / buf_size - logging.debug('num_bufs: %d', num_bufs) - checkpoint_contents = bytearray(file_size) - - def read_chunk(i): - # NOTE: We have to re-open the file to read each chunk, otherwise the - # parallelism has no effect. But we could reuse the file pointers - # within each thread. - with open(ckpt_path, 'rb') as f: - f.seek(i * buf_size) - buf = f.read(buf_size) - if buf: - checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf - return len(buf) / buf_size - - pool_size = 32 - pool = thread.ThreadPoolExecutor(pool_size) - results = pool.map(read_chunk, range(int(num_bufs) + 1)) - pool.shutdown(wait=False) - logging.debug(f'results: {list(results)}') - else: - checkpoint_contents = fp.read() - - state_dict = msgpack_restore(checkpoint_contents) - if (get_tensorstore_spec is not None): - state_dict = _restore_mpas(state_dict, target, ckpt_path, step, gda_manager, - allow_partial_mpa_restoration) - - if target is None: - restored_checkpoint = state_dict - else: - restored_checkpoint = from_state_dict(target, state_dict) - - end_time = time.time() - if monitoring is not None: - monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time) - - return restored_checkpoint + """Save a checkpoint of the model. Suitable for single-host. + + In this method, every JAX process saves the checkpoint on its own. Do not + use it if you have multiple processes and you intend for them to save data + to a common directory (e.g., a GCloud bucket). To save multi-process + checkpoints to a shared storage or to save `GlobalDeviceArray`s, use + `multiprocess_save()` instead. + + Pre-emption safe by writing to temporary before a final rename and cleanup + of past files. However, if async_manager is used, the final + commit will happen inside an async callback, which can be explicitly waited + by calling `async_manager.wait_previous_save()`. + + Parameters:: + + filename: str + str or pathlib-like path to store checkpoint files in. + target: Any + serializable flax object, usually a flax optimizer. + overwrite: bool + overwrite existing checkpoint files if a checkpoint at the + current or a later step already exits (default: False). + async_manager: optional, AsyncManager + if defined, the save will run without blocking the main + thread. Only works for single host. Note that an ongoing save will still + block subsequent saves, to make sure overwrite/keep logic works correctly. + verbose: bool + Whether output the print information. + + Returns:: + + out: str + Filename of saved checkpoint. + """ + return braintools.file.msgpack_save( + filename, + target, + overwrite=overwrite, + async_manager=async_manager, + verbose=verbose, + ) def load_pytree( filename: str, + target: Optional[Any] = None, parallel: bool = True, ) -> PyTree: - """Load the checkpoint from the given checkpoint path. - - Parameters:: - - filename: str - checkpoint file or directory of checkpoints to restore from. - parallel: bool - whether to load seekable checkpoints in parallel, for speed. + """Load the checkpoint from the given checkpoint path. - Returns:: + Parameters:: - out: Any - Restored `target` updated from checkpoint file, or if no step specified and - no checkpoint files present, returns the passed-in `target` unchanged. - If a file path is specified and is not found, the passed-in `target` will be - returned. This is to match the behavior of the case where a directory path - is specified but the directory has not yet been created. - """ - check_msgpack() - start_time = time.time() - if not os.path.exists(filename): - raise ValueError(f'Checkpoint not found: {filename}') - sys.stdout.write(f'Loading checkpoint from {filename}\n') - sys.stdout.flush() - file_size = os.path.getsize(filename) + filename: str + checkpoint file or directory of checkpoints to restore from. + parallel: bool + whether to load seekable checkpoints in parallel, for speed. - with open(filename, 'rb') as fp: - if parallel and fp.seekable(): - buf_size = 128 << 20 # 128M buffer. - num_bufs = file_size / buf_size - logging.debug('num_bufs: %d', num_bufs) - checkpoint_contents = bytearray(file_size) + Returns:: - def read_chunk(i): - # NOTE: We have to re-open the file to read each chunk, otherwise the - # parallelism has no effect. But we could reuse the file pointers - # within each thread. - with open(filename, 'rb') as f: - f.seek(i * buf_size) - buf = f.read(buf_size) - if buf: - checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf - return len(buf) / buf_size - - pool_size = 32 - pool = thread.ThreadPoolExecutor(pool_size) - results = pool.map(read_chunk, range(int(num_bufs) + 1)) - pool.shutdown(wait=False) - logging.debug(f'results: {list(results)}') - else: - checkpoint_contents = fp.read() - - state_dict = msgpack_restore(checkpoint_contents) - end_time = time.time() - if monitoring is not None: - monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time) - - return state_dict + out: Any + Restored `target` updated from checkpoint file, or if no step specified and + no checkpoint files present, returns the passed-in `target` unchanged. + If a file path is specified and is not found, the passed-in `target` will be + returned. This is to match the behavior of the case where a directory path + is specified but the directory has not yet been created. + """ + return braintools.file.msgpack_load(filename, target=target, parallel=parallel) diff --git a/brainpy/_src/connect/base.py b/brainpy/_src/connect/base.py index 34898b6fa..6c4eed57b 100644 --- a/brainpy/_src/connect/base.py +++ b/brainpy/_src/connect/base.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import abc +import textwrap from typing import Union, List, Tuple import jax.numpy as jnp @@ -9,30 +10,28 @@ from brainpy import tools, math as bm from brainpy.errors import ConnectorError -import textwrap - __all__ = [ - # the connection types - 'CONN_MAT', - 'PRE_IDS', 'POST_IDS', - 'PRE2POST', 'POST2PRE', - 'PRE2SYN', 'POST2SYN', - 'SUPPORTED_SYN_STRUCTURE', - - # the connection dtypes - 'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE', 'get_idx_type', - - # brainpy_object class - 'Connector', 'TwoEndConnector', 'OneEndConnector', - - # methods - 'mat2coo', 'mat2csc', 'mat2csr', - 'csr2csc', 'csr2mat', 'csr2coo', - 'coo2csr', 'coo2csc', 'coo2mat', - 'coo2mat_num', 'mat2mat_num', - - # visualize - 'visualizeMat', + # the connection types + 'CONN_MAT', + 'PRE_IDS', 'POST_IDS', + 'PRE2POST', 'POST2PRE', + 'PRE2SYN', 'POST2SYN', + 'SUPPORTED_SYN_STRUCTURE', + + # the connection dtypes + 'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE', 'get_idx_type', + + # brainpy_object class + 'Connector', 'TwoEndConnector', 'OneEndConnector', + + # methods + 'mat2coo', 'mat2csc', 'mat2csr', + 'csr2csc', 'csr2mat', 'csr2coo', + 'coo2csr', 'coo2csc', 'coo2mat', + 'coo2mat_num', 'mat2mat_num', + + # visualize + 'visualizeMat', ] CONN_MAT = 'conn_mat' @@ -60,744 +59,746 @@ def get_idx_type(): - return IDX_DTYPE + return IDX_DTYPE def set_default_dtype(mat_dtype=None, idx_dtype=None): - """Set the default dtype. - - Use this method, you can set the default dtype for connetion matrix and - connection index. - - For examples: - - >>> import numpy as np - >>> import brainpy as bp - >>> - >>> conn = bp.conn.GridFour()(4, 4) - >>> conn.require('conn_mat') - Array([[False, True, False, False], - [ True, False, True, False], - [False, True, False, True], - [False, False, True, False]], dtype=bool) - >>> bp.connect.set_default_dtype(mat_dtype=np.float32) - >>> conn = bp.conn.GridFour()(4, 4) - >>> conn.require('conn_mat') - Array([[0., 1., 0., 0.], - [1., 0., 1., 0.], - [0., 1., 0., 1.], - [0., 0., 1., 0.]], dtype=float32) - - Parameters:: - - mat_dtype : type - The default dtype for connection matrix. - idx_dtype : type - The default dtype for connection index. - """ - if mat_dtype is not None: - global MAT_DTYPE - MAT_DTYPE = mat_dtype - if idx_dtype is not None: - global IDX_DTYPE - IDX_DTYPE = idx_dtype + """Set the default dtype. + Use this method, you can set the default dtype for connetion matrix and + connection index. -class Connector(abc.ABC): - """Base Synaptic Connector Class.""" - pass - + For examples: -class TwoEndConnector(Connector): - """Synaptic connector to build connections between two neuron groups. - - If users want to customize their `Connector`, there are two ways: - - 1. Implementing ``build_conn(self)`` function, which returns one of - the connection data ``csr`` (CSR sparse data, a tuple of ), - ``coo`` (COO sparse data, a tuple of ), or ``mat`` - (a binary connection matrix). For instance, - - .. code-block:: python - - import brainpy as bp - class MyConnector(bp.conn.TwoEndConnector): - def build_conn(self): - return dict(csr=, mat=, coo=) - - 2. Implementing functions ``build_mat()``, ``build_csr()``, and - ``build_coo()``. Users can provide all three functions, or one of them. - - .. code-block:: python - - import brainpy as bp - class MyConnector(bp.conn.TwoEndConnector): - def build_mat(self, ): - return conn_matrix - - def build_csr(self, ): - return post_ids, inptr - - def build_coo(self, ): - return pre_ids, post_ids - - """ - - def __init__( - self, - pre: Union[int, Tuple[int, ...]] = None, - post: Union[int, Tuple[int, ...]] = None, - ): - self.pre_size = None - self.post_size = None - self.pre_num = None - self.post_num = None - if pre is not None: - if isinstance(pre, int): - pre = (pre,) - else: - pre = tuple(pre) - self.pre_size = pre - self.pre_num = tools.size2num(self.pre_size) - if post is not None: - if isinstance(post, int): - post = (post,) - else: - post = tuple(post) - self.post_size = post - self.post_num = tools.size2num(self.post_size) - - def __repr__(self): - return self.__class__.__name__ - - def __call__(self, pre_size, post_size): - """Create the concrete connections between two end objects. + >>> import numpy as np + >>> import brainpy as bp + >>> + >>> conn = bp.conn.GridFour()(4, 4) + >>> conn.require('conn_mat') + Array([[False, True, False, False], + [ True, False, True, False], + [False, True, False, True], + [False, False, True, False]], dtype=bool) + >>> bp.connect.set_default_dtype(mat_dtype=np.float32) + >>> conn = bp.conn.GridFour()(4, 4) + >>> conn.require('conn_mat') + Array([[0., 1., 0., 0.], + [1., 0., 1., 0.], + [0., 1., 0., 1.], + [0., 0., 1., 0.]], dtype=float32) Parameters:: - - pre_size : int, tuple of int, list of int - The size of the pre-synaptic group. - post_size : int, tuple of int, list of int - The size of the post-synaptic group. - - Returns:: - - conn : TwoEndConnector - Return the self. - """ - if isinstance(pre_size, int): - pre_size = (pre_size,) - else: - pre_size = tuple(pre_size) - if isinstance(post_size, int): - post_size = (post_size,) - else: - post_size = tuple(post_size) - self.pre_size, self.post_size = pre_size, post_size - self.pre_num = tools.size2num(self.pre_size) - self.post_num = tools.size2num(self.post_size) - return self - def _reset_conn(self, pre_size, post_size): - """Reset connection attributes. - - Parameters:: - - pre_size : int, tuple of int, list of int - The size of the pre-synaptic group. - post_size : int, tuple of int, list of int - The size of the post-synaptic group. + mat_dtype : type + The default dtype for connection matrix. + idx_dtype : type + The default dtype for connection index. """ - self.__call__(pre_size, post_size) - - @property - def is_version2_style(self): - if ((hasattr(self.build_coo, 'not_customized') and self.build_coo.not_customized) and - (hasattr(self.build_csr, 'not_customized') and self.build_csr.not_customized) and - (hasattr(self.build_mat, 'not_customized') and self.build_mat.not_customized)): - return False - else: - return True - - def _check(self, structures: Union[Tuple, List, str]): - # check synaptic structures - if isinstance(structures, str): - structures = [structures] - if structures is None or len(structures) == 0: - raise ConnectorError('No synaptic structure is received.') - for n in structures: - if n not in SUPPORTED_SYN_STRUCTURE: - raise ConnectorError(f'Unknown synapse structure "{n}". ' - f'Only {SUPPORTED_SYN_STRUCTURE} is supported.') - - def _return_by_mat(self, structures, mat, all_data: dict): - assert mat.ndim == 2 - if (CONN_MAT in structures) and (CONN_MAT not in all_data): - all_data[CONN_MAT] = bm.as_jax(mat, dtype=MAT_DTYPE) - - if len([s for s in structures - if s not in [CONN_MAT]]) > 0: - ij = mat2coo(mat) - self._return_by_coo(structures, coo=ij, all_data=all_data) - - def _return_by_csr(self, structures, csr: tuple, all_data: dict): - indices, indptr = csr - np = onp if isinstance(indices, onp.ndarray) else bm - assert self.pre_num == indptr.size - 1 - - if (CONN_MAT in structures) and (CONN_MAT not in all_data): - conn_mat = csr2mat((indices, indptr), self.pre_num, self.post_num) - all_data[CONN_MAT] = bm.as_jax(conn_mat, dtype=MAT_DTYPE) - - if (PRE_IDS in structures) and (PRE_IDS not in all_data): - pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr)) - all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type()) - - if (POST_IDS in structures) and (POST_IDS not in all_data): - all_data[POST_IDS] = bm.as_jax(indices, dtype=get_idx_type()) - - if (COO in structures) and (COO not in all_data): - pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr)) - all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()), - bm.as_jax(indices, dtype=get_idx_type())) - - if (PRE2POST in structures) and (PRE2POST not in all_data): - all_data[PRE2POST] = (bm.as_jax(indices, dtype=get_idx_type()), - bm.as_jax(indptr, dtype=get_idx_type())) - - if (CSR in structures) and (CSR not in all_data): - all_data[CSR] = (bm.as_jax(indices, dtype=get_idx_type()), - bm.as_jax(indptr, dtype=get_idx_type())) - - if (POST2PRE in structures) and (POST2PRE not in all_data): - indc, indptrc = csr2csc((indices, indptr), self.post_num) - all_data[POST2PRE] = (bm.as_jax(indc, dtype=get_idx_type()), - bm.as_jax(indptrc, dtype=get_idx_type())) - - if (CSC in structures) and (CSC not in all_data): - indc, indptrc = csr2csc((indices, indptr), self.post_num) - all_data[CSC] = (bm.as_jax(indc, dtype=get_idx_type()), - bm.as_jax(indptrc, dtype=get_idx_type())) - - if (PRE2SYN in structures) and (PRE2SYN not in all_data): - syn_seq = np.arange(indices.size, dtype=get_idx_type()) - all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=get_idx_type()), - bm.as_jax(indptr, dtype=get_idx_type())) - - if (POST2SYN in structures) and (POST2SYN not in all_data): - syn_seq = np.arange(indices.size, dtype=get_idx_type()) - _, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq) - all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=get_idx_type()), - bm.as_jax(indptrc, dtype=get_idx_type())) - - def _return_by_coo(self, structures, coo: tuple, all_data: dict): - pre_ids, post_ids = coo + if mat_dtype is not None: + global MAT_DTYPE + MAT_DTYPE = mat_dtype + if idx_dtype is not None: + global IDX_DTYPE + IDX_DTYPE = idx_dtype - if (CONN_MAT in structures) and (CONN_MAT not in all_data): - all_data[CONN_MAT] = bm.as_jax(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE) - if (PRE_IDS in structures) and (PRE_IDS not in all_data): - all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type()) - - if (POST_IDS in structures) and (POST_IDS not in all_data): - all_data[POST_IDS] = bm.as_jax(post_ids, dtype=get_idx_type()) - - if (COO in structures) and (COO not in all_data): - all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()), - bm.as_jax(post_ids, dtype=get_idx_type())) - - if CSC in structures and CSC not in all_data: - csc = coo2csc(coo, self.post_num) - all_data[CSC] = (bm.as_jax(csc[0], dtype=get_idx_type()), - bm.as_jax(csc[1], dtype=get_idx_type())) +class Connector(abc.ABC): + """Base Synaptic Connector Class.""" + pass - if POST2PRE in structures and POST2PRE not in all_data: - csc = coo2csc(coo, self.post_num) - all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=get_idx_type()), - bm.as_jax(csc[1], dtype=get_idx_type())) - if (len([s for s in structures - if s not in [CONN_MAT, PRE_IDS, POST_IDS, - COO, CSC, POST2PRE]]) > 0): - csr = coo2csr(coo, self.pre_num) - self._return_by_csr(structures, csr=csr, all_data=all_data) +class TwoEndConnector(Connector): + """Synaptic connector to build connections between two neuron groups. - def _make_returns(self, structures, conn_data): - """Make the desired synaptic structures and return them. - """ - csr = None - mat = None - coo = None - if isinstance(conn_data, dict): - csr = conn_data.get('csr', None) - mat = conn_data.get('mat', None) - coo = conn_data.get('coo', None) or conn_data.get('ij', None) - elif isinstance(conn_data, tuple): - if conn_data[0] == 'csr': - csr = conn_data[1] - elif conn_data[0] == 'mat': - mat = conn_data[1] - elif conn_data[0] in ['coo', 'ij']: - coo = conn_data[1] - else: - raise ConnectorError(f'Must provide one of "csr", "mat" or "coo". Got "{conn_data[0]}" instead.') - else: - raise ConnectorError('Unknown type') - - # checking - if (csr is None) and (mat is None) and (coo is None): - raise ConnectorError('Must provide one of "csr", "mat" or "coo".') - structures = (structures,) if isinstance(structures, str) else structures - assert isinstance(structures, (tuple, list)) - - all_data = dict() - # "csr" structure - if csr is not None: - if (PRE2POST in structures) and (PRE2POST not in all_data): - all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=get_idx_type()), - bm.as_jax(csr[1], dtype=get_idx_type())) - self._return_by_csr(structures, csr=csr, all_data=all_data) - - # "mat" structure - if mat is not None: - assert mat.ndim == 2 - if (CONN_MAT in structures) and (CONN_MAT not in all_data): - all_data[CONN_MAT] = bm.as_jax(mat, dtype=MAT_DTYPE) - self._return_by_mat(structures, mat=mat, all_data=all_data) - - # "coo" structure - if coo is not None: - if (PRE_IDS in structures) and (PRE_IDS not in structures): - all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=get_idx_type()) - if (POST_IDS in structures) and (POST_IDS not in structures): - all_data[POST_IDS] = bm.as_jax(coo[1], dtype=get_idx_type()) - self._return_by_coo(structures, coo=coo, all_data=all_data) - - # return - if len(structures) == 1: - return all_data[structures[0]] - else: - return tuple([all_data[n] for n in structures]) + If users want to customize their `Connector`, there are two ways: - def require(self, *structures): - """Require all the connection data needed. + 1. Implementing ``build_conn(self)`` function, which returns one of + the connection data ``csr`` (CSR sparse data, a tuple of ), + ``coo`` (COO sparse data, a tuple of ), or ``mat`` + (a binary connection matrix). For instance, - Examples:: - - >>> import brainpy as bp - >>> conn = bp.connect.FixedProb(0.1) - >>> mat = conn.require(10, 20, 'conn_mat') - >>> mat.shape - (10, 20) - """ + .. code-block:: python - if len(structures) > 0: - pre_size = None - post_size = None - if not isinstance(structures[0], str): - pre_size = structures[0] - structures = structures[1:] - if len(structures) > 0: - if not isinstance(structures[0], str): - post_size = structures[0] - structures = structures[1:] - if pre_size is not None: - self.__call__(pre_size, post_size) - else: - return tuple() - - if self.pre_num is None or self.post_num is None: - raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' - f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ') - - _has_coo_imp = not hasattr(self.build_coo, 'not_customized') - _has_csr_imp = not hasattr(self.build_csr, 'not_customized') - _has_mat_imp = not hasattr(self.build_mat, 'not_customized') - - self._check(structures) - if (_has_coo_imp or _has_csr_imp or _has_mat_imp): - if len(structures) == 1: - if PRE2POST in structures and _has_csr_imp: - r = self.build_csr() - return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) - elif CSR in structures and _has_csr_imp: - r = self.build_csr() - return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) - elif CONN_MAT in structures and _has_mat_imp: - return bm.as_jax(self.build_mat(), dtype=MAT_DTYPE) - elif PRE_IDS in structures and _has_coo_imp: - return bm.as_jax(self.build_coo()[0], dtype=get_idx_type()) - elif POST_IDS in structures and _has_coo_imp: - return bm.as_jax(self.build_coo()[1], dtype=get_idx_type()) - elif COO in structures and _has_coo_imp: - r = self.build_coo() - return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) - - elif len(structures) == 2: - if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp): - r = self.build_coo() - if structures[0] == PRE_IDS: - return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) - else: - return bm.as_jax(r[1], dtype=get_idx_type()), bm.as_jax(r[0], dtype=get_idx_type()) - - if ((CSR in structures or PRE2POST in structures) - and _has_csr_imp and COO in structures and _has_coo_imp): - csr = self.build_csr() - csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type())) - coo = self.build_coo() - coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type())) - if structures[0] == COO: - return coo, csr - else: - return csr, coo - - if ((CSR in structures or PRE2POST in structures) - and _has_csr_imp and CONN_MAT in structures and _has_mat_imp): - csr = self.build_csr() - csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type())) - mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE) - if structures[0] == CONN_MAT: - return mat, csr - else: - return csr, mat - - if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp): - coo = self.build_coo() - coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type())) - mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE) - if structures[0] == COO: - return coo, mat - else: - return mat, coo - - conn_data = dict(csr=None, ij=None, mat=None) - if _has_coo_imp: - conn_data['coo'] = self.build_coo() - # if (CSR in structures or PRE2POST in structures) and _has_csr_imp: - # conn_data['csr'] = self.build_csr() - # if CONN_MAT in structures and _has_mat_imp: - # conn_data['mat'] = self.build_mat() - elif _has_csr_imp: - conn_data['csr'] = self.build_csr() - # if COO in structures and _has_coo_imp: - # conn_data['coo'] = self.build_coo() - # if CONN_MAT in structures and _has_mat_imp: - # conn_data['mat'] = self.build_mat() - elif _has_mat_imp: - conn_data['mat'] = self.build_mat() - # if COO in structures and _has_coo_imp: - # conn_data['coo'] = self.build_coo() - # if (CSR in structures or PRE2POST in structures) and _has_csr_imp: - # conn_data['csr'] = self.build_csr() - else: - raise ValueError + import brainpy as bp + class MyConnector(bp.conn.TwoEndConnector): + def build_conn(self): + return dict(csr=, mat=, coo=) - else: - conn_data = self.build_conn() - return self._make_returns(structures, conn_data) - - def requires(self, *structures): - """Require all the connection data needed.""" - return self.require(*structures) - - @tools.not_customized - def build_conn(self): - """build connections with certain data type. - - If users want to customize their connections, please provide one - of the following functions: - - - ``build_mat()``: build a matrix binary connection matrix. - - ``build_csr()``: build a csr sparse connection data. - - ``build_coo()``: build a coo sparse connection data. - - ``build_conn()``: deprecated. - - Returns:: - - conn: tuple, dict - A tuple with two elements: connection type (str) and connection data. - For example: ``return 'csr', (ind, indptr)`` - Or a dict with three elements: csr, mat and coo. For example: - ``return dict(csr=(ind, indptr), mat=None, coo=None)`` - """ - pass + 2. Implementing functions ``build_mat()``, ``build_csr()``, and + ``build_coo()``. Users can provide all three functions, or one of them. - @tools.not_customized - def build_mat(self): - """Build a binary matrix connection data. + .. code-block:: python + import brainpy as bp + class MyConnector(bp.conn.TwoEndConnector): + def build_mat(self, ): + return conn_matrix - If users want to customize their connections, please provide one - of the following functions: + def build_csr(self, ): + return post_ids, inptr - - ``build_mat()``: build a matrix binary connection matrix. - - ``build_csr()``: build a csr sparse connection data. - - ``build_coo()``: build a coo sparse connection data. - - ``build_conn()``: deprecated. + def build_coo(self, ): + return pre_ids, post_ids - Returns:: - - conn: Array - A binary matrix with the shape ``(num_pre, num_post)``. """ - pass - - @tools.not_customized - def build_csr(self): - """Build a csr sparse connection data. - Returns:: - - conn: tuple - A tuple denoting the ``(indices, indptr)``. - """ - pass + def __init__( + self, + pre: Union[int, Tuple[int, ...]] = None, + post: Union[int, Tuple[int, ...]] = None, + ): + self.pre_size = None + self.post_size = None + self.pre_num = None + self.post_num = None + if pre is not None: + if isinstance(pre, int): + pre = (pre,) + else: + pre = tuple(pre) + self.pre_size = pre + self.pre_num = tools.size2num(self.pre_size) + if post is not None: + if isinstance(post, int): + post = (post,) + else: + post = tuple(post) + self.post_size = post + self.post_num = tools.size2num(self.post_size) + + def __repr__(self): + return self.__class__.__name__ + + def __call__(self, pre_size, post_size): + """Create the concrete connections between two end objects. + + Parameters:: + + pre_size : int, tuple of int, list of int + The size of the pre-synaptic group. + post_size : int, tuple of int, list of int + The size of the post-synaptic group. + + Returns:: + + conn : TwoEndConnector + Return the self. + """ + if isinstance(pre_size, int): + pre_size = (pre_size,) + else: + pre_size = tuple(pre_size) + if isinstance(post_size, int): + post_size = (post_size,) + else: + post_size = tuple(post_size) + self.pre_size, self.post_size = pre_size, post_size + self.pre_num = tools.size2num(self.pre_size) + self.post_num = tools.size2num(self.post_size) + return self + + def _reset_conn(self, pre_size, post_size): + """Reset connection attributes. + + Parameters:: + + pre_size : int, tuple of int, list of int + The size of the pre-synaptic group. + post_size : int, tuple of int, list of int + The size of the post-synaptic group. + """ + self.__call__(pre_size, post_size) - @tools.not_customized - def build_coo(self): - """Build a coo sparse connection data. + @property + def is_version2_style(self): + if ((hasattr(self.build_coo, 'not_customized') and self.build_coo.not_customized) and + (hasattr(self.build_csr, 'not_customized') and self.build_csr.not_customized) and + (hasattr(self.build_mat, 'not_customized') and self.build_mat.not_customized)): + return False + else: + return True + + def _check(self, structures: Union[Tuple, List, str]): + # check synaptic structures + if isinstance(structures, str): + structures = [structures] + if structures is None or len(structures) == 0: + raise ConnectorError('No synaptic structure is received.') + for n in structures: + if n not in SUPPORTED_SYN_STRUCTURE: + raise ConnectorError(f'Unknown synapse structure "{n}". ' + f'Only {SUPPORTED_SYN_STRUCTURE} is supported.') + + def _return_by_mat(self, structures, mat, all_data: dict): + assert mat.ndim == 2 + if (CONN_MAT in structures) and (CONN_MAT not in all_data): + all_data[CONN_MAT] = bm.as_jax(mat, dtype=MAT_DTYPE) + + if len([s for s in structures + if s not in [CONN_MAT]]) > 0: + ij = mat2coo(mat) + self._return_by_coo(structures, coo=ij, all_data=all_data) + + def _return_by_csr(self, structures, csr: tuple, all_data: dict): + indices, indptr = csr + np = onp if isinstance(indices, onp.ndarray) else bm + assert self.pre_num == indptr.size - 1 + + if (CONN_MAT in structures) and (CONN_MAT not in all_data): + conn_mat = csr2mat((indices, indptr), self.pre_num, self.post_num) + all_data[CONN_MAT] = bm.as_jax(conn_mat, dtype=MAT_DTYPE) + + if (PRE_IDS in structures) and (PRE_IDS not in all_data): + pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr)) + all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type()) + + if (POST_IDS in structures) and (POST_IDS not in all_data): + all_data[POST_IDS] = bm.as_jax(indices, dtype=get_idx_type()) + + if (COO in structures) and (COO not in all_data): + pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr)) + all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()), + bm.as_jax(indices, dtype=get_idx_type())) + + if (PRE2POST in structures) and (PRE2POST not in all_data): + all_data[PRE2POST] = (bm.as_jax(indices, dtype=get_idx_type()), + bm.as_jax(indptr, dtype=get_idx_type())) + + if (CSR in structures) and (CSR not in all_data): + all_data[CSR] = (bm.as_jax(indices, dtype=get_idx_type()), + bm.as_jax(indptr, dtype=get_idx_type())) + + if (POST2PRE in structures) and (POST2PRE not in all_data): + indc, indptrc = csr2csc((indices, indptr), self.post_num) + all_data[POST2PRE] = (bm.as_jax(indc, dtype=get_idx_type()), + bm.as_jax(indptrc, dtype=get_idx_type())) + + if (CSC in structures) and (CSC not in all_data): + indc, indptrc = csr2csc((indices, indptr), self.post_num) + all_data[CSC] = (bm.as_jax(indc, dtype=get_idx_type()), + bm.as_jax(indptrc, dtype=get_idx_type())) + + if (PRE2SYN in structures) and (PRE2SYN not in all_data): + syn_seq = np.arange(indices.size, dtype=get_idx_type()) + all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=get_idx_type()), + bm.as_jax(indptr, dtype=get_idx_type())) + + if (POST2SYN in structures) and (POST2SYN not in all_data): + syn_seq = np.arange(indices.size, dtype=get_idx_type()) + _, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq) + all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=get_idx_type()), + bm.as_jax(indptrc, dtype=get_idx_type())) + + def _return_by_coo(self, structures, coo: tuple, all_data: dict): + pre_ids, post_ids = coo + + if (CONN_MAT in structures) and (CONN_MAT not in all_data): + all_data[CONN_MAT] = bm.as_jax(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE) + + if (PRE_IDS in structures) and (PRE_IDS not in all_data): + all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type()) + + if (POST_IDS in structures) and (POST_IDS not in all_data): + all_data[POST_IDS] = bm.as_jax(post_ids, dtype=get_idx_type()) + + if (COO in structures) and (COO not in all_data): + all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()), + bm.as_jax(post_ids, dtype=get_idx_type())) + + if CSC in structures and CSC not in all_data: + csc = coo2csc(coo, self.post_num) + all_data[CSC] = (bm.as_jax(csc[0], dtype=get_idx_type()), + bm.as_jax(csc[1], dtype=get_idx_type())) + + if POST2PRE in structures and POST2PRE not in all_data: + csc = coo2csc(coo, self.post_num) + all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=get_idx_type()), + bm.as_jax(csc[1], dtype=get_idx_type())) + + if (len([s for s in structures + if s not in [CONN_MAT, PRE_IDS, POST_IDS, + COO, CSC, POST2PRE]]) > 0): + csr = coo2csr(coo, self.pre_num) + self._return_by_csr(structures, csr=csr, all_data=all_data) + + def _make_returns(self, structures, conn_data): + """Make the desired synaptic structures and return them. + """ + csr = None + mat = None + coo = None + if isinstance(conn_data, dict): + csr = conn_data.get('csr', None) + mat = conn_data.get('mat', None) + coo = conn_data.get('coo', None) or conn_data.get('ij', None) + elif isinstance(conn_data, tuple): + if conn_data[0] == 'csr': + csr = conn_data[1] + elif conn_data[0] == 'mat': + mat = conn_data[1] + elif conn_data[0] in ['coo', 'ij']: + coo = conn_data[1] + else: + raise ConnectorError(f'Must provide one of "csr", "mat" or "coo". Got "{conn_data[0]}" instead.') + else: + raise ConnectorError('Unknown type') + + # checking + if (csr is None) and (mat is None) and (coo is None): + raise ConnectorError('Must provide one of "csr", "mat" or "coo".') + structures = (structures,) if isinstance(structures, str) else structures + assert isinstance(structures, (tuple, list)) + + all_data = dict() + # "csr" structure + if csr is not None: + if (PRE2POST in structures) and (PRE2POST not in all_data): + all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=get_idx_type()), + bm.as_jax(csr[1], dtype=get_idx_type())) + self._return_by_csr(structures, csr=csr, all_data=all_data) + + # "mat" structure + if mat is not None: + assert mat.ndim == 2 + if (CONN_MAT in structures) and (CONN_MAT not in all_data): + all_data[CONN_MAT] = bm.as_jax(mat, dtype=MAT_DTYPE) + self._return_by_mat(structures, mat=mat, all_data=all_data) + + # "coo" structure + if coo is not None: + if (PRE_IDS in structures) and (PRE_IDS not in structures): + all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=get_idx_type()) + if (POST_IDS in structures) and (POST_IDS not in structures): + all_data[POST_IDS] = bm.as_jax(coo[1], dtype=get_idx_type()) + self._return_by_coo(structures, coo=coo, all_data=all_data) + + # return + if len(structures) == 1: + return all_data[structures[0]] + else: + return tuple([all_data[n] for n in structures]) + + def require(self, *structures): + """Require all the connection data needed. + + Examples:: + + >>> import brainpy as bp + >>> conn = bp.connect.FixedProb(0.1) + >>> mat = conn.require(10, 20, 'conn_mat') + >>> mat.shape + (10, 20) + """ - Returns:: - - conn: tuple - A tuple denoting the ``(pre_ids, post_ids)``. - """ - pass + if len(structures) > 0: + pre_size = None + post_size = None + if not isinstance(structures[0], str): + pre_size = structures[0] + structures = structures[1:] + if len(structures) > 0: + if not isinstance(structures[0], str): + post_size = structures[0] + structures = structures[1:] + if pre_size is not None: + self.__call__(pre_size, post_size) + else: + return tuple() + + if self.pre_num is None or self.post_num is None: + raise ConnectorError(f'self.pre_num or self.post_num is not defined. ' + f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ') + + _has_coo_imp = not hasattr(self.build_coo, 'not_customized') + _has_csr_imp = not hasattr(self.build_csr, 'not_customized') + _has_mat_imp = not hasattr(self.build_mat, 'not_customized') + + self._check(structures) + if (_has_coo_imp or _has_csr_imp or _has_mat_imp): + if len(structures) == 1: + if PRE2POST in structures and _has_csr_imp: + r = self.build_csr() + return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) + elif CSR in structures and _has_csr_imp: + r = self.build_csr() + return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) + elif CONN_MAT in structures and _has_mat_imp: + return bm.as_jax(self.build_mat(), dtype=MAT_DTYPE) + elif PRE_IDS in structures and _has_coo_imp: + return bm.as_jax(self.build_coo()[0], dtype=get_idx_type()) + elif POST_IDS in structures and _has_coo_imp: + return bm.as_jax(self.build_coo()[1], dtype=get_idx_type()) + elif COO in structures and _has_coo_imp: + r = self.build_coo() + return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) + + elif len(structures) == 2: + if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp): + r = self.build_coo() + if structures[0] == PRE_IDS: + return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type()) + else: + return bm.as_jax(r[1], dtype=get_idx_type()), bm.as_jax(r[0], dtype=get_idx_type()) + + if ((CSR in structures or PRE2POST in structures) + and _has_csr_imp and COO in structures and _has_coo_imp): + csr = self.build_csr() + csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type())) + coo = self.build_coo() + coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type())) + if structures[0] == COO: + return coo, csr + else: + return csr, coo + + if ((CSR in structures or PRE2POST in structures) + and _has_csr_imp and CONN_MAT in structures and _has_mat_imp): + csr = self.build_csr() + csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type())) + mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE) + if structures[0] == CONN_MAT: + return mat, csr + else: + return csr, mat + + if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp): + coo = self.build_coo() + coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type())) + mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE) + if structures[0] == COO: + return coo, mat + else: + return mat, coo + + conn_data = dict(csr=None, ij=None, mat=None) + if _has_coo_imp: + conn_data['coo'] = self.build_coo() + # if (CSR in structures or PRE2POST in structures) and _has_csr_imp: + # conn_data['csr'] = self.build_csr() + # if CONN_MAT in structures and _has_mat_imp: + # conn_data['mat'] = self.build_mat() + elif _has_csr_imp: + conn_data['csr'] = self.build_csr() + # if COO in structures and _has_coo_imp: + # conn_data['coo'] = self.build_coo() + # if CONN_MAT in structures and _has_mat_imp: + # conn_data['mat'] = self.build_mat() + elif _has_mat_imp: + conn_data['mat'] = self.build_mat() + # if COO in structures and _has_coo_imp: + # conn_data['coo'] = self.build_coo() + # if (CSR in structures or PRE2POST in structures) and _has_csr_imp: + # conn_data['csr'] = self.build_csr() + else: + raise ValueError + + else: + conn_data = self.build_conn() + return self._make_returns(structures, conn_data) + + def requires(self, *structures): + """Require all the connection data needed.""" + return self.require(*structures) + + @tools.not_customized + def build_conn(self): + """build connections with certain data type. + + If users want to customize their connections, please provide one + of the following functions: + + - ``build_mat()``: build a matrix binary connection matrix. + - ``build_csr()``: build a csr sparse connection data. + - ``build_coo()``: build a coo sparse connection data. + - ``build_conn()``: deprecated. + + Returns:: + + conn: tuple, dict + A tuple with two elements: connection type (str) and connection data. + For example: ``return 'csr', (ind, indptr)`` + Or a dict with three elements: csr, mat and coo. For example: + ``return dict(csr=(ind, indptr), mat=None, coo=None)`` + """ + pass + + @tools.not_customized + def build_mat(self): + """Build a binary matrix connection data. + + + If users want to customize their connections, please provide one + of the following functions: + + - ``build_mat()``: build a matrix binary connection matrix. + - ``build_csr()``: build a csr sparse connection data. + - ``build_coo()``: build a coo sparse connection data. + - ``build_conn()``: deprecated. + + Returns:: + + conn: Array + A binary matrix with the shape ``(num_pre, num_post)``. + """ + pass + + @tools.not_customized + def build_csr(self): + """Build a csr sparse connection data. + + Returns:: + + conn: tuple + A tuple denoting the ``(indices, indptr)``. + """ + pass + + @tools.not_customized + def build_coo(self): + """Build a coo sparse connection data. + + Returns:: + + conn: tuple + A tuple denoting the ``(pre_ids, post_ids)``. + """ + pass class OneEndConnector(TwoEndConnector): - """Synaptic connector to build synapse connections within a population of neurons.""" - - def __init__(self, *args, **kwargs): - super(OneEndConnector, self).__init__(*args, **kwargs) - - def __call__(self, pre_size, post_size=None): - if post_size is None: - post_size = pre_size - - try: - assert pre_size == post_size - except AssertionError: - raise ConnectorError( - f'The shape of pre-synaptic group should be the same with the post group. ' - f'But we got {pre_size} != {post_size}.') - - if isinstance(pre_size, int): - pre_size = (pre_size,) - else: - pre_size = tuple(pre_size) - if isinstance(post_size, int): - post_size = (post_size,) - else: - post_size = tuple(post_size) - self.pre_size, self.post_size = pre_size, post_size - self.pre_num = tools.size2num(self.pre_size) - self.post_num = tools.size2num(self.post_size) - return self - - def _reset_conn(self, pre_size, post_size=None): - self.__init__() - self.__call__(pre_size, post_size) + """Synaptic connector to build synapse connections within a population of neurons.""" + + def __init__(self, *args, **kwargs): + super(OneEndConnector, self).__init__(*args, **kwargs) + + def __call__(self, pre_size, post_size=None): + if post_size is None: + post_size = pre_size + + try: + assert pre_size == post_size + except AssertionError: + raise ConnectorError( + f'The shape of pre-synaptic group should be the same with the post group. ' + f'But we got {pre_size} != {post_size}.') + + if isinstance(pre_size, int): + pre_size = (pre_size,) + else: + pre_size = tuple(pre_size) + if isinstance(post_size, int): + post_size = (post_size,) + else: + post_size = tuple(post_size) + self.pre_size, self.post_size = pre_size, post_size + self.pre_num = tools.size2num(self.pre_size) + self.post_num = tools.size2num(self.post_size) + return self + + def _reset_conn(self, pre_size, post_size=None): + self.__init__() + self.__call__(pre_size, post_size) def mat2csr(dense): - """convert a dense matrix to (indices, indptr).""" - if isinstance(dense, onp.ndarray): - pre_ids, post_ids = onp.where(dense > 0) - else: - pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0) - return coo2csr((pre_ids, post_ids), dense.shape[0]) + """convert a dense matrix to (indices, indptr).""" + if isinstance(dense, onp.ndarray): + pre_ids, post_ids = onp.where(dense > 0) + else: + pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0) + return coo2csr((pre_ids, post_ids), dense.shape[0]) def mat2coo(dense): - if isinstance(dense, onp.ndarray): - pre_ids, post_ids = onp.where(dense > 0) - else: - pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0) - return pre_ids.astype(dtype=get_idx_type()), post_ids.astype(dtype=get_idx_type()) + if isinstance(dense, onp.ndarray): + pre_ids, post_ids = onp.where(dense > 0) + else: + pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0) + return pre_ids.astype(dtype=get_idx_type()), post_ids.astype(dtype=get_idx_type()) def mat2csc(dense): - if isinstance(dense, onp.ndarray): - pre_ids, post_ids = onp.where(dense > 0) - else: - pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0) - return coo2csr((post_ids, pre_ids), dense.shape[1]) + if isinstance(dense, onp.ndarray): + pre_ids, post_ids = onp.where(dense > 0) + else: + pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0) + return coo2csr((post_ids, pre_ids), dense.shape[1]) def csr2mat(csr, num_pre, num_post): - """convert (indices, indptr) to a dense matrix.""" - indices, indptr = csr - if isinstance(indices, onp.ndarray): - d = onp.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post - pre_ids = onp.repeat(onp.arange(indptr.size - 1), onp.diff(indptr)) - d[pre_ids, indices] = True - return d - else: - d = bm.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post - pre_ids = jnp.repeat(jnp.arange(indptr.size - 1), jnp.diff(indptr)) - d[pre_ids, indices] = True - return d.value + """convert (indices, indptr) to a dense matrix.""" + indices, indptr = csr + if isinstance(indices, onp.ndarray): + d = onp.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post + pre_ids = onp.repeat(onp.arange(indptr.size - 1), onp.diff(indptr)) + d[pre_ids, indices] = True + return d + else: + d = bm.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post + pre_ids = jnp.repeat(jnp.arange(indptr.size - 1), jnp.diff(indptr)) + d[pre_ids, indices] = True + return d.value def csr2csc(csr, post_num, data=None): - """Convert csr to csc.""" - return coo2csc(csr2coo(csr), post_num, data) + """Convert csr to csc.""" + return coo2csc(csr2coo(csr), post_num, data) def csr2coo(csr): - np = onp if isinstance(csr[0], onp.ndarray) else jnp - indices, indptr = csr - pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr)) - return pre_ids, indices + np = onp if isinstance(csr[0], onp.ndarray) else jnp + indices, indptr = csr + pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr)) + return pre_ids, indices def coo2mat(ij, num_pre, num_post): - """convert (indices, indptr) to a dense matrix.""" - pre_ids, post_ids = ij - if isinstance(pre_ids, onp.ndarray): - d = onp.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post - d[pre_ids, post_ids] = True - return d - else: - d = bm.zeros((num_pre, num_post), dtype=MAT_DTYPE) - d[pre_ids, post_ids] = True - return d.value + """convert (indices, indptr) to a dense matrix.""" + pre_ids, post_ids = ij + if isinstance(pre_ids, onp.ndarray): + d = onp.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post + d[pre_ids, post_ids] = True + return d + else: + d = bm.zeros((num_pre, num_post), dtype=MAT_DTYPE) + d[pre_ids, post_ids] = True + return d.value def coo2csr(coo, num_pre): - """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'""" - pre_ids, post_ids = coo - - if isinstance(pre_ids, onp.ndarray): - sort_ids = onp.argsort(pre_ids) - post_ids = onp.asarray(post_ids) - post_ids = post_ids[sort_ids] - indices = post_ids - unique_pre_ids, pre_count = onp.unique(pre_ids, return_counts=True) - final_pre_count = onp.zeros(num_pre, dtype=jnp.uint32) - final_pre_count[unique_pre_ids] = pre_count - else: - sort_ids = onp.argsort(bm.as_jax(pre_ids)) - post_ids = bm.as_jax(post_ids) - post_ids = post_ids[sort_ids] - indices = post_ids - unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) - final_pre_count = bm.zeros(num_pre, dtype=jnp.uint32) - final_pre_count[unique_pre_ids] = pre_count - final_pre_count = bm.as_jax(final_pre_count) - indptr = final_pre_count.cumsum() - indptr = onp.insert(indptr, 0, 0) - return indices.astype(get_idx_type()), indptr.astype(get_idx_type()) + """convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'""" + pre_ids, post_ids = coo + + if isinstance(pre_ids, onp.ndarray): + sort_ids = onp.argsort(pre_ids) + post_ids = onp.asarray(post_ids) + post_ids = post_ids[sort_ids] + indices = post_ids + unique_pre_ids, pre_count = onp.unique(pre_ids, return_counts=True) + final_pre_count = onp.zeros(num_pre, dtype=jnp.uint32) + final_pre_count[unique_pre_ids] = pre_count + else: + sort_ids = onp.argsort(bm.as_jax(pre_ids)) + post_ids = bm.as_jax(post_ids) + post_ids = post_ids[sort_ids] + indices = post_ids + unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) + final_pre_count = bm.zeros(num_pre, dtype=jnp.uint32) + final_pre_count[unique_pre_ids] = pre_count + final_pre_count = bm.as_jax(final_pre_count) + indptr = final_pre_count.cumsum() + indptr = onp.insert(indptr, 0, 0) + return indices.astype(get_idx_type()), indptr.astype(get_idx_type()) def coo2csc(coo, post_num, data=None): - """Convert csr to csc.""" - pre_ids, indices = coo - if isinstance(indices, onp.ndarray): - # to maintain the original order of the elements with the same value - sort_ids = onp.argsort(indices) - pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=get_idx_type()) - - unique_post_ids, count = onp.unique(indices, return_counts=True) - post_count = onp.zeros(post_num, dtype=get_idx_type()) - post_count[unique_post_ids] = count - - indptr_new = post_count.cumsum() - indptr_new = onp.insert(indptr_new, 0, 0) - indptr_new = onp.asarray(indptr_new, dtype=get_idx_type()) - - else: - pre_ids = bm.as_jax(pre_ids) - indices = bm.as_jax(indices) - - # to maintain the original order of the elements with the same value - sort_ids = jnp.argsort(indices) - pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=get_idx_type()) - - unique_post_ids, count = jnp.unique(indices, return_counts=True) - post_count = bm.zeros(post_num, dtype=get_idx_type()) - post_count[unique_post_ids] = count - - indptr_new = post_count.value.cumsum() - indptr_new = jnp.insert(indptr_new, 0, 0) - indptr_new = jnp.asarray(indptr_new, dtype=get_idx_type()) - - if data is None: - return pre_ids_new, indptr_new - else: - data_new = data[sort_ids] - return pre_ids_new, indptr_new, data_new + """Convert csr to csc.""" + pre_ids, indices = coo + if isinstance(indices, onp.ndarray): + # to maintain the original order of the elements with the same value + sort_ids = onp.argsort(indices) + pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=get_idx_type()) + + unique_post_ids, count = onp.unique(indices, return_counts=True) + post_count = onp.zeros(post_num, dtype=get_idx_type()) + post_count[unique_post_ids] = count + + indptr_new = post_count.cumsum() + indptr_new = onp.insert(indptr_new, 0, 0) + indptr_new = onp.asarray(indptr_new, dtype=get_idx_type()) + + else: + pre_ids = bm.as_jax(pre_ids) + indices = bm.as_jax(indices) + + # to maintain the original order of the elements with the same value + sort_ids = jnp.argsort(indices) + pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=get_idx_type()) + + unique_post_ids, count = jnp.unique(indices, return_counts=True) + post_count = bm.zeros(post_num, dtype=get_idx_type()) + post_count[unique_post_ids] = count + + indptr_new = post_count.value.cumsum() + indptr_new = jnp.insert(indptr_new, 0, 0) + indptr_new = jnp.asarray(indptr_new, dtype=get_idx_type()) + + if data is None: + return pre_ids_new, indptr_new + else: + data_new = data[sort_ids] + return pre_ids_new, indptr_new, data_new + def coo2mat_num(ij, num_pre, num_post, num, seed=0): - """ - convert (indices, indptr) to a dense connection number matrix.\n - Specific for FixedTotalNum. - """ - rng = bm.random.RandomState(seed) - mat = coo2mat(ij, num_pre, num_post) + """ + convert (indices, indptr) to a dense connection number matrix.\n + Specific for FixedTotalNum. + """ + rng = bm.random.RandomState(seed) + mat = coo2mat(ij, num_pre, num_post) - # get nonzero indices and number - nonzero_idx = jnp.nonzero(mat) - nonzero_num = jnp.count_nonzero(mat) + # get nonzero indices and number + nonzero_idx = jnp.nonzero(mat) + nonzero_num = jnp.count_nonzero(mat) - # get multi connection number - multi_conn_num = num - nonzero_num + # get multi connection number + multi_conn_num = num - nonzero_num - # alter the element type to int - mat = mat.astype(jnp.int32) + # alter the element type to int + mat = mat.astype(jnp.int32) - # 随机在mat中选取nonzero_idx的元素,将其值加1 - index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False) - for i in index: - mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1) + # 随机在mat中选取nonzero_idx的元素,将其值加1 + index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False) + for i in index: + mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1) + + return mat - return mat def mat2mat_num(mat, num, seed=0): - """ - Convert boolean matrix to a dense connection number matrix.\n - Specific for FixedTotalNum. - """ - rng = bm.random.RandomState(seed) + """ + Convert boolean matrix to a dense connection number matrix.\n + Specific for FixedTotalNum. + """ + rng = bm.random.RandomState(seed) - # get nonzero indices and number - nonzero_idx = jnp.nonzero(mat) - nonzero_num = jnp.count_nonzero(mat) + # get nonzero indices and number + nonzero_idx = jnp.nonzero(mat) + nonzero_num = jnp.count_nonzero(mat) - # get multi connection number - multi_conn_num = num - nonzero_num + # get multi connection number + multi_conn_num = num - nonzero_num - # alter the element type to int - mat = mat.astype(jnp.int32) + # alter the element type to int + mat = mat.astype(jnp.int32) - # 随机在mat中选取nonzero_idx的元素,将其值加1 - index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False) - for i in index: - mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1) + # 随机在mat中选取nonzero_idx的元素,将其值加1 + index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False) + for i in index: + mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1) - return mat + return mat def visualizeMat(mat, description='Untitled'): - """ - Visualize the matrix. (Need seaborn and matplotlib) - - parameters - ---------- - mat : jnp.ndarray - The matrix to be visualized. - description : str - The title of the figure. - """ - try: - import seaborn as sns - import matplotlib.pyplot as plt - except (ModuleNotFoundError, ImportError): - print('Please install seaborn and matplotlib for this function') - return - sns.heatmap(mat, cmap='viridis') - warpped_title = textwrap.fill(description, width=60) - plt.title(warpped_title) - plt.show() \ No newline at end of file + """ + Visualize the matrix. (Need seaborn and matplotlib) + + parameters + ---------- + mat : jnp.ndarray + The matrix to be visualized. + description : str + The title of the figure. + """ + try: + import seaborn as sns + import matplotlib.pyplot as plt + except (ModuleNotFoundError, ImportError): + print('Please install seaborn and matplotlib for this function') + return + sns.heatmap(mat, cmap='viridis') + warpped_title = textwrap.fill(description, width=60) + plt.title(warpped_title) + plt.show() diff --git a/brainpy/_src/connect/custom_conn.py b/brainpy/_src/connect/custom_conn.py index 14376af12..f5ef04ba2 100644 --- a/brainpy/_src/connect/custom_conn.py +++ b/brainpy/_src/connect/custom_conn.py @@ -9,107 +9,107 @@ from .base import * __all__ = [ - 'MatConn', - 'IJConn', - 'CSRConn', - 'SparseMatConn' + 'MatConn', + 'IJConn', + 'CSRConn', + 'SparseMatConn' ] class MatConn(TwoEndConnector): - """Connector built from the dense connection matrix.""" + """Connector built from the dense connection matrix.""" - def __init__(self, conn_mat, **kwargs): - super(MatConn, self).__init__(**kwargs) + def __init__(self, conn_mat, **kwargs): + super(MatConn, self).__init__(**kwargs) - assert isinstance(conn_mat, (np.ndarray, bm.BaseArray, jax.Array)) and conn_mat.ndim == 2 - self.pre_num, self.post_num = conn_mat.shape - self.pre_size, self.post_size = (self.pre_num,), (self.post_num,) + assert isinstance(conn_mat, (np.ndarray, bm.BaseArray, jax.Array)) and conn_mat.ndim == 2 + self.pre_num, self.post_num = conn_mat.shape + self.pre_size, self.post_size = (self.pre_num,), (self.post_num,) - self.conn_mat = jnp.asarray(conn_mat).astype(MAT_DTYPE) + self.conn_mat = jnp.asarray(conn_mat).astype(MAT_DTYPE) - def __call__(self, pre_size, post_size): - assert self.pre_num == tools.size2num(pre_size) - assert self.post_num == tools.size2num(post_size) - return self + def __call__(self, pre_size, post_size): + assert self.pre_num == tools.size2num(pre_size) + assert self.post_num == tools.size2num(post_size) + return self - def build_mat(self): - assert self.conn_mat.shape[0] == self.pre_num - assert self.conn_mat.shape[1] == self.post_num - return self.conn_mat + def build_mat(self): + assert self.conn_mat.shape[0] == self.pre_num + assert self.conn_mat.shape[1] == self.post_num + return self.conn_mat class IJConn(TwoEndConnector): - """Connector built from the ``pre_ids`` and ``post_ids`` connections.""" - - def __init__(self, i, j, **kwargs): - super(IJConn, self).__init__(**kwargs) - - assert isinstance(i, (np.ndarray, bm.BaseArray, jnp.ndarray)) and i.ndim == 1 - assert isinstance(j, (np.ndarray, bm.BaseArray, jnp.ndarray)) and j.ndim == 1 - assert i.size == j.size - - # initialize the class via "pre_ids" and "post_ids" - self.pre_ids = jnp.asarray(i).astype(IDX_DTYPE) - self.post_ids = jnp.asarray(j).astype(IDX_DTYPE) - self.max_pre = self.pre_ids.max() - self.max_post = self.post_ids.max() - - def __call__(self, pre_size, post_size): - super(IJConn, self).__call__(pre_size, post_size) - if self.max_pre >= self.pre_num: - raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' - f'the maximum id ({self.max_pre}) of self.pre_ids.') - if self.max_post >= self.post_num: - raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' - f'the maximum id ({self.max_post}) of self.post_ids.') - return self - - def build_coo(self): - if self.pre_num <= self.max_pre: - raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' - f'the maximum id ({self.max_pre}) of self.pre_ids.') - if self.post_num <= self.max_post: - raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' - f'the maximum id ({self.max_post}) of self.post_ids.') - return self.pre_ids, self.post_ids + """Connector built from the ``pre_ids`` and ``post_ids`` connections.""" + + def __init__(self, i, j, **kwargs): + super(IJConn, self).__init__(**kwargs) + + assert isinstance(i, (np.ndarray, bm.BaseArray, jnp.ndarray)) and i.ndim == 1 + assert isinstance(j, (np.ndarray, bm.BaseArray, jnp.ndarray)) and j.ndim == 1 + assert i.size == j.size + + # initialize the class via "pre_ids" and "post_ids" + self.pre_ids = jnp.asarray(i).astype(IDX_DTYPE) + self.post_ids = jnp.asarray(j).astype(IDX_DTYPE) + self.max_pre = self.pre_ids.max() + self.max_post = self.post_ids.max() + + def __call__(self, pre_size, post_size): + super(IJConn, self).__call__(pre_size, post_size) + if self.max_pre >= self.pre_num: + raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' + f'the maximum id ({self.max_pre}) of self.pre_ids.') + if self.max_post >= self.post_num: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' + f'the maximum id ({self.max_post}) of self.post_ids.') + return self + + def build_coo(self): + if self.pre_num <= self.max_pre: + raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than ' + f'the maximum id ({self.max_pre}) of self.pre_ids.') + if self.post_num <= self.max_post: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' + f'the maximum id ({self.max_post}) of self.post_ids.') + return self.pre_ids, self.post_ids class CSRConn(TwoEndConnector): - """Connector built from the CSR sparse connection matrix.""" + """Connector built from the CSR sparse connection matrix.""" - def __init__(self, indices, inptr, **kwargs): - super(CSRConn, self).__init__(**kwargs) + def __init__(self, indices, inptr, **kwargs): + super(CSRConn, self).__init__(**kwargs) - self.indices = jnp.asarray(indices, dtype=IDX_DTYPE) - self.inptr = jnp.asarray(inptr, dtype=IDX_DTYPE) - self.pre_num = self.inptr.size - 1 - self.max_post = self.indices.max() + self.indices = jnp.asarray(indices, dtype=IDX_DTYPE) + self.inptr = jnp.asarray(inptr, dtype=IDX_DTYPE) + self.pre_num = self.inptr.size - 1 + self.max_post = self.indices.max() - def build_csr(self): - if self.pre_num != self.pre_num: - raise ConnectorError(f'(pre_size, post_size) is inconsistent with ' - f'the shape of the sparse matrix.') - if self.post_num <= self.max_post: - raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' - f'the maximum id ({self.max_post}) of self.post_ids.') - return self.indices, self.inptr + def build_csr(self): + if self.pre_num != self.pre_num: + raise ConnectorError(f'(pre_size, post_size) is inconsistent with ' + f'the shape of the sparse matrix.') + if self.post_num <= self.max_post: + raise ConnectorError(f'post_num ({self.post_num}) should be greater than ' + f'the maximum id ({self.max_post}) of self.post_ids.') + return self.indices, self.inptr class SparseMatConn(CSRConn): - """Connector built from the sparse connection matrix""" - - def __init__(self, csr_mat, **kwargs): - try: - from scipy.sparse import csr_matrix - except (ModuleNotFoundError, ImportError): - raise ConnectorError(f'Using SparseMatConn requires the scipy package. ' - f'Please run "pip install scipy" to install scipy.') - - assert isinstance(csr_mat, csr_matrix) - self.csr_mat = csr_mat - super(SparseMatConn, self).__init__(indices=jnp.asarray(self.csr_mat.indices, dtype=IDX_DTYPE), - inptr=jnp.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE), - **kwargs) - self.pre_num = csr_mat.shape[0] - self.post_num = csr_mat.shape[1] + """Connector built from the sparse connection matrix""" + + def __init__(self, csr_mat, **kwargs): + try: + from scipy.sparse import csr_matrix + except (ModuleNotFoundError, ImportError): + raise ConnectorError(f'Using SparseMatConn requires the scipy package. ' + f'Please run "pip install scipy" to install scipy.') + + assert isinstance(csr_mat, csr_matrix) + self.csr_mat = csr_mat + super(SparseMatConn, self).__init__(indices=jnp.asarray(self.csr_mat.indices, dtype=IDX_DTYPE), + inptr=jnp.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE), + **kwargs) + self.pre_num = csr_mat.shape[0] + self.post_num = csr_mat.shape[1] diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 14c2b474f..0a5ea3199 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -3,640 +3,639 @@ from functools import partial from typing import Optional -from jax import vmap, jit, numpy as jnp import numpy as np +from jax import vmap, jit, numpy as jnp import brainpy.math as bm +from brainpy._src.tools.package import SUPPORT_NUMBA from brainpy.errors import ConnectorError from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed -from brainpy._src.tools.package import SUPPORT_NUMBA from .base import * - __all__ = [ - 'FixedProb', - 'FixedPreNum', - 'FixedPostNum', - 'FixedTotalNum', - 'GaussianProb', - 'ProbDist', - - 'SmallWorld', - 'ScaleFreeBA', - 'ScaleFreeBADual', - 'PowerLaw', + 'FixedProb', + 'FixedPreNum', + 'FixedPostNum', + 'FixedTotalNum', + 'GaussianProb', + 'ProbDist', + + 'SmallWorld', + 'ScaleFreeBA', + 'ScaleFreeBADual', + 'PowerLaw', ] class FixedProb(TwoEndConnector): - """Connect the post-synaptic neurons with fixed probability. - - Parameters:: - - prob: float - The conn probability. - pre_ratio: float - The ratio of pre-synaptic neurons to connect. - include_self : bool - Whether create (i, i) conn? - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - seed : optional, int - Seed the random generator. - """ - - def __init__(self, - prob, - pre_ratio=1., - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedProb, self).__init__(**kwargs) - assert 0. <= prob <= 1. - assert 0. <= pre_ratio <= 1. - self.prob = prob - self.pre_ratio = pre_ratio - self.include_self = include_self - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self._jaxrand = bm.random.default_rng(self.seed) - self._nprand = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' - f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' - f'seed={self.seed})') - - def _iii(self): - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - - if self.pre_ratio < 1.: - pre_num_to_select = int(self.pre_num * self.pre_ratio) - pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) - else: - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - - post_num_total = self.post_num - post_num_to_select = int(self.post_num * self.prob) - - if self.allow_multi_conn: - selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self._nprand.randint(0, int(1e8))) - else: - rng = self._nprand - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._iii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - - def build_mat(self): - if self.pre_ratio < 1.: - pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state - else: - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) - mat = bm.asarray(mat) - if not self.include_self: - bm.fill_diagonal(mat, False) - return mat.astype(MAT_DTYPE) + """Connect the post-synaptic neurons with fixed probability. + Parameters:: -class FixedTotalNum(TwoEndConnector): - """Connect the synaptic neurons with fixed total number. - - Parameters:: - - num : float,int - The conn total number. - allow_multi_conn : bool, optional - Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. - seed: int, optional - The random number seed. - """ - - def __init__(self, - num, - allow_multi_conn=False, - seed=None, **kwargs): - super().__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) - - def build_coo(self): - mat_element_num = self.pre_num * self.post_num - if self.num > mat_element_num: - raise ConnectorError(f'"num" must be smaller than "all2all num", ' - f'but got {self.num} > {mat_element_num}') - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) - selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) - else: - index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) - selected_pre_ids = index // self.post_num - selected_post_ids = index % self.post_num - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' + prob: float + The conn probability. + pre_ratio: float + The ratio of pre-synaptic neurons to connect. + include_self : bool + Whether create (i, i) conn? + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + .. versionadded:: 2.2.3.2 -class FixedNum(TwoEndConnector): - def __init__(self, - num, - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedNum, self).__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.include_self = include_self - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' + seed : optional, int + Seed the random generator. + """ + + def __init__(self, + prob, + pre_ratio=1., + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedProb, self).__init__(**kwargs) + assert 0. <= prob <= 1. + assert 0. <= pre_ratio <= 1. + self.prob = prob + self.pre_ratio = pre_ratio + self.include_self = include_self + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self._jaxrand = bm.random.default_rng(self.seed) + self._nprand = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' + f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' + f'seed={self.seed})') + + def _iii(self): + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + + if self.pre_ratio < 1.: + pre_num_to_select = int(self.pre_num * self.pre_ratio) + pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) + else: + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + + post_num_total = self.post_num + post_num_to_select = int(self.post_num * self.prob) + + if self.allow_multi_conn: + selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self._nprand.randint(0, int(1e8))) + else: + rng = self._nprand + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + def build_mat(self): + if self.pre_ratio < 1.: + pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + else: + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) + mat = bm.asarray(mat) + if not self.include_self: + bm.fill_diagonal(mat, False) + return mat.astype(MAT_DTYPE) -class FixedPreNum(FixedNum): - """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. +class FixedTotalNum(TwoEndConnector): + """Connect the synaptic neurons with fixed total number. + + Parameters:: + + num : float,int + The conn total number. + allow_multi_conn : bool, optional + Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. + seed: int, optional + The random number seed. + """ + + def __init__(self, + num, + allow_multi_conn=False, + seed=None, **kwargs): + super().__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) + + def build_coo(self): + mat_element_num = self.pre_num * self.post_num + if self.num > mat_element_num: + raise ConnectorError(f'"num" must be smaller than "all2all num", ' + f'but got {self.num} > {mat_element_num}') + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) + selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) + else: + index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) + selected_pre_ids = index // self.post_num + selected_post_ids = index % self.post_num + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' - Parameters:: - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def build_coo(self): - if isinstance(self.num, int) and self.num > self.pre_num: - raise ConnectorError(f'"num" must be smaller than "pre_num", ' - f'but got {self.num} > {self.pre_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num - pre_num_total = self.pre_num - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(post_num_total): - posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) - return posts - - selected_pre_ids = jnp.asarray(single_conn()) - - post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select - if not self.include_self: - true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) - post_nums -= jnp.sum(true_ids, axis=1) - selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_pre_ids = selected_pre_ids.flatten() - selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) +class FixedNum(TwoEndConnector): + def __init__(self, + num, + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedNum, self).__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.include_self = include_self + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' -class FixedPostNum(FixedNum): - """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. +class FixedPreNum(FixedNum): + """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. + + Parameters:: + + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def build_coo(self): + if isinstance(self.num, int) and self.num > self.pre_num: + raise ConnectorError(f'"num" must be smaller than "pre_num", ' + f'but got {self.num} > {self.pre_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num + pre_num_total = self.pre_num + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(post_num_total): + posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) + return posts + + selected_pre_ids = jnp.asarray(single_conn()) + + post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select + if not self.include_self: + true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) + post_nums -= jnp.sum(true_ids, axis=1) + selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_pre_ids = selected_pre_ids.flatten() + selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - Parameters:: - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def _ii(self): - if isinstance(self.num, int) and self.num > self.post_num: - raise ConnectorError(f'"num" must be smaller than "post_num", ' - f'but got {self.num} > {self.post_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._ii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) +class FixedPostNum(FixedNum): + """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. + + Parameters:: + + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def _ii(self): + if isinstance(self.num, int) and self.num > self.post_num: + raise ConnectorError(f'"num" must be smaller than "post_num", ' + f'but got {self.num} > {self.post_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._ii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) @jit @partial(vmap, in_axes=(0, None, None)) def gaussian_prob_dist_cal1(i_value, post_values, sigma): - dists = jnp.abs(i_value - post_values) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) + dists = jnp.abs(i_value - post_values) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) @jit @partial(vmap, in_axes=(0, None, None, None)) def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): - dists = jnp.abs(i_value - post_values) - dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) + dists = jnp.abs(i_value - post_values) + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) class GaussianProb(OneEndConnector): - r"""Builds a Gaussian connectivity pattern within a population of neurons, - where the connection probability decay according to the gaussian function. - - Specifically, for any pair of neurons :math:`(i, j)`, - - .. math:: - - p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) - - where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. - - Parameters:: - - sigma : float - Width of the Gaussian function. - encoding_values : optional, list, tuple, int, float - The value ranges to encode for neurons at each axis. - - - If `values` is not provided, the neuron only encodes each positional - information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is - the index in the high-dimensional space. - - If `values` is a single tuple/list of int/float, neurons at each dimension - will encode the same range of values. For example, ``values=(0, np.pi)``, - neurons at each dimension will encode a continuous value space ``[0, np.pi]``. - - If `values` is a tuple/list of list/tuple, it means the value space will be - different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. - - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - normalize : bool - Whether normalize the connection probability . - include_self : bool - Whether create the connection at the same position. - seed : int - The random seed. - """ - - def __init__( - self, - sigma: float, - encoding_values: Optional[np.ndarray] = None, - normalize: bool = True, - include_self: bool = True, - periodic_boundary: bool = False, - seed: int = None, - **kwargs - ): - super(GaussianProb, self).__init__(**kwargs) - self.sigma = sigma - self.encoding_values = encoding_values - self.normalize = normalize - self.include_self = include_self - self.periodic_boundary = periodic_boundary - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(sigma={self.sigma}, ' - f'normalize={self.normalize}, ' - f'periodic_boundary={self.periodic_boundary}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - self.rng = np.random.RandomState(self.seed) - # value range to encode - if self.encoding_values is None: - value_ranges = tuple([(0, s) for s in self.pre_size]) - elif isinstance(self.encoding_values, (tuple, list)): - if len(self.encoding_values) == 0: - raise ConnectorError(f'encoding_values has a length of 0.') - elif isinstance(self.encoding_values[0], (int, float)): - assert len(self.encoding_values) == 2 - assert self.encoding_values[0] < self.encoding_values[1] - value_ranges = tuple([self.encoding_values for _ in self.pre_size]) - elif isinstance(self.encoding_values[0], (tuple, list)): - if len(self.encoding_values) != len(self.pre_size): - raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' - f'the encoded values provided only has {len(self.encoding_values)}-D. ' - f'Error in {str(self)}.') - for v in self.encoding_values: - assert isinstance(v[0], (int, float)) - assert len(v) == 2 - value_ranges = tuple(self.encoding_values) - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - - # values - values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] - # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) - post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) - value_sizes = np.array([v[1] - v[0] for v in value_ranges]) - if value_sizes.ndim < post_values.ndim: - value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - - # probability of connections - if isOptimized: - i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) - for i in range(self.pre_num): - list_index = i - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - i_value_list[list_index] = i_value - - if self.periodic_boundary: - prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) - else: - prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) - else: - prob_mat = [] - for i in range(self.pre_num): - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = np.abs(i_value - post_values) - if self.periodic_boundary: - dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) - exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) - prob_mat.append(exp_dists) - prob_mat = np.stack(prob_mat) - - if self.normalize: - prob_mat /= prob_mat.max() - - # connectivity - conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) - if not self.include_self: - np.fill_diagonal(conn_mat, False) - return conn_mat + r"""Builds a Gaussian connectivity pattern within a population of neurons, + where the connection probability decay according to the gaussian function. + + Specifically, for any pair of neurons :math:`(i, j)`, + + .. math:: + + p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) + + where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. + + Parameters:: + + sigma : float + Width of the Gaussian function. + encoding_values : optional, list, tuple, int, float + The value ranges to encode for neurons at each axis. + + - If `values` is not provided, the neuron only encodes each positional + information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is + the index in the high-dimensional space. + - If `values` is a single tuple/list of int/float, neurons at each dimension + will encode the same range of values. For example, ``values=(0, np.pi)``, + neurons at each dimension will encode a continuous value space ``[0, np.pi]``. + - If `values` is a tuple/list of list/tuple, it means the value space will be + different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. + + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + normalize : bool + Whether normalize the connection probability . + include_self : bool + Whether create the connection at the same position. + seed : int + The random seed. + """ + + def __init__( + self, + sigma: float, + encoding_values: Optional[np.ndarray] = None, + normalize: bool = True, + include_self: bool = True, + periodic_boundary: bool = False, + seed: int = None, + **kwargs + ): + super(GaussianProb, self).__init__(**kwargs) + self.sigma = sigma + self.encoding_values = encoding_values + self.normalize = normalize + self.include_self = include_self + self.periodic_boundary = periodic_boundary + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(sigma={self.sigma}, ' + f'normalize={self.normalize}, ' + f'periodic_boundary={self.periodic_boundary}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + self.rng = np.random.RandomState(self.seed) + # value range to encode + if self.encoding_values is None: + value_ranges = tuple([(0, s) for s in self.pre_size]) + elif isinstance(self.encoding_values, (tuple, list)): + if len(self.encoding_values) == 0: + raise ConnectorError(f'encoding_values has a length of 0.') + elif isinstance(self.encoding_values[0], (int, float)): + assert len(self.encoding_values) == 2 + assert self.encoding_values[0] < self.encoding_values[1] + value_ranges = tuple([self.encoding_values for _ in self.pre_size]) + elif isinstance(self.encoding_values[0], (tuple, list)): + if len(self.encoding_values) != len(self.pre_size): + raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' + f'the encoded values provided only has {len(self.encoding_values)}-D. ' + f'Error in {str(self)}.') + for v in self.encoding_values: + assert isinstance(v[0], (int, float)) + assert len(v) == 2 + value_ranges = tuple(self.encoding_values) + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + + # values + values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] + # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) + post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) + value_sizes = np.array([v[1] - v[0] for v in value_ranges]) + if value_sizes.ndim < post_values.ndim: + value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + + # probability of connections + if isOptimized: + i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) + for i in range(self.pre_num): + list_index = i + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + i_value_list[list_index] = i_value + + if self.periodic_boundary: + prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) + else: + prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) + else: + prob_mat = [] + for i in range(self.pre_num): + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + # distances + dists = np.abs(i_value - post_values) + if self.periodic_boundary: + dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) + exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) + prob_mat.append(exp_dists) + prob_mat = np.stack(prob_mat) + + if self.normalize: + prob_mat /= prob_mat.max() + + # connectivity + conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) + if not self.include_self: + np.fill_diagonal(conn_mat, False) + return conn_mat class SmallWorld(TwoEndConnector): - r"""Build a Watts–Strogatz small-world graph. - - Parameters:: - - num_neighbor : int - Each node is joined with its `k` nearest neighbors in a ring - topology. - prob : float - The probability of rewiring each edge - directed : bool - Whether the graph is a directed graph. - include_self : bool - Whether include the node self. - - Notes:: - - First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is - joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors - if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as - follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with - :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new - edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. - - References:: - - .. [1] Duncan J. Watts and Steven H. Strogatz, - Collective dynamics of small-world networks, - Nature, 393, pp. 440--442, 1998. - """ - - def __init__( - self, - num_neighbor, - prob, - directed=False, - include_self=False, - seed=None, - **kwargs - ): - super(SmallWorld, self).__init__(**kwargs) - self.prob = prob - self.directed = directed - self.num_neighbor = num_neighbor - self.include_self = include_self - - self.seed = format_seed(seed) - self.rng = np.random.RandomState(seed=self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _smallworld_rewire(i, all_j): - if rng.random(1) < prob: - non_connected = np.where(np.logical_not(all_j))[0] - if len(non_connected) <= 1: - return -1 - # Enforce no self-loops or multiple edges - w = rng.choice(non_connected) - while (not include_self) and w == i: - # non_connected.remove(w) - w = rng.choice(non_connected) - return w - else: - return -1 - - self._connect = numba_jit(_smallworld_rewire) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, ' - f'directed={self.directed}, ' - f'num_neighbor={self.num_neighbor}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_conn(self): - assert self.pre_size == self.post_size - - # seed - self.seed = self.rng.randint(1, int(1e7)) - numba_seed(self.seed) - - if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): - num_node = self.pre_num - - if self.num_neighbor > num_node: - raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") - # If k == n, the graph is complete not Watts-Strogatz - if self.num_neighbor == num_node: - conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) - else: - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 - # connect each node to k/2 neighbors - for j in range(1, self.num_neighbor // 2 + 1): - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - conn[nodes, targets] = True - conn[targets, nodes] = True - - # rewire edges from each node - # loop over all nodes in order (label) and neighbors in order (distance) - # no self loops or multiple edges allowed - for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - if self.directed: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(prob=self.prob, i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[u, w] = True - w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) - if w != -1: - conn[v, u] = False - conn[w, u] = True - else: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[v, u] = False - conn[u, w] = True - conn[w, u] = True - # conn = np.asarray(conn, dtype=MAT_DTYPE) - else: - raise ConnectorError('Currently only support 1D ring connection.') - - return 'mat', conn + r"""Build a Watts–Strogatz small-world graph. + + Parameters:: + + num_neighbor : int + Each node is joined with its `k` nearest neighbors in a ring + topology. + prob : float + The probability of rewiring each edge + directed : bool + Whether the graph is a directed graph. + include_self : bool + Whether include the node self. + + Notes:: + + First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is + joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors + if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as + follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with + :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new + edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. + + References:: + + .. [1] Duncan J. Watts and Steven H. Strogatz, + Collective dynamics of small-world networks, + Nature, 393, pp. 440--442, 1998. + """ + + def __init__( + self, + num_neighbor, + prob, + directed=False, + include_self=False, + seed=None, + **kwargs + ): + super(SmallWorld, self).__init__(**kwargs) + self.prob = prob + self.directed = directed + self.num_neighbor = num_neighbor + self.include_self = include_self + + self.seed = format_seed(seed) + self.rng = np.random.RandomState(seed=self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _smallworld_rewire(i, all_j): + if rng.random(1) < prob: + non_connected = np.where(np.logical_not(all_j))[0] + if len(non_connected) <= 1: + return -1 + # Enforce no self-loops or multiple edges + w = rng.choice(non_connected) + while (not include_self) and w == i: + # non_connected.remove(w) + w = rng.choice(non_connected) + return w + else: + return -1 + + self._connect = numba_jit(_smallworld_rewire) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, ' + f'directed={self.directed}, ' + f'num_neighbor={self.num_neighbor}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_conn(self): + assert self.pre_size == self.post_size + + # seed + self.seed = self.rng.randint(1, int(1e7)) + numba_seed(self.seed) + + if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): + num_node = self.pre_num + + if self.num_neighbor > num_node: + raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") + # If k == n, the graph is complete not Watts-Strogatz + if self.num_neighbor == num_node: + conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) + else: + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 + # connect each node to k/2 neighbors + for j in range(1, self.num_neighbor // 2 + 1): + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + conn[nodes, targets] = True + conn[targets, nodes] = True + + # rewire edges from each node + # loop over all nodes in order (label) and neighbors in order (distance) + # no self loops or multiple edges allowed + for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + if self.directed: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(prob=self.prob, i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[u, w] = True + w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) + if w != -1: + conn[v, u] = False + conn[w, u] = True + else: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[v, u] = False + conn[u, w] = True + conn[w, u] = True + # conn = np.asarray(conn, dtype=MAT_DTYPE) + else: + raise ConnectorError('Currently only support 1D ring connection.') + + return 'mat', conn # def _random_subset(seq, m, rng): @@ -655,591 +654,591 @@ def build_conn(self): class ScaleFreeBA(TwoEndConnector): - r"""Build a random graph according to the Barabási–Albert preferential - attachment model. - - A graph of :math:`num\_node` nodes is grown by attaching new nodes each with - :math:`m` edges that are preferentially attached to existing nodes - with high degree. - - Parameters:: - - m : int - Number of edges to attach from a new node to existing nodes - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises:: - - ConnectorError - If `m` does not satisfy ``1 <= m < n``. - - References:: - - .. [1] A. L. Barabási and R. Albert "Emergence of scaling in - random networks", Science 286, pp 509-512, 1999. - """ - - def __init__(self, m, directed=False, seed=None, **kwargs): - super(ScaleFreeBA, self).__init__(**kwargs) - self.m = m - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, ' - f'directed={self.directed}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m < 1 or self.m >= num_node: - raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " - f"m < n, while m = {self.m} and n = {num_node}") - - # Add m initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - # Target nodes for new edges - targets = list(range(self.m)) - # List of existing nodes, with nodes repeated once for each adjacent edge - - if not isOptimized: - repeated_nodes = [] - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * self.m) - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), self.m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) - size_repeated_nodes = 0 - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets - size_repeated_nodes += self.m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source - size_repeated_nodes += self.m - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) - source += 1 - - return conn + r"""Build a random graph according to the Barabási–Albert preferential + attachment model. + + A graph of :math:`num\_node` nodes is grown by attaching new nodes each with + :math:`m` edges that are preferentially attached to existing nodes + with high degree. + + Parameters:: + + m : int + Number of edges to attach from a new node to existing nodes + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises:: + + ConnectorError + If `m` does not satisfy ``1 <= m < n``. + + References:: + + .. [1] A. L. Barabási and R. Albert "Emergence of scaling in + random networks", Science 286, pp 509-512, 1999. + """ + + def __init__(self, m, directed=False, seed=None, **kwargs): + super(ScaleFreeBA, self).__init__(**kwargs) + self.m = m + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, ' + f'directed={self.directed}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m < 1 or self.m >= num_node: + raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " + f"m < n, while m = {self.m} and n = {num_node}") + + # Add m initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + # Target nodes for new edges + targets = list(range(self.m)) + # List of existing nodes, with nodes repeated once for each adjacent edge + + if not isOptimized: + repeated_nodes = [] + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * self.m) + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), self.m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) + size_repeated_nodes = 0 + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets + size_repeated_nodes += self.m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source + size_repeated_nodes += self.m + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) + source += 1 + + return conn class ScaleFreeBADual(TwoEndConnector): - r"""Build a random graph according to the dual Barabási–Albert preferential - attachment model. - - A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ - edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that - are preferentially attached to existing nodes with high degree. - - Parameters:: - - m1 : int - Number of edges to attach from a new node to existing nodes with probability :math:`p` - m2 : int - Number of edges to attach from a new node to existing nodes with probability :math:`1-p` - p : float - The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises:: - - ConnectorError - If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. - - References:: - - .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. - """ - - def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): - super(ScaleFreeBADual, self).__init__(**kwargs) - self.m1 = m1 - self.m2 = m2 - self.p = p - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' - f'p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m1 < 1 or self.m1 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " - f"while m1 = {self.m1} and num_node = {num_node}.") - if self.m2 < 1 or self.m2 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " - f"while m2 = {self.m2} and num_node = {num_node}.") - if self.p < 0 or self.p > 1: - raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") - - # Add max(m1,m2) initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - - if not isOptimized: - # List of existing nodes, with nodes repeated once for each adjacent edge - repeated_nodes = [] - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * m) - # Pick which m to use next time (m1 or m2) + r"""Build a random graph according to the dual Barabási–Albert preferential + attachment model. + + A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ + edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that + are preferentially attached to existing nodes with high degree. + + Parameters:: + + m1 : int + Number of edges to attach from a new node to existing nodes with probability :math:`p` + m2 : int + Number of edges to attach from a new node to existing nodes with probability :math:`1-p` + p : float + The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises:: + + ConnectorError + If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. + + References:: + + .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. + """ + + def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): + super(ScaleFreeBADual, self).__init__(**kwargs) + self.m1 = m1 + self.m2 = m2 + self.p = p + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' + f'p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m1 < 1 or self.m1 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " + f"while m1 = {self.m1} and num_node = {num_node}.") + if self.m2 < 1 or self.m2 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " + f"while m2 = {self.m2} and num_node = {num_node}.") + if self.p < 0 or self.p > 1: + raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") + + # Add max(m1,m2) initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + + if not isOptimized: + # List of existing nodes, with nodes repeated once for each adjacent edge + repeated_nodes = [] + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * m) + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) + size_repeated_nodes = 0 + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) - size_repeated_nodes = 0 - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets - size_repeated_nodes += m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source - size_repeated_nodes += m - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) - source += 1 - - return conn + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets + size_repeated_nodes += m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source + size_repeated_nodes += m + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) + source += 1 + + return conn class PowerLaw(TwoEndConnector): - """Holme and Kim algorithm for growing graphs with powerlaw - degree distribution and approximate average clustering. - - Parameters:: - - m : int - the number of random edges to add for each new node - p : float, - Probability of adding a triangle after adding a random edge - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Notes:: - - The average clustering has a hard time getting above a certain - cutoff that depends on :math:`m`. This cutoff is often quite low. The - transitivity (fraction of triangles to possible triangles) seems to - decrease with network size. - - It is essentially the Barabási–Albert (BA) growth model with an - extra step that each random edge is followed by a chance of - making an edge to one of its neighbors too (and thus a triangle). - - This algorithm improves on BA in the sense that it enables a - higher average clustering to be attained if desired. - - It seems possible to have a disconnected graph with this algorithm - since the initial :math:`m` nodes may not be all linked to a new node - on the first iteration like the BA model. - - Raises:: - - ConnectorError - If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not - satisfy :math:`0 <= p <= 1`. - - References:: - - .. [1] P. Holme and B. J. Kim, - "Growing scale-free networks with tunable clustering", - Phys. Rev. E, 65, 026107, 2002. - """ - - def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): - super(PowerLaw, self).__init__(**kwargs) - self.m = m - self.p = p - if self.p > 1 or self.p < 0: - raise ConnectorError(f"p must be in [0,1], while p={self.p}") - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - num_node = self.pre_num - if self.m < 1 or num_node < self.m: - raise ConnectorError(f"Must have m>1 and m 1 or self.p < 0: + raise ConnectorError(f"p must be in [0,1], while p={self.p}") + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + num_node = self.pre_num + if self.m < 1 or num_node < self.m: + raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) - size = np.prod(pre_size) - - for i in range(size): - pre_pos = np.asarray([p[i] for p in pre_ids]) - pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) - connected_pres.extend(pres) - connected_posts.extend(posts) - return np.asarray(connected_pres), np.asarray(connected_posts) + """Connection with a maximum distance under a probability `p`. + + .. versionadded:: 2.1.13 + + Parameters:: + + dist: float, int + The maximum distance between two points. + prob: float + The connection probability, within 0. and 1. + pre_ratio: float + The ratio of pre-synaptic neurons to connect. + seed: optional, int + The random seed. + include_self: bool + Whether include the point at the same position. + + """ + + def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True, **kwargs): + super(ProbDist, self).__init__(**kwargs) + + self.prob = prob + self.pre_ratio = pre_ratio + self.dist = dist + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + self.include_self = include_self + + rng = np.random if SUPPORT_NUMBA else self.rng + + @numba_jit + def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim): + all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) + all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) + size = 0 + + if rng.random() < pre_ratio: + normalized_pos = np.zeros(n_dim) + for i in range(n_dim): + pre_len = pre_size[i] + post_len = post_size[i] + normalized_pos[i] = pre_pos[i] * post_len / pre_len + for i in range(post_size[0]): + post_pos = np.asarray((i,)) + d = np.abs(pre_pos[0] - post_pos[0]) + if d <= dist: + if d == 0. and not include_self: + continue + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + @numba_jit + def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + + if rng.random() < pre_ratio: + normalized_pos = np.zeros(n_dim) + for i in range(n_dim): + pre_len = pre_size[i] + post_len = post_size[i] + normalized_pos[i] = pre_pos[i] * post_len / pre_len + for i in range(post_size[0]): + for j in range(post_size[1]): + post_pos = np.asarray((i, j)) + d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) + if d <= dist: + if d == 0. and not include_self: + continue + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays + + @numba_jit + def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] * post_size[2] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + + if rng.random() < pre_ratio: + normalized_pos = np.zeros(n_dim) + for i in range(n_dim): + pre_len = pre_size[i] + post_len = post_size[i] + normalized_pos[i] = pre_pos[i] * post_len / pre_len + for i in range(post_size[0]): + for j in range(post_size[1]): + for k in range(post_size[2]): + post_pos = np.asarray((i, j, k)) + d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) + if d <= dist: + if d == 0. and not include_self: + continue + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + @numba_jit + def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + + if rng.random() < pre_ratio: + normalized_pos = np.zeros(n_dim) + for i in range(n_dim): + pre_len = pre_size[i] + post_len = post_size[i] + normalized_pos[i] = pre_pos[i] * post_len / pre_len + for i in range(post_size[0]): + for j in range(post_size[1]): + for k in range(post_size[2]): + for l in range(post_size[3]): + post_pos = np.asarray((i, j, k, l)) + d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) + if d <= dist: + if d == 0. and not include_self: + continue + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + self._connect_1d_jit = _connect_1d_jit + self._connect_2d_jit = _connect_2d_jit + self._connect_3d_jit = _connect_3d_jit + self._connect_4d_jit = _connect_4d_jit + + def build_coo(self, isOptimized=True): + if len(self.pre_size) != len(self.post_size): + raise ValueError('The dimensions of shapes of two objects to establish connections should ' + f'be the same. But we got dimension {len(self.pre_size)} != {len(self.post_size)}. ' + f'Specifically, pre size = {self.pre_size}, post size = {self.post_size}') + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + # connections + n_dim = len(self.pre_size) + if n_dim == 1: + f = self._connect_1d_jit + elif n_dim == 2: + f = self._connect_2d_jit + elif n_dim == 3: + f = self._connect_3d_jit + elif n_dim == 4: + f = self._connect_4d_jit + else: + raise NotImplementedError('Does not support the network dimension bigger than 4.') + + pre_size = np.asarray(self.pre_size) + post_size = np.asarray(self.post_size) + connected_pres = [] + connected_posts = [] + pre_ids = np.meshgrid(*(np.arange(p) for p in self.pre_size), indexing='ij') + pre_ids = tuple([(np.moveaxis(p, 0, 1).flatten()) if p.ndim > 1 else p.flatten() for p in pre_ids]) + size = np.prod(pre_size) + + for i in range(size): + pre_pos = np.asarray([p[i] for p in pre_ids]) + pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) + connected_pres.extend(pres) + connected_posts.extend(posts) + return np.asarray(connected_pres), np.asarray(connected_posts) diff --git a/brainpy/_src/connect/regular_conn.py b/brainpy/_src/connect/regular_conn.py index a038d373c..7ff5a115f 100644 --- a/brainpy/_src/connect/regular_conn.py +++ b/brainpy/_src/connect/regular_conn.py @@ -10,291 +10,291 @@ from .base import * __all__ = [ - 'One2One', 'one2one', - 'All2All', 'all2all', - 'GridFour', 'grid_four', - 'GridEight', 'grid_eight', - 'GridN', + 'One2One', 'one2one', + 'All2All', 'all2all', + 'GridFour', 'grid_four', + 'GridEight', 'grid_eight', + 'GridN', ] class One2One(TwoEndConnector): - """Connect two neuron groups one by one. This means - The two neuron groups should have the same size. - """ - - def __init__(self, *args, **kwargs): - super(One2One, self).__init__(*args, **kwargs) - - def __call__(self, pre_size, post_size): - super(One2One, self).__call__(pre_size, post_size) - try: - assert self.pre_num == self.post_num - except AssertionError: - raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {self.pre_num} != {self.post_num}.') - return self - - def build_coo(self): - if self.pre_num != self.post_num: - raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {self.pre_num} != {self.post_num}.') - return np.arange(self.pre_num, dtype=get_idx_type()), np.arange(self.post_num, dtype=get_idx_type()), - - def build_csr(self): - if self.pre_num != self.post_num: - raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {self.pre_num} != {self.post_num}.') - ind = np.arange(self.pre_num) - indptr = np.arange(self.pre_num + 1) - return (np.asarray(ind, dtype=get_idx_type()), np.asarray(indptr, dtype=get_idx_type())) - - def build_mat(self): - if self.pre_num != self.post_num: - raise ConnectorError(f'One2One connection must be defined in two groups with the ' - f'same size, but {self.pre_num} != {self.post_num}.') - mat = np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE) - np.fill_diagonal(mat, True) - return mat + """Connect two neuron groups one by one. This means + The two neuron groups should have the same size. + """ + + def __init__(self, *args, **kwargs): + super(One2One, self).__init__(*args, **kwargs) + + def __call__(self, pre_size, post_size): + super(One2One, self).__call__(pre_size, post_size) + try: + assert self.pre_num == self.post_num + except AssertionError: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {self.pre_num} != {self.post_num}.') + return self + + def build_coo(self): + if self.pre_num != self.post_num: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {self.pre_num} != {self.post_num}.') + return np.arange(self.pre_num, dtype=get_idx_type()), np.arange(self.post_num, dtype=get_idx_type()), + + def build_csr(self): + if self.pre_num != self.post_num: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {self.pre_num} != {self.post_num}.') + ind = np.arange(self.pre_num) + indptr = np.arange(self.pre_num + 1) + return (np.asarray(ind, dtype=get_idx_type()), np.asarray(indptr, dtype=get_idx_type())) + + def build_mat(self): + if self.pre_num != self.post_num: + raise ConnectorError(f'One2One connection must be defined in two groups with the ' + f'same size, but {self.pre_num} != {self.post_num}.') + mat = np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE) + np.fill_diagonal(mat, True) + return mat one2one = One2One() class All2All(TwoEndConnector): - """Connect each neuron in first group to all neurons in the - post-synaptic neuron groups. It means this kind of conn - will create (num_pre x num_post) synapses. - """ + """Connect each neuron in first group to all neurons in the + post-synaptic neuron groups. It means this kind of conn + will create (num_pre x num_post) synapses. + """ - def __init__(self, *args, include_self: bool = True, **kwargs): - self.include_self = include_self - super(All2All, self).__init__(*args, **kwargs) + def __init__(self, *args, include_self: bool = True, **kwargs): + self.include_self = include_self + super(All2All, self).__init__(*args, **kwargs) - def __repr__(self): - return f'{self.__class__.__name__}(include_self={self.include_self})' + def __repr__(self): + return f'{self.__class__.__name__}(include_self={self.include_self})' - def build_mat(self): - mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) - if not self.include_self: - np.fill_diagonal(mat, False) - return mat + def build_mat(self): + mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE) + if not self.include_self: + np.fill_diagonal(mat, False) + return mat all2all = All2All(include_self=True) def get_size_length(sizes: Union[Tuple, List]): - if not isinstance(sizes, (tuple, list)): - raise TypeError - lengths = [] - a = 1 - for s in reversed(sizes): - lengths.insert(0, a) - a *= s - return np.asarray(lengths) + if not isinstance(sizes, (tuple, list)): + raise TypeError + lengths = [] + a = 1 + for s in reversed(sizes): + lengths.insert(0, a) + a *= s + return np.asarray(lengths) class GridConn(OneEndConnector): - def __init__( - self, - strides, - include_self: bool = False, - periodic_boundary: bool = False, - **kwargs - ): - super(GridConn, self).__init__(**kwargs) - self.strides = strides - self.include_self = include_self - self.periodic_boundary = periodic_boundary - - def __repr__(self): - return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})' - - def _format(self): - dim = len(self.post_size) - if self.pre_num != self.post_num: - raise ConnectorError(f'{self.__class__.__name__} is used to for connection within ' - f'a same population. But we detect pre_num != post_num ' - f'({self.pre_num} != {self.post_num}).') - # point indices - indices = jnp.meshgrid(*(jnp.arange(size) for size in self.post_size), indexing='ij') - indices = jnp.asarray(indices) - indices = indices.reshape(dim, self.post_num).T - lengths = jnp.asarray(self.post_size) - return lengths, dim, indices - - def _get_strides(self, dim): - # increments - increments = np.asarray(np.meshgrid(*(self.strides for _ in range(dim)))).reshape(dim, -1).T - select_ids = self._select_stride(increments) - increments = jnp.asarray(increments[select_ids]) - return increments - - def _select_stride(self, stride: np.ndarray) -> np.ndarray: - raise NotImplementedError - - def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: - raise NotImplementedError - - def build_mat(self): - sizes, _, indices = self._format() - - @jax.vmap - def f_connect(pre_id): - # pre_id: R^(num_dim) - dist = jnp.abs(pre_id - indices) - if self.periodic_boundary: - dist = jnp.where(dist > sizes / 2, sizes - dist, dist) - return self._select_dist(dist) - - return jnp.asarray(f_connect(indices), dtype=MAT_DTYPE) - - def build_coo(self): - sizes, dim, indices = self._format() - strides = self._get_strides(dim) - - @jax.vmap - def f_connect(pre_id): - # pre_id: R^(num_dim) - post_ids = pre_id + strides - if self.periodic_boundary: - post_ids = post_ids % sizes - else: - post_ids = jnp.where(post_ids < sizes, post_ids, -1) - size = len(post_ids) - pre_ids = jnp.repeat(pre_id, size).reshape(dim, size).T - return pre_ids, post_ids - - pres, posts = f_connect(indices) - pres = pres.reshape(-1, dim) - posts = posts.reshape(-1, dim) - idx = jnp.nonzero(jnp.all(posts >= 0, axis=1))[0] - pres = pres[idx] - posts = posts[idx] - if dim == 1: - pres = pres.flatten() - posts = posts.flatten() - else: - strides = jnp.asarray(get_size_length(self.post_size)) - pres = jnp.sum(pres * strides, axis=1) - posts = jnp.sum(posts * strides, axis=1) - return jnp.asarray(pres, dtype=get_idx_type()), jnp.asarray(posts, dtype=get_idx_type()) + def __init__( + self, + strides, + include_self: bool = False, + periodic_boundary: bool = False, + **kwargs + ): + super(GridConn, self).__init__(**kwargs) + self.strides = strides + self.include_self = include_self + self.periodic_boundary = periodic_boundary + + def __repr__(self): + return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})' + + def _format(self): + dim = len(self.post_size) + if self.pre_num != self.post_num: + raise ConnectorError(f'{self.__class__.__name__} is used to for connection within ' + f'a same population. But we detect pre_num != post_num ' + f'({self.pre_num} != {self.post_num}).') + # point indices + indices = jnp.meshgrid(*(jnp.arange(size) for size in self.post_size), indexing='ij') + indices = jnp.asarray(indices) + indices = indices.reshape(dim, self.post_num).T + lengths = jnp.asarray(self.post_size) + return lengths, dim, indices + + def _get_strides(self, dim): + # increments + increments = np.asarray(np.meshgrid(*(self.strides for _ in range(dim)))).reshape(dim, -1).T + select_ids = self._select_stride(increments) + increments = jnp.asarray(increments[select_ids]) + return increments + + def _select_stride(self, stride: np.ndarray) -> np.ndarray: + raise NotImplementedError + + def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: + raise NotImplementedError + + def build_mat(self): + sizes, _, indices = self._format() + + @jax.vmap + def f_connect(pre_id): + # pre_id: R^(num_dim) + dist = jnp.abs(pre_id - indices) + if self.periodic_boundary: + dist = jnp.where(dist > sizes / 2, sizes - dist, dist) + return self._select_dist(dist) + + return jnp.asarray(f_connect(indices), dtype=MAT_DTYPE) + + def build_coo(self): + sizes, dim, indices = self._format() + strides = self._get_strides(dim) + + @jax.vmap + def f_connect(pre_id): + # pre_id: R^(num_dim) + post_ids = pre_id + strides + if self.periodic_boundary: + post_ids = post_ids % sizes + else: + post_ids = jnp.where(post_ids < sizes, post_ids, -1) + size = len(post_ids) + pre_ids = jnp.repeat(pre_id, size).reshape(dim, size).T + return pre_ids, post_ids + + pres, posts = f_connect(indices) + pres = pres.reshape(-1, dim) + posts = posts.reshape(-1, dim) + idx = jnp.nonzero(jnp.all(posts >= 0, axis=1))[0] + pres = pres[idx] + posts = posts[idx] + if dim == 1: + pres = pres.flatten() + posts = posts.flatten() + else: + strides = jnp.asarray(get_size_length(self.post_size)) + pres = jnp.sum(pres * strides, axis=1) + posts = jnp.sum(posts * strides, axis=1) + return jnp.asarray(pres, dtype=get_idx_type()), jnp.asarray(posts, dtype=get_idx_type()) class GridFour(GridConn): - """The nearest four neighbors connection method. + """The nearest four neighbors connection method. - Parameters:: + Parameters:: - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - .. versionadded:: 2.2.3.2 + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + .. versionadded:: 2.2.3.2 - include_self : bool - Whether create connection at the same position. - """ + include_self : bool + Whether create connection at the same position. + """ - def __init__( - self, - include_self: bool = False, - periodic_boundary: bool = False, - **kwargs - ): - super(GridFour, self).__init__(strides=np.asarray([-1, 0, 1]), - include_self=include_self, - periodic_boundary=periodic_boundary, - **kwargs) - self.include_self = include_self - self.periodic_boundary = periodic_boundary + def __init__( + self, + include_self: bool = False, + periodic_boundary: bool = False, + **kwargs + ): + super(GridFour, self).__init__(strides=np.asarray([-1, 0, 1]), + include_self=include_self, + periodic_boundary=periodic_boundary, + **kwargs) + self.include_self = include_self + self.periodic_boundary = periodic_boundary - def _select_stride(self, stride: np.ndarray) -> np.ndarray: - temp = abs(stride).sum(axis=1) - return (temp <= 1) if self.include_self else (temp == 1) + def _select_stride(self, stride: np.ndarray) -> np.ndarray: + temp = abs(stride).sum(axis=1) + return (temp <= 1) if self.include_self else (temp == 1) - def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: - dist = jnp.linalg.norm(dist, axis=1) - return dist <= 1 if self.include_self else dist == 1 + def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: + dist = jnp.linalg.norm(dist, axis=1) + return dist <= 1 if self.include_self else dist == 1 grid_four = GridFour() class GridN(GridConn): - """The nearest (2*N+1) * (2*N+1) neighbors conn method. - - Parameters:: - - N : int - Extend of the conn scope. For example: - When N=1, - [x x x] - [x I x] - [x x x] - When N=2, - [x x x x x] - [x x x x x] - [x x I x x] - [x x x x x] - [x x x x x] - include_self : bool - Whether create (i, i) conn ? - periodic_boundary: bool - Whether the neuron encode the value space with the periodic boundary. - .. versionadded:: 2.2.3.2 - """ - - def __init__( - self, - N: int = 1, - include_self: bool = False, - periodic_boundary: bool = False, - **kwargs - ): - super(GridN, self).__init__(strides=np.arange(-N, N + 1, 1), - include_self=include_self, - periodic_boundary=periodic_boundary, - **kwargs) - self.N = N - - def __repr__(self): - return (f'{self.__class__.__name__}(N={self.N}, ' - f'include_self={self.include_self}, ' - f'periodic_boundary={self.periodic_boundary})') - - def _select_stride(self, stride: np.ndarray) -> np.ndarray: - return (np.ones(len(stride), dtype=bool) - if self.include_self else - (np.sum(np.abs(stride), axis=1) > 0)) - - def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: - if self.include_self: - return jnp.all(dist <= self.N, axis=1) - else: - return jnp.logical_and(jnp.all(dist <= self.N, axis=1), - jnp.logical_not(jnp.all(dist == 0, axis=1))) - - -class GridEight(GridN): - """The nearest eight neighbors conn method. - - Parameters:: - - include_self : bool - Whether create (i, i) conn ? - periodic_boundary: bool - Whether the neurons encode the value space with the periodic boundary. - .. versionadded:: 2.2.3.2 - """ - - def __init__(self, - include_self: bool = False, - periodic_boundary: bool = False, - **kwargs): - super(GridEight, self).__init__(N=1, + """The nearest (2*N+1) * (2*N+1) neighbors conn method. + + Parameters:: + + N : int + Extend of the conn scope. For example: + When N=1, + [x x x] + [x I x] + [x x x] + When N=2, + [x x x x x] + [x x x x x] + [x x I x x] + [x x x x x] + [x x x x x] + include_self : bool + Whether create (i, i) conn ? + periodic_boundary: bool + Whether the neuron encode the value space with the periodic boundary. + .. versionadded:: 2.2.3.2 + """ + + def __init__( + self, + N: int = 1, + include_self: bool = False, + periodic_boundary: bool = False, + **kwargs + ): + super(GridN, self).__init__(strides=np.arange(-N, N + 1, 1), include_self=include_self, periodic_boundary=periodic_boundary, **kwargs) + self.N = N + + def __repr__(self): + return (f'{self.__class__.__name__}(N={self.N}, ' + f'include_self={self.include_self}, ' + f'periodic_boundary={self.periodic_boundary})') + + def _select_stride(self, stride: np.ndarray) -> np.ndarray: + return (np.ones(len(stride), dtype=bool) + if self.include_self else + (np.sum(np.abs(stride), axis=1) > 0)) + + def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: + if self.include_self: + return jnp.all(dist <= self.N, axis=1) + else: + return jnp.logical_and(jnp.all(dist <= self.N, axis=1), + jnp.logical_not(jnp.all(dist == 0, axis=1))) + + +class GridEight(GridN): + """The nearest eight neighbors conn method. + + Parameters:: + + include_self : bool + Whether create (i, i) conn ? + periodic_boundary: bool + Whether the neurons encode the value space with the periodic boundary. + .. versionadded:: 2.2.3.2 + """ + + def __init__(self, + include_self: bool = False, + periodic_boundary: bool = False, + **kwargs): + super(GridEight, self).__init__(N=1, + include_self=include_self, + periodic_boundary=periodic_boundary, + **kwargs) grid_eight = GridEight() diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py index 07422f65e..6cc5daaa5 100644 --- a/brainpy/_src/connect/tests/test_all_time.py +++ b/brainpy/_src/connect/tests/test_all_time.py @@ -8,12 +8,12 @@ pytest.skip('skip.', allow_module_level=True) try: - import pandas as pd + import pandas as pd - df = pd.DataFrame(columns=['connector name', 'superclass', 'connect matrix size', - 'build function', 'other parameter', 'time(ms)']) + df = pd.DataFrame(columns=['connector name', 'superclass', 'connect matrix size', + 'build function', 'other parameter', 'time(ms)']) except (ImportError, ModuleNotFoundError): - print('No pandas installed, skip test.') + print('No pandas installed, skip test.') # size_same = [100, 500, 2500, 12500, 25000, 37500, 50000] # size_same = [100, 500, 2500, 12500] @@ -24,643 +24,643 @@ def get_ms(value): - return round(value * 1000, 4) + return round(value * 1000, 4) def insert_row(connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used): - try: - df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') class OneEndConnector(unittest.TestCase): - def test_gaussian_prob(self): - print() - for size in size_same: - print('GaussianProb:', size) - conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'sigma=1/include_self=False', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['GaussianProb', - # 'OneEndConnector', - # f'{size}x{size}', - # 'build_coo', - # 'sigma=1/include_self=False', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'sigma=1/include_self=False', - time_used) - - def test_grid_four(self): - print() - for size in size_same: - print('GridFour:', size) - conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_eight(self): - print() - for size in size_same: - print('GridEight:', size) - conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_n(self): - print() - for size in size_same: - print('GridN:', size) - conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False/N=2', - time_used) + def test_gaussian_prob(self): + print() + for size in size_same: + print('GaussianProb:', size) + conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'sigma=1/include_self=False', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['GaussianProb', + # 'OneEndConnector', + # f'{size}x{size}', + # 'build_coo', + # 'sigma=1/include_self=False', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'sigma=1/include_self=False', + time_used) + + def test_grid_four(self): + print() + for size in size_same: + print('GridFour:', size) + conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_eight(self): + print() + for size in size_same: + print('GridEight:', size) + conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_n(self): + print() + for size in size_same: + print('GridN:', size) + conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False/N=2', + time_used) class TwoEndConnector(unittest.TestCase): - def test_fixed_prob(self): - print() - for size in size_same: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'prob=0.1', - time_used) - - for size in size_diff: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'prob=0.1', - time_used) - - def test_fixed_pre_num(self): - print() - for size in size_same: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'pre_num=10', - time_used) - - for size in size_diff: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_fixed_post_num(self): - print() - for size in size_same: - print('FixedPostNum:', size) - conn = bp.connect.FixedPostNum(num=10, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - mat = conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num=10', - time_used) - - for size in size_diff: - print('FixedPostNum:', size) - conn = bp.connect.FixedPreNum(num=10, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_prob_dist(self): - print() - for size in size_same: - print('ProbDist:', size) - conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.5', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - def test_small_world(self): - print() - for size in size_same: - print('SmallWorld:', size) - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - def test_scale_free_ba(self): - print() - for size in size_same: - print('ScaleFreeBA:', size) - conn = bp.connect.ScaleFreeBA(m=2) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=2', - time_used) - - def test_scale_free_ba_dual(self): - print() - for size in size_same: - print('ScaleFreeBADual:', size) - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm1=2/m2=3/p=0.4', - time_used) - - def test_power_law(self): - print() - for size in size_same: - print('PowerLaw:', size) - conn = bp.connect.PowerLaw(m=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=3/p=0.4', - time_used) - - def test_one2one(self): - print() - for size in size_same: - print('One2One:', size) - conn = bp.connect.One2One() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - '', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) - - def test_all2all(self): - print() - for size in size_same: - print('All2All:', size) - conn = bp.connect.All2All() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['All2All', - # 'TwoEndConnector', - # f'{size}x{size}', - # 'build_coo', - # '', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) + def test_fixed_prob(self): + print() + for size in size_same: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'prob=0.1', + time_used) + + for size in size_diff: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'prob=0.1', + time_used) + + def test_fixed_pre_num(self): + print() + for size in size_same: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'pre_num=10', + time_used) + + for size in size_diff: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_fixed_post_num(self): + print() + for size in size_same: + print('FixedPostNum:', size) + conn = bp.connect.FixedPostNum(num=10, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + mat = conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num=10', + time_used) + + for size in size_diff: + print('FixedPostNum:', size) + conn = bp.connect.FixedPreNum(num=10, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_prob_dist(self): + print() + for size in size_same: + print('ProbDist:', size) + conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.5', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + def test_small_world(self): + print() + for size in size_same: + print('SmallWorld:', size) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + def test_scale_free_ba(self): + print() + for size in size_same: + print('ScaleFreeBA:', size) + conn = bp.connect.ScaleFreeBA(m=2) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=2', + time_used) + + def test_scale_free_ba_dual(self): + print() + for size in size_same: + print('ScaleFreeBADual:', size) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm1=2/m2=3/p=0.4', + time_used) + + def test_power_law(self): + print() + for size in size_same: + print('PowerLaw:', size) + conn = bp.connect.PowerLaw(m=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=3/p=0.4', + time_used) + + def test_one2one(self): + print() + for size in size_same: + print('One2One:', size) + conn = bp.connect.One2One() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + '', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) + + def test_all2all(self): + print() + for size in size_same: + print('All2All:', size) + conn = bp.connect.All2All() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['All2All', + # 'TwoEndConnector', + # f'{size}x{size}', + # 'build_coo', + # '', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) class TestSave(unittest.TestCase): - def test_save(self): - try: - df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', - index=False) - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + def test_save(self): + try: + df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', + index=False) + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') diff --git a/brainpy/_src/connect/tests/test_custom_conn.py b/brainpy/_src/connect/tests/test_custom_conn.py index d536555f3..cfe0c8b72 100644 --- a/brainpy/_src/connect/tests/test_custom_conn.py +++ b/brainpy/_src/connect/tests/test_custom_conn.py @@ -10,53 +10,53 @@ class TestIJConn(TestCase): - def test_ij(self): - conn = bp.connect.IJConn(i=np.array([0, 1, 2]), - j=np.array([0, 0, 0]))(pre_size=5, post_size=3) + def test_ij(self): + conn = bp.connect.IJConn(i=np.array([0, 1, 2]), + j=np.array([0, 0, 0]))(pre_size=5, post_size=3) - pre2post, post2pre, conn_mat = conn.requires('pre2post', 'post2pre', 'conn_mat') + pre2post, post2pre, conn_mat = conn.requires('pre2post', 'post2pre', 'conn_mat') - assert bp.math.array_equal(pre2post[0], bp.math.array([0, 0, 0])) - assert bp.math.array_equal(post2pre[0], bp.math.array([0, 1, 2])) + assert bp.math.array_equal(pre2post[0], bp.math.array([0, 0, 0])) + assert bp.math.array_equal(post2pre[0], bp.math.array([0, 1, 2])) - a = bp.math.array([[True, False, False], - [True, False, False], - [True, False, False], - [False, False, False], - [False, False, False]]) - print() - print('conn_mat', conn_mat) - assert bp.math.array_equal(conn_mat, a) + a = bp.math.array([[True, False, False], + [True, False, False], + [True, False, False], + [False, False, False], + [False, False, False]]) + print() + print('conn_mat', conn_mat) + assert bp.math.array_equal(conn_mat, a) class TestMatConn(TestCase): - def test_MatConn1(self): - bp.math.random.seed(123) - actual_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_) - conn = bp.connect.MatConn(conn_mat=actual_mat)(pre_size=5, post_size=3) + def test_MatConn1(self): + bp.math.random.seed(123) + actual_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_) + conn = bp.connect.MatConn(conn_mat=actual_mat)(pre_size=5, post_size=3) - pre2post, post2pre, conn_mat = conn.requires('pre2post', 'post2pre', 'conn_mat') + pre2post, post2pre, conn_mat = conn.requires('pre2post', 'post2pre', 'conn_mat') - print() - print('conn_mat', conn_mat) + print() + print('conn_mat', conn_mat) - assert bp.math.array_equal(conn_mat, actual_mat) + assert bp.math.array_equal(conn_mat, actual_mat) - def test_MatConn2(self): - conn = bp.connect.MatConn(conn_mat=np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)) - with pytest.raises(AssertionError): - conn(pre_size=5, post_size=1) + def test_MatConn2(self): + conn = bp.connect.MatConn(conn_mat=np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)) + with pytest.raises(AssertionError): + conn(pre_size=5, post_size=1) class TestSparseMatConn(TestCase): - def test_sparseMatConn(self): - conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_) - sparse_mat = csr_matrix(conn_mat) - conn = bp.conn.SparseMatConn(sparse_mat)(pre_size=sparse_mat.shape[0], post_size=sparse_mat.shape[1]) + def test_sparseMatConn(self): + conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_) + sparse_mat = csr_matrix(conn_mat) + conn = bp.conn.SparseMatConn(sparse_mat)(pre_size=sparse_mat.shape[0], post_size=sparse_mat.shape[1]) - print(conn.requires('pre2post')) + print(conn.requires('pre2post')) - print(conn.requires('conn_mat')) - print(csr_matrix.todense(sparse_mat)) + print(conn.requires('conn_mat')) + print(csr_matrix.todense(sparse_mat)) - assert bp.math.array_equal(conn_mat, bp.math.asarray(csr_matrix.todense(sparse_mat), dtype=bp.math.bool_)) + assert bp.math.array_equal(conn_mat, bp.math.asarray(csr_matrix.todense(sparse_mat), dtype=bp.math.bool_)) diff --git a/brainpy/_src/connect/tests/test_optimized_result.py b/brainpy/_src/connect/tests/test_optimized_result.py index 6eb4d5f2a..1e747d019 100644 --- a/brainpy/_src/connect/tests/test_optimized_result.py +++ b/brainpy/_src/connect/tests/test_optimized_result.py @@ -1,24 +1,20 @@ # -*- coding: utf-8 -*- from datetime import datetime +from time import time import pytest -import unittest -import pytest - import brainpy as bp -from time import time - try: - import pandas as pd + import pandas as pd - df = pd.DataFrame( - columns=['connector name', 'connect matrix size', - 'build function', 'other parameter', - 'time origin(ms)', 'time optimized(ms)']) + df = pd.DataFrame( + columns=['connector name', 'connect matrix size', + 'build function', 'other parameter', + 'time origin(ms)', 'time optimized(ms)']) except (ImportError, ModuleNotFoundError): - pytest.skip('No pandas installed, skip test.', allow_module_level=True) + pytest.skip('No pandas installed, skip test.', allow_module_level=True) # size_same = [100, 500, 2500, 12500, 25000, 37500, 50000] # size_same = [100, 500, 2500, 12500] @@ -26,218 +22,218 @@ def get_ms(value): - return round(value * 1000, 4) + return round(value * 1000, 4) def insert_row(connector_name, connect_matrix_size, build_function, other_parameter, time_origin_used, time_optimized_used): - try: - df.loc[len(df)] = [connector_name, connect_matrix_size, - build_function, other_parameter, - time_origin_used, time_optimized_used] - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.loc[len(df)] = [connector_name, connect_matrix_size, + build_function, other_parameter, + time_origin_used, time_optimized_used] + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') def test_GaussianProb1(): - conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123) - for size in size_same: - conn(pre_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - print() - print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') - insert_row('GaussianProb', - f'{size}x{size}', - 'build_mat', - 'sigma=1 / include_self=False', - time_origin, - time_optimized) + conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123) + for size in size_same: + conn(pre_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + print() + print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') + insert_row('GaussianProb', + f'{size}x{size}', + 'build_mat', + 'sigma=1 / include_self=False', + time_origin, + time_optimized) def test_GaussianProb2(): - conn = bp.connect.GaussianProb(sigma=4, seed=123) - for size in size_same: - conn(pre_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - print() - print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') - insert_row('GaussianProb', - f'{size}x{size}', - 'build_mat', - 'sigma=4', - time_origin, - time_optimized) + conn = bp.connect.GaussianProb(sigma=4, seed=123) + for size in size_same: + conn(pre_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + print() + print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') + insert_row('GaussianProb', + f'{size}x{size}', + 'build_mat', + 'sigma=4', + time_origin, + time_optimized) def test_GaussianProb3(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123) - for size in size_same: - conn(pre_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - print() - print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') - insert_row('GaussianProb', - f'{size}x{size}', - 'build_mat', - 'sigma=4 / periodic_boundary=True', - time_origin, - time_optimized) + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123) + for size in size_same: + conn(pre_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + print() + print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') + insert_row('GaussianProb', + f'{size}x{size}', + 'build_mat', + 'sigma=4 / periodic_boundary=True', + time_origin, + time_optimized) def testGaussianProb4(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123) - for size in size_same: - conn(pre_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - print() - print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') - insert_row('GaussianProb', - f'{size}x{size}', - 'build_mat', - 'sigma=4 / periodic_boundary=True', - time_origin, - time_optimized) + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True, seed=123) + for size in size_same: + conn(pre_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + print() + print(f'time_optimized:{time_optimized}\ntime_origin:{time_origin}') + insert_row('GaussianProb', + f'{size}x{size}', + 'build_mat', + 'sigma=4 / periodic_boundary=True', + time_origin, + time_optimized) def test_ScaleFreeBA(): - conn = bp.connect.ScaleFreeBA(m=2, seed=123) - for size in size_same: - conn(pre_size=size, post_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - insert_row('ScaleFreeBA', - f'{size}x{size}', - 'build_mat', - 'm=2', - time_origin, - time_optimized) + conn = bp.connect.ScaleFreeBA(m=2, seed=123) + for size in size_same: + conn(pre_size=size, post_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + insert_row('ScaleFreeBA', + f'{size}x{size}', + 'build_mat', + 'm=2', + time_origin, + time_optimized) def test_ScaleFreeBADual(): - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4, seed=123) - for size in size_same: - conn(pre_size=size, post_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - insert_row('ScaleFreeBADual', - f'{size}x{size}', - 'build_mat', - 'm1=2 / m2=3 / p=0.4', - time_origin, - time_optimized) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4, seed=123) + for size in size_same: + conn(pre_size=size, post_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + insert_row('ScaleFreeBADual', + f'{size}x{size}', + 'build_mat', + 'm1=2 / m2=3 / p=0.4', + time_origin, + time_optimized) def test_PowerLaw(): - conn = bp.connect.PowerLaw(m=3, p=0.4, seed=123) - for size in size_same: - conn(pre_size=size, post_size=size) - mat = conn.build_mat(isOptimized=True) - time0 = time() - mat1 = conn.build_mat(isOptimized=True) - time_optimized = get_ms(time() - time0) - - mat2 = conn.build_mat(isOptimized=False) - time0 = time() - mat2 = conn.build_mat(isOptimized=False) - time_origin = get_ms(time() - time0) - - assert bp.math.array_equal(mat1, mat2) - insert_row('PowerLaw', - f'{size}x{size}', - 'build_mat', - 'm=3 / p=0.4', - time_origin, - time_optimized) + conn = bp.connect.PowerLaw(m=3, p=0.4, seed=123) + for size in size_same: + conn(pre_size=size, post_size=size) + mat = conn.build_mat(isOptimized=True) + time0 = time() + mat1 = conn.build_mat(isOptimized=True) + time_optimized = get_ms(time() - time0) + + mat2 = conn.build_mat(isOptimized=False) + time0 = time() + mat2 = conn.build_mat(isOptimized=False) + time_origin = get_ms(time() - time0) + + assert bp.math.array_equal(mat1, mat2) + insert_row('PowerLaw', + f'{size}x{size}', + 'build_mat', + 'm=3 / p=0.4', + time_origin, + time_optimized) def test_ProbDist(): - conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=123, include_self=True) - # for size in [1000, (100, 20), (4, 20, 20), (4, 3, 8, 5)]: - for size in [10000]: - conn(pre_size=size, post_size=size) - pre_ids1, post_ids1 = conn.build_coo(isOptimized=True) - time0 = time() - pre_ids1, post_ids1 = conn.build_coo(isOptimized=True) - time_optimized = get_ms(time() - time0) - - pre_ids2, post_ids2 = conn.build_coo(isOptimized=False) - time0 = time() - pre_ids2, post_ids2 = conn.build_coo(isOptimized=False) - time_origin = get_ms(time() - time0) - - # assert (bp.math.array_equal(pre_ids1, pre_ids2) and bp.math.array_equal(post_ids1, post_ids2)) - print() - print(f'time origin: {time_origin}\ntime optimized: {time_optimized}') - insert_row('ProbDist', - {size}, - 'build_coo', - 'dist=1 / prob=0.5 / pre_ratio=0.3 / include_self=True', - time_origin, - time_optimized) + conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=123, include_self=True) + # for size in [1000, (100, 20), (4, 20, 20), (4, 3, 8, 5)]: + for size in [10000]: + conn(pre_size=size, post_size=size) + pre_ids1, post_ids1 = conn.build_coo(isOptimized=True) + time0 = time() + pre_ids1, post_ids1 = conn.build_coo(isOptimized=True) + time_optimized = get_ms(time() - time0) + + pre_ids2, post_ids2 = conn.build_coo(isOptimized=False) + time0 = time() + pre_ids2, post_ids2 = conn.build_coo(isOptimized=False) + time_origin = get_ms(time() - time0) + + # assert (bp.math.array_equal(pre_ids1, pre_ids2) and bp.math.array_equal(post_ids1, post_ids2)) + print() + print(f'time origin: {time_origin}\ntime optimized: {time_optimized}') + insert_row('ProbDist', + {size}, + 'build_coo', + 'dist=1 / prob=0.5 / pre_ratio=0.3 / include_self=True', + time_origin, + time_optimized) def test_save(): - try: - df.to_csv('opt_time_compare' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', - index=False) - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.to_csv('opt_time_compare' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', + index=False) + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') diff --git a/brainpy/_src/connect/tests/test_random_conn.py b/brainpy/_src/connect/tests/test_random_conn.py index 68531ded7..1f17977f5 100644 --- a/brainpy/_src/connect/tests/test_random_conn.py +++ b/brainpy/_src/connect/tests/test_random_conn.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -import pytest - import unittest +import pytest + import brainpy as bp @@ -87,11 +87,13 @@ def test_random_fix_post3(): conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) conn1.require(bp.connect.CONN_MAT) + def test_random_fix_total1(): conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=False, seed=1234)(pre_size=3, post_size=4) coo1 = conn1.require(bp.connect.COO) conn_mat = bp.connect.coo2mat_num(ij=coo1, num_pre=3, num_post=4, num=conn1.num, seed=1234) + def test_random_fix_total2(): conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=True, seed=1234)(pre_size=3, post_size=4) mat1 = conn1.require(bp.connect.CONN_MAT) diff --git a/brainpy/_src/connect/tests/test_random_conn_visualize.py b/brainpy/_src/connect/tests/test_random_conn_visualize.py index ba0d95f13..29b16e4ba 100644 --- a/brainpy/_src/connect/tests/test_random_conn_visualize.py +++ b/brainpy/_src/connect/tests/test_random_conn_visualize.py @@ -8,172 +8,172 @@ def test_random_fix_pre1(): - for num in [0.4, 20]: - conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat1 = conn1.require(bp.connect.CONN_MAT) + for num in [0.4, 20]: + conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat1 = conn1.require(bp.connect.CONN_MAT) - conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat2 = conn2.require(bp.connect.CONN_MAT) + conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat2 = conn2.require(bp.connect.CONN_MAT) - print() - print(f'num = {num}') - print('conn_mat 1\n', mat1) - print(mat1.sum()) - print('conn_mat 2\n', mat2) - print(mat2.sum()) + print() + print(f'num = {num}') + print('conn_mat 1\n', mat1) + print(mat1.sum()) + print('conn_mat 2\n', mat2) + print(mat2.sum()) - assert bp.math.array_equal(mat1, mat2) - bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) + assert bp.math.array_equal(mat1, mat2) + bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) def test_random_fix_pre2(): - for num in [0.5, 3]: - conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4) - mat1 = conn1.require(bp.connect.CONN_MAT) - print() - print(mat1) + for num in [0.5, 3]: + conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + print() + print(mat1) - bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num) + bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num) def test_random_fix_pre3(): - with pytest.raises(bp.errors.ConnectorError): - conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) - conn1.require(bp.connect.CONN_MAT) + with pytest.raises(bp.errors.ConnectorError): + conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) + conn1.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4') + bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4') def test_random_fix_post1(): - for num in [0.4, 20]: - conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat1 = conn1.require(bp.connect.CONN_MAT) + for num in [0.4, 20]: + conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat1 = conn1.require(bp.connect.CONN_MAT) - conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat2 = conn2.require(bp.connect.CONN_MAT) + conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat2 = conn2.require(bp.connect.CONN_MAT) - print() - print('conn_mat 1\n', mat1) - print('conn_mat 2\n', mat2) + print() + print('conn_mat 1\n', mat1) + print('conn_mat 2\n', mat2) - assert bp.math.array_equal(mat1, mat2) - bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) + assert bp.math.array_equal(mat1, mat2) + bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) def test_random_fix_post2(): - for num in [0.5, 3]: - conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4) - mat1 = conn1.require(bp.connect.CONN_MAT) - print(mat1) - bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num) + for num in [0.5, 3]: + conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + print(mat1) + bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num) def test_random_fix_post3(): - with pytest.raises(bp.errors.ConnectorError): - conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) - conn1.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4') + with pytest.raises(bp.errors.ConnectorError): + conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) + conn1.require(bp.connect.CONN_MAT) + bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4') def test_gaussian_prob1(): - conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100') def test_gaussian_prob2(): - conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50)) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50)) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)') def test_gaussian_prob3(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50)) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50)) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)') def test_gaussian_prob4(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10)) - conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - mat = conn.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)') + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10)) + conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + mat = conn.require(bp.connect.CONN_MAT) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)') def test_SmallWorld1(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=10, post_size=10) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=10, post_size=10) - mat = conn.require(bp.connect.CONN_MAT) + mat = conn.require(bp.connect.CONN_MAT) - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10') + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10') def test_SmallWorld3(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True) - conn(pre_size=20, post_size=20) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True) + conn(pre_size=20, post_size=20) - mat = conn.require(bp.connect.CONN_MAT) + mat = conn.require(bp.connect.CONN_MAT) - print('conn_mat', mat) + print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20') + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20') def test_SmallWorld2(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5) - conn(pre_size=(100,), post_size=(100,)) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)') - - -def test_ScaleFreeBA(): - conn = bp.connect.ScaleFreeBA(m=2) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5) + conn(pre_size=(100,), post_size=(100,)) mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, bp.connect.PRE_IDS, bp.connect.POST_IDS, bp.connect.PRE2POST, bp.connect.POST_IDS) print() print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size)) + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)') + + +def test_ScaleFreeBA(): + conn = bp.connect.ScaleFreeBA(m=2) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size)) def test_ScaleFreeBADual(): - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) def test_PowerLaw(): - conn = bp.connect.PowerLaw(m=3, p=0.4) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) + conn = bp.connect.PowerLaw(m=3, p=0.4) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) diff --git a/brainpy/_src/connect/tests/test_regular_conn.py b/brainpy/_src/connect/tests/test_regular_conn.py index 4fe2ab85d..4237d7c06 100644 --- a/brainpy/_src/connect/tests/test_regular_conn.py +++ b/brainpy/_src/connect/tests/test_regular_conn.py @@ -1,104 +1,104 @@ # -*- coding: utf-8 -*- +import unittest + import numpy as np import brainpy as bp from brainpy import connect -import unittest - class TestOne2One(unittest.TestCase): - def test_one2one(self): - for size in [100, (3, 4), (4, 5, 6)]: - conn = connect.One2One()(pre_size=size, post_size=size) + def test_one2one(self): + for size in [100, (3, 4), (4, 5, 6)]: + conn = connect.One2One()(pre_size=size, post_size=size) - conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ - conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') + conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ + conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') - num = bp.tools.size2num(size) + num = bp.tools.size2num(size) - actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_) - bp.math.fill_diagonal(actual_mat, True) + actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_) + bp.math.fill_diagonal(actual_mat, True) - assert bp.math.array_equal(actual_mat, conn_mat) - assert bp.math.array_equal(pre_ids, bp.math.arange(num)) - assert bp.math.array_equal(post_ids, bp.math.arange(num)) + assert bp.math.array_equal(actual_mat, conn_mat) + assert bp.math.array_equal(pre_ids, bp.math.arange(num)) + assert bp.math.array_equal(post_ids, bp.math.arange(num)) - print('conn_mat', conn_mat) - print('pre_ids', pre_ids) - print('post_ids', post_ids) - print('pre2post', pre2post) - print('post2pre', post2pre) - print('pre2syn', pre2syn) - print('post2syn', post2syn) + print('conn_mat', conn_mat) + print('pre_ids', pre_ids) + print('post_ids', post_ids) + print('pre2post', pre2post) + print('post2pre', post2pre) + print('pre2syn', pre2syn) + print('post2syn', post2syn) class TestAll2All(unittest.TestCase): - def test_all2all(self): - for has_self in [True, False]: - for size in [100, (3, 4), (4, 5, 6)]: - conn = connect.All2All(include_self=has_self)(pre_size=size, post_size=size) - mat = conn.require(connect.CONN_MAT) - conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ - conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') - num = bp.tools.size2num(size) - - print(mat) - actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_) - if not has_self: - bp.math.fill_diagonal(actual_mat, False) - assert bp.math.array_equal(actual_mat, mat) - - print() - print('conn_mat', conn_mat) - print('pre_ids', pre_ids) - print('post_ids', post_ids) - print('pre2post', pre2post) - print('post2pre', post2pre) - print('pre2syn', pre2syn) - print('post2syn', post2syn) + def test_all2all(self): + for has_self in [True, False]: + for size in [100, (3, 4), (4, 5, 6)]: + conn = connect.All2All(include_self=has_self)(pre_size=size, post_size=size) + mat = conn.require(connect.CONN_MAT) + conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \ + conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn') + num = bp.tools.size2num(size) + + print(mat) + actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_) + if not has_self: + bp.math.fill_diagonal(actual_mat, False) + assert bp.math.array_equal(actual_mat, mat) + + print() + print('conn_mat', conn_mat) + print('pre_ids', pre_ids) + print('post_ids', post_ids) + print('pre2post', pre2post) + print('post2pre', post2pre) + print('pre2syn', pre2syn) + print('post2syn', post2syn) class TestGridConn(unittest.TestCase): - def test_grid_four(self): - for periodic_boundary in [True, False]: - for include_self in [True, False]: - for size in (10, [10, 10], (4, 4, 5)): - conn = bp.conn.GridFour(include_self=include_self, - periodic_boundary=periodic_boundary)(size, size) - mat = conn.build_mat() - pre_ids, post_ids = conn.build_coo() - new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) - new_mat[pre_ids, post_ids] = True - - print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') - self.assertTrue(bp.math.allclose(mat, new_mat)) - - def test_grid_eight(self): - for periodic_boundary in [True, False]: - for include_self in [True, False]: - for size in (10, [10, 10], (4, 4, 5)): - conn = bp.conn.GridEight(include_self=include_self, - periodic_boundary=periodic_boundary)(size, size) - mat = conn.build_mat() - pre_ids, post_ids = conn.build_coo() - new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) - new_mat[pre_ids, post_ids] = True - - print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') - self.assertTrue(bp.math.allclose(mat, new_mat)) - - def test_grid_N(self): - for periodic_boundary in [True, False]: - for include_self in [True, False]: - for size in (10, [10, 10], (4, 4, 5)): - conn = bp.conn.GridN(include_self=include_self, - periodic_boundary=periodic_boundary, - N=2)(size, size) - mat = conn.build_mat() - pre_ids, post_ids = conn.build_coo() - new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) - new_mat[pre_ids, post_ids] = True - - print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') - self.assertTrue(bp.math.allclose(mat, new_mat)) + def test_grid_four(self): + for periodic_boundary in [True, False]: + for include_self in [True, False]: + for size in (10, [10, 10], (4, 4, 5)): + conn = bp.conn.GridFour(include_self=include_self, + periodic_boundary=periodic_boundary)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() + new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) + new_mat[pre_ids, post_ids] = True + + print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') + self.assertTrue(bp.math.allclose(mat, new_mat)) + + def test_grid_eight(self): + for periodic_boundary in [True, False]: + for include_self in [True, False]: + for size in (10, [10, 10], (4, 4, 5)): + conn = bp.conn.GridEight(include_self=include_self, + periodic_boundary=periodic_boundary)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() + new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) + new_mat[pre_ids, post_ids] = True + + print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') + self.assertTrue(bp.math.allclose(mat, new_mat)) + + def test_grid_N(self): + for periodic_boundary in [True, False]: + for include_self in [True, False]: + for size in (10, [10, 10], (4, 4, 5)): + conn = bp.conn.GridN(include_self=include_self, + periodic_boundary=periodic_boundary, + N=2)(size, size) + mat = conn.build_mat() + pre_ids, post_ids = conn.build_coo() + new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool) + new_mat[pre_ids, post_ids] = True + + print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}') + self.assertTrue(bp.math.allclose(mat, new_mat)) diff --git a/brainpy/_src/context.py b/brainpy/_src/context.py index 2f9336271..d5f7d0aed 100644 --- a/brainpy/_src/context.py +++ b/brainpy/_src/context.py @@ -11,61 +11,61 @@ from brainpy._src.tools.dicts import DotDict __all__ = [ - 'share', + 'share', ] class _ShareContext: - def __init__(self): - super().__init__() - - # Shared data across all nodes at current time step. - # ------------- - - self._arguments = DotDict() - self._category = dict() - - @property - def dt(self): - return brainstate.environ.get_dt() - - @dt.setter - def dt(self, dt): - self.set_dt(dt) - - def set_dt(self, dt: Union[int, float]): - brainstate.environ.set(dt=dt) - - def load(self, key, value: Any = None, desc: str = None): - """Load the shared data by the ``key``. - - Args: - key (str): the key to indicate the data. - value (Any): the default value when ``key`` is not defined in the shared. - desc: (str): the description of the key. - """ - return brainstate.environ.get(key, value, desc) - - def save(self, *args, **kwargs) -> None: - """Save shared arguments in the global context.""" - assert len(args) % 2 == 0 - for i in range(0, len(args), 2): - identifier = args[i] - data = args[i + 1] - brainstate.environ.set(**{identifier: data}) - brainstate.environ.set(**kwargs) - - def __setitem__(self, key, value): - """Enable setting the shared item by ``bp.share[key] = value``.""" - self.save(key, value) - - def __getitem__(self, item): - """Enable loading the shared parameter by ``bp.share[key]``.""" - return self.load(item) - - def get_shargs(self) -> DotDict: - """Get all shared arguments in the global context.""" - return DotDict(brainstate.environ.all()) + def __init__(self): + super().__init__() + + # Shared data across all nodes at current time step. + # ------------- + + self._arguments = DotDict() + self._category = dict() + + @property + def dt(self): + return brainstate.environ.get_dt() + + @dt.setter + def dt(self, dt): + self.set_dt(dt) + + def set_dt(self, dt: Union[int, float]): + brainstate.environ.set(dt=dt) + + def load(self, key, value: Any = None, desc: str = None): + """Load the shared data by the ``key``. + + Args: + key (str): the key to indicate the data. + value (Any): the default value when ``key`` is not defined in the shared. + desc: (str): the description of the key. + """ + return brainstate.environ.get(key, value, desc) + + def save(self, *args, **kwargs) -> None: + """Save shared arguments in the global context.""" + assert len(args) % 2 == 0 + for i in range(0, len(args), 2): + identifier = args[i] + data = args[i + 1] + brainstate.environ.set(**{identifier: data}) + brainstate.environ.set(**kwargs) + + def __setitem__(self, key, value): + """Enable setting the shared item by ``bp.share[key] = value``.""" + self.save(key, value) + + def __getitem__(self, item): + """Enable loading the shared parameter by ``bp.share[key]``.""" + return self.load(item) + + def get_shargs(self) -> DotDict: + """Get all shared arguments in the global context.""" + return DotDict(brainstate.environ.all()) share = _ShareContext() diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index b21a231ab..7337a0cc3 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -19,549 +19,547 @@ from brainpy._src.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay from brainpy.check import jit_error - __all__ = [ - 'Delay', - 'VarDelay', - 'DataDelay', - 'DelayAccess', + 'Delay', + 'VarDelay', + 'DataDelay', + 'DelayAccess', ] - delay_identifier = '_*_delay_of_' def _get_delay(delay_time, delay_step): - if delay_time is None: - if delay_step is None: - return None, None + if delay_time is None: + if delay_step is None: + return None, None + else: + assert isinstance(delay_step, int), '"delay_step" should be an integer.' + delay_time = delay_step * bm.get_dt() else: - assert isinstance(delay_step, int), '"delay_step" should be an integer.' - delay_time = delay_step * bm.get_dt() - else: - assert delay_step is None, '"delay_step" should be None if "delay_time" is given.' - assert isinstance(delay_time, (int, float)) - delay_step = math.ceil(delay_time / bm.get_dt()) - return delay_time, delay_step + assert delay_step is None, '"delay_step" should be None if "delay_time" is given.' + assert isinstance(delay_time, (int, float)) + delay_step = math.ceil(delay_time / bm.get_dt()) + return delay_time, delay_step class Delay(DynamicalSystem, ParamDesc): - """Base class for delay variables. - - Args: - time: The delay time. - init: The initial delay data. - method: The delay method. Can be ``rotation`` and ``concat``. - name: The delay name. - mode: The computing mode. - """ - - max_time: float - max_length: int - data: Optional[bm.Variable] - - def __init__( - self, - # delay time - time: Optional[Union[int, float]] = None, - - # delay init - init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, - - # delay method - method: Optional[str] = None, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # delay method - if method is None: - if self.mode.is_one_of(bm.NonBatchingMode, bm.BatchingMode): - method = ROTATE_UPDATE - elif self.mode.is_a(bm.TrainingMode): - method = CONCAT_UPDATE - else: - method = ROTATE_UPDATE - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.method = method - - # delay length - if time is None: - length = 0 - time = 0. - elif isinstance(time, (int, float)): - length = int(time / bm.get_dt()) - else: - raise TypeError('time must be a int or float or None.') - assert isinstance(length, int) - self.max_length = length - self.max_time = time - - # delay data - if init is not None: - assert isinstance(init, (numbers.Number, bm.BaseArray, jax.Array, Callable)) - self._init = init - - # other info - self._registered_entries = dict() - - def register_entry( - self, - entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]] = None, - delay_step: Optional[int] = None - ) -> 'Delay': - """Register an entry to access the data. + """Base class for delay variables. Args: - entry: str. The entry to access the delay data. - delay_time: The delay time of the entry (can be a float). - delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``. - - Returns: - Return the self. + time: The delay time. + init: The initial delay data. + method: The delay method. Can be ``rotation`` and ``concat``. + name: The delay name. + mode: The computing mode. """ - raise NotImplementedError - def at(self, entry: str, *indices) -> bm.Array: - """Get the data at the given entry. + max_time: float + max_length: int + data: Optional[bm.Variable] + + def __init__( + self, + # delay time + time: Optional[Union[int, float]] = None, + + # delay init + init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, + + # delay method + method: Optional[str] = None, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # delay method + if method is None: + if self.mode.is_one_of(bm.NonBatchingMode, bm.BatchingMode): + method = ROTATE_UPDATE + elif self.mode.is_a(bm.TrainingMode): + method = CONCAT_UPDATE + else: + method = ROTATE_UPDATE + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method + + # delay length + if time is None: + length = 0 + time = 0. + elif isinstance(time, (int, float)): + length = int(time / bm.get_dt()) + else: + raise TypeError('time must be a int or float or None.') + assert isinstance(length, int) + self.max_length = length + self.max_time = time - Args: - entry: str. The entry to access the data. - *indices: The slicing indices. + # delay data + if init is not None: + assert isinstance(init, (numbers.Number, bm.BaseArray, jax.Array, Callable)) + self._init = init - Returns: - The data. - """ - raise NotImplementedError + # other info + self._registered_entries = dict() - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[int] = None + ) -> 'Delay': + """Register an entry to access the data. - Parameters:: + Args: + entry: str. The entry to access the delay data. + delay_time: The delay time of the entry (can be a float). + delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``. - delay_step: int, ArrayType - The delay length used to retrieve the data. - """ - raise NotImplementedError() + Returns: + Return the self. + """ + raise NotImplementedError + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. -def _check_target_sharding(sharding, ndim, mode: bm.Mode): - if sharding is not None: - if len(sharding) == ndim: - sharding = list(sharding) - elif len(sharding) + 1 == ndim and mode.is_child_of(bm.BatchingMode): - sharding = list(sharding) - sharding.insert(0, bm.sharding.BATCH_AXIS) - else: - raise ValueError('sharding axis names do not match the target dimension. ') - return sharding + Args: + entry: str. The entry to access the data. + *indices: The slicing indices. + Returns: + The data. + """ + raise NotImplementedError -class VarDelay(Delay): - """Generate Delays for the given :py:class:`~.Variable` instance. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Args: - target: Variable. The delay target. - time: int, float. The delay time. - init: Any. The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - entries: optional, dict. The delay access entries. - name: str. The delay name. - method: str. The method used for updating delay. Default None. - mode: Mode. The computing mode. Default None. + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. - """ + Parameters:: - not_desc_params = ('time', 'entries', 'name') + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + raise NotImplementedError() - def __init__( - self, - # delay target - target: bm.Variable, +def _check_target_sharding(sharding, ndim, mode: bm.Mode): + if sharding is not None: + if len(sharding) == ndim: + sharding = list(sharding) + elif len(sharding) + 1 == ndim and mode.is_child_of(bm.BatchingMode): + sharding = list(sharding) + sharding.insert(0, bm.sharding.BATCH_AXIS) + else: + raise ValueError('sharding axis names do not match the target dimension. ') + return sharding - # delay time - time: Optional[Union[int, float]] = None, - # delay init - init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, +class VarDelay(Delay): + """Generate Delays for the given :py:class:`~.Variable` instance. - # delay access entry - entries: Optional[Dict] = None, + The data in this delay variable is arranged as:: - # delay method - method: Optional[str] = None, + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(time=time, init=init, method=method, name=name, mode=mode) + Args: + target: Variable. The delay target. + time: int, float. The delay time. + init: Any. The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + entries: optional, dict. The delay access entries. + name: str. The delay name. + method: str. The method used for updating delay. Default None. + mode: Mode. The computing mode. Default None. - # check - if not isinstance(target, bm.Variable): - raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') + """ - if self.mode.is_child_of(bm.BatchingMode): - assert target.batch_axis is not None + not_desc_params = ('time', 'entries', 'name') - # sharding - sharding = None - if target.axis_names is not None: - sharding = list(target.axis_names) - sharding.insert(0, bm.sharding.TIME_AXIS) - sharding = tuple(sharding) - self.sharding = bm.sharding.get_sharding(sharding) + def __init__( + self, - # target - self.target = target + # delay target + target: bm.Variable, - # delay data - self._init = init - if self.max_length > 0: - self._init_data(self.max_length) - else: - self.data = None + # delay time + time: Optional[Union[int, float]] = None, - # other info - if entries is not None: - for entry, delay_time in entries.items(): - self.register_entry(entry, delay_time) + # delay init + init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, - def register_entry( - self, - entry: str, - delay_time: Optional[Union[int, float]] = None, - delay_step: Optional[int] = None, - ) -> 'Delay': - """Register an entry to access the data. + # delay access entry + entries: Optional[Dict] = None, - Args: - entry: str. The entry to access the delay data. - delay_time: The delay time of the entry (can be a float). - delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``. + # delay method + method: Optional[str] = None, - Returns: - Return the self. - """ - if entry in self._registered_entries: - raise KeyError(f'Entry {entry} has been registered. ' - f'The existing delay for the key {entry} is {self._registered_entries[entry]}. ' - f'The new delay for the key {entry} is {delay_time}. ' - f'You can use another key. ') - - if isinstance(delay_time, (np.ndarray, jax.Array)): - assert delay_time.size == 1 and delay_time.ndim == 0 - delay_time = delay_time.item() - - _, delay_step = _get_delay(delay_time, delay_step) - - # delay variable - if delay_step is not None: - if self.max_length < delay_step: - self._init_data(delay_step) - self.max_length = delay_step - self.max_time = delay_time - self._registered_entries[entry] = delay_step - return self - - def at(self, entry: str, *indices) -> bm.Array: - """Get the data at the given entry. + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(time=time, init=init, method=method, name=name, mode=mode) - Args: - entry: str. The entry to access the data. - *indices: The slicing indices. Not include the slice at the batch dimension. + # check + if not isinstance(target, bm.Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') - Returns: - The data. - """ - assert isinstance(entry, str), 'entry should be a string for describing the ' - if entry not in self._registered_entries: - raise KeyError(f'Does not find delay entry "{entry}".') - delay_step = self._registered_entries[entry] - if isinstance(self.mode, bm.BatchingMode) and len(indices) > self.target.batch_axis: - indices = list(indices) - indices.insert(self.target.batch_axis, slice(None, None, None)) - indices = tuple(indices) - - if delay_step is None or delay_step == 0.: - if len(indices): - return self.target[indices] - else: - return self.target.value - else: - return self.retrieve(delay_step, *indices) + if self.mode.is_child_of(bm.BatchingMode): + assert target.batch_axis is not None - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.target.shape + # sharding + sharding = None + if target.axis_names is not None: + sharding = list(target.axis_names) + sharding.insert(0, bm.sharding.TIME_AXIS) + sharding = tuple(sharding) + self.sharding = bm.sharding.get_sharding(sharding) - def __repr__(self): - name = self.__class__.__name__ - return f'{name}(step={self.max_length}, shape={self.delay_target_shape}, method={self.method})' + # target + self.target = target - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.max_length}. ' - f'But we got {delay_len}') + # delay data + self._init = init + if self.max_length > 0: + self._init_data(self.max_length) + else: + self.data = None + + # other info + if entries is not None: + for entry, delay_time in entries.items(): + self.register_entry(entry, delay_time) + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[int, float]] = None, + delay_step: Optional[int] = None, + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry: str. The entry to access the delay data. + delay_time: The delay time of the entry (can be a float). + delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``. + + Returns: + Return the self. + """ + if entry in self._registered_entries: + raise KeyError(f'Entry {entry} has been registered. ' + f'The existing delay for the key {entry} is {self._registered_entries[entry]}. ' + f'The new delay for the key {entry} is {delay_time}. ' + f'You can use another key. ') + + if isinstance(delay_time, (np.ndarray, jax.Array)): + assert delay_time.size == 1 and delay_time.ndim == 0 + delay_time = delay_time.item() + + _, delay_step = _get_delay(delay_time, delay_step) + + # delay variable + if delay_step is not None: + if self.max_length < delay_step: + self._init_data(delay_step) + self.max_length = delay_step + self.max_time = delay_time + self._registered_entries[entry] = delay_step + return self + + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry: str. The entry to access the data. + *indices: The slicing indices. Not include the slice at the batch dimension. + + Returns: + The data. + """ + assert isinstance(entry, str), 'entry should be a string for describing the ' + if entry not in self._registered_entries: + raise KeyError(f'Does not find delay entry "{entry}".') + delay_step = self._registered_entries[entry] + if isinstance(self.mode, bm.BatchingMode) and len(indices) > self.target.batch_axis: + indices = list(indices) + indices.insert(self.target.batch_axis, slice(None, None, None)) + indices = tuple(indices) + + if delay_step is None or delay_step == 0.: + if len(indices): + return self.target[indices] + else: + return self.target.value + else: + return self.retrieve(delay_step, *indices) - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.target.shape - Parameters:: + def __repr__(self): + name = self.__class__.__name__ + return f'{name}(step={self.max_length}, shape={self.delay_target_shape}, method={self.method})' - delay_step: int, Array - The delay length used to retrieve the data. - """ - assert self.data is not None - assert delay_step is not None - if check.is_checking(): - jit_error(delay_step > self.max_length, self._check_delay, delay_step) + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.max_length}. ' + f'But we got {delay_len}') - if self.method == ROTATE_UPDATE: - i = share.load('i') - delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length, dtype=jnp.int32) - delay_idx = jax.lax.stop_gradient(delay_idx) + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. - elif self.method == CONCAT_UPDATE: - delay_idx = delay_step - 1 + Parameters:: - else: - raise ValueError(f'Unknown updating method "{self.method}"') + delay_step: int, Array + The delay length used to retrieve the data. + """ + assert self.data is not None + assert delay_step is not None + if check.is_checking(): + jit_error(delay_step > self.max_length, self._check_delay, delay_step) - # the delay index - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + indices + if self.method == ROTATE_UPDATE: + i = share.load('i') + delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length, dtype=jnp.int32) + delay_idx = jax.lax.stop_gradient(delay_idx) - # the delay data - return self.data[indices] + elif self.method == CONCAT_UPDATE: + delay_idx = delay_step - 1 - def update( - self, - latest_value: Optional[Union[bm.Array, jax.Array]] = None - ) -> None: - """Update delay variable with the new data. - """ - if self.data is not None: - # jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value)) - - # get the latest target value - if latest_value is None: - latest_value = self.target.value - - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - i = share.load('i') - idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32) - self.data[jax.lax.stop_gradient(idx)] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.max_length > 1: - latest_value = bm.expand_dims(latest_value, 0) - self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0) else: - self.data[0] = latest_value - - else: - raise ValueError(f'Unknown updating method "{self.method}"') - - def reset_state(self, batch_size: int = None, **kwargs): - """Reset the delay data. - """ - # initialize delay data - if self.data is not None: - self._init_data(self.max_length, batch_size) - - def _init_data(self, length: int, batch_size: int = None): - if batch_size is not None: - if self.target.batch_size != batch_size: - raise ValueError(f'The batch sizes of delay variable and target variable differ ' - f'({self.target.batch_size} != {batch_size}). ' - 'Please reset the target variable first, because delay data ' - 'depends on the target variable. ') - - if self.target.batch_axis is None: - batch_axis = None - else: - batch_axis = self.target.batch_axis + 1 + raise ValueError(f'Unknown updating method "{self.method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + indices + + # the delay data + return self.data[indices] + + def update( + self, + latest_value: Optional[Union[bm.Array, jax.Array]] = None + ) -> None: + """Update delay variable with the new data. + """ + if self.data is not None: + # jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value)) + + # get the latest target value + if latest_value is None: + latest_value = self.target.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + i = share.load('i') + idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32) + self.data[jax.lax.stop_gradient(idx)] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.max_length > 1: + latest_value = bm.expand_dims(latest_value, 0) + self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0) + else: + self.data[0] = latest_value + + else: + raise ValueError(f'Unknown updating method "{self.method}"') + + def reset_state(self, batch_size: int = None, **kwargs): + """Reset the delay data. + """ + # initialize delay data + if self.data is not None: + self._init_data(self.max_length, batch_size) + + def _init_data(self, length: int, batch_size: int = None): + if batch_size is not None: + if self.target.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.target.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + + if self.target.batch_axis is None: + batch_axis = None + else: + batch_axis = self.target.batch_axis + 1 - if self.sharding is None: - f = jnp.zeros - else: - f = jax.jit(jnp.zeros, static_argnums=0, static_argnames='dtype', out_shardings=self.sharding) + if self.sharding is None: + f = jnp.zeros + else: + f = jax.jit(jnp.zeros, static_argnums=0, static_argnames='dtype', out_shardings=self.sharding) - data = f((length,) + self.target.shape, dtype=self.target.dtype) - if self.data is None: - self.data = bm.Variable(data, batch_axis=batch_axis) - else: - self.data._value = data - # update delay data - if isinstance(self._init, (bm.BaseArray, jax.Array, numbers.Number)): - self.data[:] = self._init - elif callable(self._init): - self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype) - else: - assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}' + data = f((length,) + self.target.shape, dtype=self.target.dtype) + if self.data is None: + self.data = bm.Variable(data, batch_axis=batch_axis) + else: + self.data._value = data + # update delay data + if isinstance(self._init, (bm.BaseArray, jax.Array, numbers.Number)): + self.data[:] = self._init + elif callable(self._init): + self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype) + else: + assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}' class DataDelay(VarDelay): - not_desc_params = ('time', 'entries', 'name') - - def __init__( - self, - - # delay target - data: bm.Variable, - data_init: Union[Callable, bm.Array, jax.Array], - - # delay time - time: Optional[Union[int, float]] = None, - - # delay init - init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, - - # delay access entry - entries: Optional[Dict] = None, - - # delay method - method: Optional[str] = None, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - self.target_init = data_init - super().__init__(target=data, - time=time, - init=init, - entries=entries, - method=method, - name=name, - mode=mode) - - def reset_state(self, batch_size: int = None, **kwargs): - """Reset the delay data. - """ - self.target.value = variable_(self.target_init, self.target.size_without_batch, batch_size) - if self.data is not None: - self._init_data(self.max_length, batch_size) - - def update( - self, - latest_value: Union[bm.Array, jax.Array] - ) -> None: - """Update delay variable with the new data. - """ - self.target.value = latest_value - super().update(latest_value) + not_desc_params = ('time', 'entries', 'name') + + def __init__( + self, + + # delay target + data: bm.Variable, + data_init: Union[Callable, bm.Array, jax.Array], + + # delay time + time: Optional[Union[int, float]] = None, + + # delay init + init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, + + # delay access entry + entries: Optional[Dict] = None, + + # delay method + method: Optional[str] = None, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + self.target_init = data_init + super().__init__(target=data, + time=time, + init=init, + entries=entries, + method=method, + name=name, + mode=mode) + + def reset_state(self, batch_size: int = None, **kwargs): + """Reset the delay data. + """ + self.target.value = variable_(self.target_init, self.target.size_without_batch, batch_size) + if self.data is not None: + self._init_data(self.max_length, batch_size) + + def update( + self, + latest_value: Union[bm.Array, jax.Array] + ) -> None: + """Update delay variable with the new data. + """ + self.target.value = latest_value + super().update(latest_value) class DelayAccess(DynamicalSystem): - def __init__( - self, - delay: Delay, - time: Union[None, int, float], - *indices, - delay_entry: str = None - ): - super().__init__(mode=delay.mode) - self.refs = {'delay': delay} - assert isinstance(delay, Delay) - self._delay_entry = delay_entry or self.name - delay.register_entry(self._delay_entry, time) - self.indices = indices - - def update(self): - return self.refs['delay'].at(self._delay_entry, *self.indices) - - def reset_state(self, *args, **kwargs): - pass + def __init__( + self, + delay: Delay, + time: Union[None, int, float], + *indices, + delay_entry: str = None + ): + super().__init__(mode=delay.mode) + self.refs = {'delay': delay} + assert isinstance(delay, Delay) + self._delay_entry = delay_entry or self.name + delay.register_entry(self._delay_entry, time) + self.indices = indices + + def update(self): + return self.refs['delay'].at(self._delay_entry, *self.indices) + + def reset_state(self, *args, **kwargs): + pass def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_data=None) -> Delay: - """Initialize a delay class by the return info (usually is created by ``.return_info()`` function). - - Args: - info: the return information. - initial_delay_data: The initial delay data. - - Returns: - The decay instance. - """ - if isinstance(info, bm.Variable): - return VarDelay(info, init=initial_delay_data) - - elif isinstance(info, ReturnInfo): - # batch size - if isinstance(info.batch_or_mode, int): - shape = (info.batch_or_mode,) + tuple(info.size) - batch_axis = 0 - elif isinstance(info.batch_or_mode, bm.NonBatchingMode): - shape = tuple(info.size) - batch_axis = None - elif isinstance(info.batch_or_mode, bm.BatchingMode): - shape = (info.batch_or_mode.batch_size,) + tuple(info.size) - batch_axis = 0 - else: - shape = tuple(info.size) - batch_axis = None - - # init - if isinstance(info.data, Callable): - init = info.data(shape) - elif isinstance(info.data, (bm.BaseArray, jax.Array)): - init = info.data - else: - raise TypeError - assert init.shape == shape + """Initialize a delay class by the return info (usually is created by ``.return_info()`` function). - # axis names - if info.axis_names is not None: - assert init.ndim == len(info.axis_names) + Args: + info: the return information. + initial_delay_data: The initial delay data. - # variable - target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names) - return DataDelay(target, data_init=info.data, init=initial_delay_data) - else: - raise TypeError + Returns: + The decay instance. + """ + if isinstance(info, bm.Variable): + return VarDelay(info, init=initial_delay_data) + + elif isinstance(info, ReturnInfo): + # batch size + if isinstance(info.batch_or_mode, int): + shape = (info.batch_or_mode,) + tuple(info.size) + batch_axis = 0 + elif isinstance(info.batch_or_mode, bm.NonBatchingMode): + shape = tuple(info.size) + batch_axis = None + elif isinstance(info.batch_or_mode, bm.BatchingMode): + shape = (info.batch_or_mode.batch_size,) + tuple(info.size) + batch_axis = 0 + else: + shape = tuple(info.size) + batch_axis = None + + # init + if isinstance(info.data, Callable): + init = info.data(shape) + elif isinstance(info.data, (bm.BaseArray, jax.Array)): + init = info.data + else: + raise TypeError + assert init.shape == shape + + # axis names + if info.axis_names is not None: + assert init.ndim == len(info.axis_names) + + # variable + target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names) + return DataDelay(target, data_init=info.data, init=initial_delay_data) + else: + raise TypeError def register_delay_by_return(target: JointType[DynamicalSystem, SupportAutoDelay]): - """Register delay class for the given target. - - Args: - target: The target class to register delay. - - Returns: - The delay registered for the given target. - """ - if not target.has_aft_update(delay_identifier): - delay_ins = init_delay_by_return(target.return_info()) - target.add_aft_update(delay_identifier, delay_ins) - delay_cls = target.get_aft_update(delay_identifier) - return delay_cls + """Register delay class for the given target. + + Args: + target: The target class to register delay. + + Returns: + The delay registered for the given target. + """ + if not target.has_aft_update(delay_identifier): + delay_ins = init_delay_by_return(target.return_info()) + target.add_aft_update(delay_identifier, delay_ins) + delay_cls = target.get_aft_update(delay_identifier) + return delay_cls diff --git a/brainpy/_src/deprecations.py b/brainpy/_src/deprecations.py index 74a0103da..b13aa80ac 100644 --- a/brainpy/_src/deprecations.py +++ b/brainpy/_src/deprecations.py @@ -1,13 +1,12 @@ -import warnings import functools +import warnings __all__ = [ - 'deprecated', - 'deprecation_getattr', - 'deprecation_getattr2', + 'deprecated', + 'deprecation_getattr', + 'deprecation_getattr2', ] - _update_deprecate_msg = ''' From brainpy>=2.4.3, update() function no longer needs to receive a global shared argument. @@ -24,7 +23,6 @@ def update(self, *args, **kwagrs): ... ''' - _input_deprecate_msg = ''' From brainpy>=2.4.3, input() and monitor() function no longer needs to receive a global shared argument. @@ -42,52 +40,52 @@ def f_input_or_monitor(): def _deprecate(msg): - warnings.simplefilter('always', DeprecationWarning) # turn off filter - warnings.warn(msg, category=DeprecationWarning, stacklevel=2) - warnings.simplefilter('default', DeprecationWarning) # reset filter + warnings.simplefilter('always', DeprecationWarning) # turn off filter + warnings.warn(msg, category=DeprecationWarning, stacklevel=2) + warnings.simplefilter('default', DeprecationWarning) # reset filter def deprecated(func): - """This is a decorator which can be used to mark functions - as deprecated. It will result in a warning being emitted - when the function is used.""" + """This is a decorator which can be used to mark functions + as deprecated. It will result in a warning being emitted + when the function is used.""" - @functools.wraps(func) - def new_func(*args, **kwargs): - _deprecate("Call to deprecated function {}.".format(func.__name__)) - return func(*args, **kwargs) + @functools.wraps(func) + def new_func(*args, **kwargs): + _deprecate("Call to deprecated function {}.".format(func.__name__)) + return func(*args, **kwargs) - return new_func + return new_func def deprecation_getattr(module, deprecations, redirects=None, redirect_module=None): - redirects = redirects or {} + redirects = redirects or {} - def get_attr(name): - if name in deprecations: - message, fn = deprecations[name] - if fn is None: - raise AttributeError(message) - _deprecate(message) - return fn - if name in redirects: - return getattr(redirect_module, name) - raise AttributeError(f"module {module!r} has no attribute {name!r}") + def get_attr(name): + if name in deprecations: + message, fn = deprecations[name] + if fn is None: + raise AttributeError(message) + _deprecate(message) + return fn + if name in redirects: + return getattr(redirect_module, name) + raise AttributeError(f"module {module!r} has no attribute {name!r}") - return get_attr + return get_attr def deprecation_getattr2(module, deprecations): - def get_attr(name): - if name in deprecations: - old_name, new_name, fn = deprecations[name] - message = f"{old_name} is deprecated. " - if new_name is not None: - message += f'Use {new_name} instead.' - if fn is None: - raise AttributeError(message) - _deprecate(message) - return fn - raise AttributeError(f"module {module!r} has no attribute {name!r}") - - return get_attr + def get_attr(name): + if name in deprecations: + old_name, new_name, fn = deprecations[name] + message = f"{old_name} is deprecated. " + if new_name is not None: + message += f'Use {new_name} instead.' + if fn is None: + raise AttributeError(message) + _deprecate(message) + return fn + raise AttributeError(f"module {module!r} has no attribute {name!r}") + + return get_attr diff --git a/brainpy/_src/dnn/__init__.py b/brainpy/_src/dnn/__init__.py index ae2e425ab..621f4735c 100644 --- a/brainpy/_src/dnn/__init__.py +++ b/brainpy/_src/dnn/__init__.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -from .base import * from .activations import * -from .dropout import * +from .base import * from .conv import * -from .normalization import * -from .pooling import * +from .dropout import * from .function import * from .interoperation_flax import * +from .normalization import * +from .pooling import * diff --git a/brainpy/_src/dnn/activations.py b/brainpy/_src/dnn/activations.py index aa0319a3e..f8f292630 100644 --- a/brainpy/_src/dnn/activations.py +++ b/brainpy/_src/dnn/activations.py @@ -5,1170 +5,1170 @@ from brainpy.types import ArrayType __all__ = [ - 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh', - 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU', - 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU', 'Softsign', 'Tanhshrink', - 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax' + 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid', 'Tanh', + 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU', 'Hardshrink', 'LeakyReLU', + 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU', 'Softsign', 'Tanhshrink', + 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax' ] def _inplace(inp, val, inplace): - if inplace: - assert isinstance(inp, bm.BaseArray), 'input must be instance of brainpy.math.Array if inplace=True' - inp.value = val - return inp - else: - return val + if inplace: + assert isinstance(inp, bm.BaseArray), 'input must be instance of brainpy.math.Array if inplace=True' + inp.value = val + return inp + else: + return val class Threshold(Layer): - r"""Thresholds each element of the input Tensor. + r"""Thresholds each element of the input Tensor. - Threshold is defined as: + Threshold is defined as: - .. math:: - y = - \begin{cases} - x, &\text{ if } x > \text{threshold} \\ - \text{value}, &\text{ otherwise } - \end{cases} + .. math:: + y = + \begin{cases} + x, &\text{ if } x > \text{threshold} \\ + \text{value}, &\text{ otherwise } + \end{cases} - Args: - threshold: The value to threshold at - value: The value to replace with - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + threshold: The value to threshold at + value: The value to replace with + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Threshold(0.1, 20) - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['threshold', 'value', 'inplace'] + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Threshold(0.1, 20) + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['threshold', 'value', 'inplace'] - threshold: float - value: float - inplace: bool + threshold: float + value: float + inplace: bool - def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: - super().__init__() - self.threshold = threshold - self.value = value - self.inplace = inplace + def __init__(self, threshold: float, value: float, inplace: bool = False) -> None: + super().__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - v = bm.where(input > self.threshold, input, self.value) - return _inplace(input, v, self.inplace) + def update(self, input: ArrayType) -> ArrayType: + v = bm.where(input > self.threshold, input, self.value) + return _inplace(input, v, self.inplace) - def extra_repr(self): - inplace_str = ', inplace=True' if self.inplace else '' - return 'threshold={}, value={}{}'.format( - self.threshold, self.value, inplace_str - ) + def extra_repr(self): + inplace_str = ', inplace=True' if self.inplace else '' + return 'threshold={}, value={}{}'.format( + self.threshold, self.value, inplace_str + ) class ReLU(Layer): - r"""Applies the rectified linear unit function element-wise: + r"""Applies the rectified linear unit function element-wise: - :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` + :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.ReLU() - >>> input = bm.random.randn(2) - >>> output = m(input) + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.ReLU() + >>> input = bm.random.randn(2) + >>> output = m(input) - An implementation of CReLU - https://arxiv.org/abs/1603.05201 + An implementation of CReLU - https://arxiv.org/abs/1603.05201 - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.ReLU() - >>> input = bm.random.randn(2).unsqueeze(0) - >>> output = bm.cat((m(input), m(-input))) - """ - __constants__ = ['inplace'] - inplace: bool + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.ReLU() + >>> input = bm.random.randn(2).unsqueeze(0) + >>> output = bm.cat((m(input), m(-input))) + """ + __constants__ = ['inplace'] + inplace: bool - def __init__(self, inplace: bool = False): - super().__init__() - self.inplace = inplace + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - v = bm.relu(input) - return _inplace(input, v, self.inplace) + def update(self, input: ArrayType) -> ArrayType: + v = bm.relu(input) + return _inplace(input, v, self.inplace) - def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' - return inplace_str + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str class RReLU(Layer): - r"""Applies the randomized leaky rectified liner unit function, element-wise, - as described in the paper: + r"""Applies the randomized leaky rectified liner unit function, element-wise, + as described in the paper: - `Empirical Evaluation of Rectified Activations in Convolutional Network`_. + `Empirical Evaluation of Rectified Activations in Convolutional Network`_. - The function is defined as: + The function is defined as: - .. math:: - \text{RReLU}(x) = - \begin{cases} - x & \text{if } x \geq 0 \\ - ax & \text{ otherwise } - \end{cases} + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} - where :math:`a` is randomly sampled from uniform distribution - :math:`\mathcal{U}(\text{lower}, \text{upper})`. + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. - See: https://arxiv.org/pdf/1505.00853.pdf + See: https://arxiv.org/pdf/1505.00853.pdf - Args: - lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` - upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.RReLU(0.1, 0.3) - >>> input = bm.random.randn(2) - >>> output = m(input) + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.RReLU(0.1, 0.3) + >>> input = bm.random.randn(2) + >>> output = m(input) - .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: - https://arxiv.org/abs/1505.00853 - """ - __constants__ = ['lower', 'upper', 'inplace'] + .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + """ + __constants__ = ['lower', 'upper', 'inplace'] - lower: float - upper: float - inplace: bool + lower: float + upper: float + inplace: bool - def __init__( - self, - lower: float = 1. / 8, - upper: float = 1. / 3, - inplace: bool = False - ): - super().__init__() - self.lower = lower - self.upper = upper - self.inplace = inplace + def __init__( + self, + lower: float = 1. / 8, + upper: float = 1. / 3, + inplace: bool = False + ): + super().__init__() + self.lower = lower + self.upper = upper + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - x = bm.rrelu(input, self.lower, self.upper) - return _inplace(input, x, self.inplace) + def update(self, input: ArrayType) -> ArrayType: + x = bm.rrelu(input, self.lower, self.upper) + return _inplace(input, x, self.inplace) - def extra_repr(self): - inplace_str = ', inplace=True' if self.inplace else '' - return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) + def extra_repr(self): + inplace_str = ', inplace=True' if self.inplace else '' + return 'lower={}, upper={}{}'.format(self.lower, self.upper, inplace_str) class Hardtanh(Layer): - r"""Applies the HardTanh function element-wise. - - HardTanh is defined as: - - .. math:: - \text{HardTanh}(x) = \begin{cases} - \text{max\_val} & \text{ if } x > \text{ max\_val } \\ - \text{min\_val} & \text{ if } x < \text{ min\_val } \\ - x & \text{ otherwise } \\ - \end{cases} - - Args: - min_val: minimum value of the linear region range. Default: -1 - max_val: maximum value of the linear region range. Default: 1 - inplace: can optionally do the operation in-place. Default: ``False`` - - Keyword arguments :attr:`min_value` and :attr:`max_value` - have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. - - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Hardtanh(-2, 2) - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['min_val', 'max_val', 'inplace'] - - min_val: float - max_val: float - inplace: bool - - def __init__( - self, - min_val: float = -1., - max_val: float = 1., - inplace: bool = False, - ) -> None: - super().__init__() - self.min_val = min_val - self.max_val = max_val - self.inplace = inplace - assert self.max_val > self.min_val - - def update(self, input: ArrayType) -> ArrayType: - x = bm.hard_tanh(input, self.min_val, self.max_val) - return _inplace(input, x, self.inplace) - - def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return 'min_val={}, max_val={}{}'.format( - self.min_val, self.max_val, inplace_str - ) + r"""Applies the HardTanh function element-wise. + + HardTanh is defined as: + + .. math:: + \text{HardTanh}(x) = \begin{cases} + \text{max\_val} & \text{ if } x > \text{ max\_val } \\ + \text{min\_val} & \text{ if } x < \text{ min\_val } \\ + x & \text{ otherwise } \\ + \end{cases} + + Args: + min_val: minimum value of the linear region range. Default: -1 + max_val: maximum value of the linear region range. Default: 1 + inplace: can optionally do the operation in-place. Default: ``False`` + + Keyword arguments :attr:`min_value` and :attr:`max_value` + have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Hardtanh(-2, 2) + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['min_val', 'max_val', 'inplace'] + + min_val: float + max_val: float + inplace: bool + + def __init__( + self, + min_val: float = -1., + max_val: float = 1., + inplace: bool = False, + ) -> None: + super().__init__() + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + assert self.max_val > self.min_val + + def update(self, input: ArrayType) -> ArrayType: + x = bm.hard_tanh(input, self.min_val, self.max_val) + return _inplace(input, x, self.inplace) + + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'min_val={}, max_val={}{}'.format( + self.min_val, self.max_val, inplace_str + ) class ReLU6(Hardtanh): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{ReLU6}(x) = \min(\max(0,x), 6) + .. math:: + \text{ReLU6}(x) = \min(\max(0,x), 6) - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.test_ReLU6() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.test_ReLU6() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ - def __init__(self, inplace: bool = False): - super().__init__(0., 6., inplace) + def __init__(self, inplace: bool = False): + super().__init__(0., 6., inplace) - def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' - return inplace_str + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str class Sigmoid(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Sigmoid() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Sigmoid() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ - def update(self, input: ArrayType) -> ArrayType: - return bm.sigmoid(input) + def update(self, input: ArrayType) -> ArrayType: + return bm.sigmoid(input) class Hardsigmoid(Layer): - r"""Applies the Hardsigmoid function element-wise. + r"""Applies the Hardsigmoid function element-wise. - Hardsigmoid is defined as: + Hardsigmoid is defined as: - .. math:: - \text{Hardsigmoid}(x) = \begin{cases} - 0 & \text{if~} x \le -3, \\ - 1 & \text{if~} x \ge +3, \\ - x / 6 + 1 / 2 & \text{otherwise} - \end{cases} + .. math:: + \text{Hardsigmoid}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + 1 & \text{if~} x \ge +3, \\ + x / 6 + 1 / 2 & \text{otherwise} + \end{cases} - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Hardsigmoid() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['inplace'] + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Hardsigmoid() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] - inplace: bool + inplace: bool - def __init__(self, inplace: bool = False) -> None: - super().__init__() - self.inplace = inplace + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - x = bm.hard_sigmoid(input) - return _inplace(input, x, self.inplace) + def update(self, input: ArrayType) -> ArrayType: + x = bm.hard_sigmoid(input) + return _inplace(input, x, self.inplace) class Tanh(Layer): - r"""Applies the Hyperbolic Tangent (Tanh) function element-wise. + r"""Applies the Hyperbolic Tangent (Tanh) function element-wise. - Tanh is defined as: + Tanh is defined as: - .. math:: - \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} + .. math:: + \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)} - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Tanh() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Tanh() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ - def update(self, input: ArrayType) -> ArrayType: - return bm.tanh(input) + def update(self, input: ArrayType) -> ArrayType: + return bm.tanh(input) class SiLU(Layer): - r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. - The SiLU function is also known as the swish function. + r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise. + The SiLU function is also known as the swish function. - .. math:: - \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} - .. note:: - See `Gaussian Error Linear Units (GELUs) `_ - where the SiLU (Sigmoid Linear Unit) was originally coined, and see - `Sigmoid-Weighted Linear Units for Neural Network Function Approximation - in Reinforcement Learning `_ and `Swish: - a Self-Gated Activation Function `_ - where the SiLU was experimented with later. - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + Args: + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.SiLU() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['inplace'] - inplace: bool + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.SiLU() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] + inplace: bool - def __init__(self, inplace: bool = False): - super().__init__() - self.inplace = inplace + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.silu(input), self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.silu(input), self.inplace) - def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' - return inplace_str + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str class Mish(Layer): - r"""Applies the Mish function, element-wise. - Mish: A Self Regularized Non-Monotonic Neural Activation Function. + r"""Applies the Mish function, element-wise. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. - .. math:: - \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) - .. note:: - See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Mish() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['inplace'] - inplace: bool + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Mish() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] + inplace: bool - def __init__(self, inplace: bool = False): - super().__init__() - self.inplace = inplace + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.mish(input), inplace=self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.mish(input), inplace=self.inplace) - def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' - return inplace_str + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str class Hardswish(Layer): - r"""Applies the Hardswish function, element-wise, as described in the paper: - `Searching for MobileNetV3 `_. + r"""Applies the Hardswish function, element-wise, as described in the paper: + `Searching for MobileNetV3 `_. - Hardswish is defined as: + Hardswish is defined as: - .. math:: - \text{Hardswish}(x) = \begin{cases} - 0 & \text{if~} x \le -3, \\ - x & \text{if~} x \ge +3, \\ - x \cdot (x + 3) /6 & \text{otherwise} - \end{cases} + .. math:: + \text{Hardswish}(x) = \begin{cases} + 0 & \text{if~} x \le -3, \\ + x & \text{if~} x \ge +3, \\ + x \cdot (x + 3) /6 & \text{otherwise} + \end{cases} - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Hardswish() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['inplace'] + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Hardswish() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace'] - inplace: bool + inplace: bool - def __init__(self, inplace: bool = False) -> None: - super().__init__() - self.inplace = inplace + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.hard_swish(input), self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.hard_swish(input), self.inplace) class ELU(Layer): - r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described - in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear - Units (ELUs) `__. + r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described + in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear + Units (ELUs) `__. - ELU is defined as: + ELU is defined as: - .. math:: - \text{ELU}(x) = \begin{cases} - x, & \text{ if } x > 0\\ - \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 - \end{cases} + .. math:: + \text{ELU}(x) = \begin{cases} + x, & \text{ if } x > 0\\ + \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 + \end{cases} - Args: - alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.ELU() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['alpha', 'inplace'] - alpha: float - inplace: bool + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.ELU() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['alpha', 'inplace'] + alpha: float + inplace: bool - def __init__(self, alpha: float = 1., inplace: bool = False) -> None: - super().__init__() - self.alpha = alpha - self.inplace = inplace + def __init__(self, alpha: float = 1., inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.elu(input, self.alpha), self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.elu(input, self.alpha), self.inplace) - def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return 'alpha={}{}'.format(self.alpha, inplace_str) + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'alpha={}{}'.format(self.alpha, inplace_str) class CELU(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) + .. math:: + \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1)) - More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . + More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . - Args: - alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.CELU() - >>> input = bm.random.randn(2) - >>> output = m(input) + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.CELU() + >>> input = bm.random.randn(2) + >>> output = m(input) - .. _`Continuously Differentiable Exponential Linear Units`: - https://arxiv.org/abs/1704.07483 - """ - __constants__ = ['alpha', 'inplace'] - alpha: float - inplace: bool + .. _`Continuously Differentiable Exponential Linear Units`: + https://arxiv.org/abs/1704.07483 + """ + __constants__ = ['alpha', 'inplace'] + alpha: float + inplace: bool - def __init__(self, alpha: float = 1., inplace: bool = False) -> None: - super().__init__() - self.alpha = alpha - self.inplace = inplace + def __init__(self, alpha: float = 1., inplace: bool = False) -> None: + super().__init__() + self.alpha = alpha + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.celu(input, self.alpha), self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.celu(input, self.alpha), self.inplace) - def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return 'alpha={}{}'.format(self.alpha, inplace_str) + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'alpha={}{}'.format(self.alpha, inplace_str) class SELU(Layer): - r"""Applied element-wise, as: + r"""Applied element-wise, as: - .. math:: - \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) + .. math:: + \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1))) - with :math:`\alpha = 1.6732632423543772848170429916717` and - :math:`\text{scale} = 1.0507009873554804934193349852946`. + with :math:`\alpha = 1.6732632423543772848170429916717` and + :math:`\text{scale} = 1.0507009873554804934193349852946`. - More details can be found in the paper `Self-Normalizing Neural Networks`_ . + More details can be found in the paper `Self-Normalizing Neural Networks`_ . - Args: - inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` + Args: + inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.SELU() - >>> input = bm.random.randn(2) - >>> output = m(input) + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.SELU() + >>> input = bm.random.randn(2) + >>> output = m(input) - .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 - """ - __constants__ = ['inplace'] - inplace: bool + .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515 + """ + __constants__ = ['inplace'] + inplace: bool - def __init__(self, inplace: bool = False) -> None: - super().__init__() - self.inplace = inplace + def __init__(self, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.selu(input), self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.selu(input), self.inplace) - def extra_repr(self) -> str: - inplace_str = 'inplace=True' if self.inplace else '' - return inplace_str + def extra_repr(self) -> str: + inplace_str = 'inplace=True' if self.inplace else '' + return inplace_str class GLU(Layer): - r"""Applies the gated linear unit function - :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half - of the input matrices and :math:`b` is the second half. + r"""Applies the gated linear unit function + :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half + of the input matrices and :math:`b` is the second half. - Args: - dim (int): the dimension on which to split the input. Default: -1 + Args: + dim (int): the dimension on which to split the input. Default: -1 - Shape: - - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional - dimensions - - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + Shape: + - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional + dimensions + - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.GLU() - >>> input = bm.random.randn(4, 2) - >>> output = m(input) - """ - __constants__ = ['dim'] - dim: int + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.GLU() + >>> input = bm.random.randn(4, 2) + >>> output = m(input) + """ + __constants__ = ['dim'] + dim: int - def __init__(self, dim: int = -1) -> None: - super().__init__() - self.dim = dim + def __init__(self, dim: int = -1) -> None: + super().__init__() + self.dim = dim - def update(self, input: ArrayType) -> ArrayType: - return bm.glu(input, self.dim) + def update(self, input: ArrayType) -> ArrayType: + return bm.glu(input, self.dim) - def extra_repr(self) -> str: - return 'dim={}'.format(self.dim) + def extra_repr(self) -> str: + return 'dim={}'.format(self.dim) class GELU(Layer): - r"""Applies the Gaussian Error Linear Units function: + r"""Applies the Gaussian Error Linear Units function: - .. math:: \text{GELU}(x) = x * \Phi(x) + .. math:: \text{GELU}(x) = x * \Phi(x) - where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. - When the approximate argument is 'tanh', Gelu is estimated with: + When the approximate argument is 'tanh', Gelu is estimated with: - .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) + .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) - Args: - approximate (str, optional): the gelu approximation algorithm to use: - ``'none'`` | ``'tanh'``. Default: ``'none'`` + Args: + approximate (str, optional): the gelu approximation algorithm to use: + ``'none'`` | ``'tanh'``. Default: ``'none'`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.GELU() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['approximate'] - approximate: bool + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.GELU() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['approximate'] + approximate: bool - def __init__(self, approximate: bool = False) -> None: - super().__init__() - self.approximate = approximate + def __init__(self, approximate: bool = False) -> None: + super().__init__() + self.approximate = approximate - def update(self, input: ArrayType) -> ArrayType: - return bm.gelu(input, approximate=self.approximate) + def update(self, input: ArrayType) -> ArrayType: + return bm.gelu(input, approximate=self.approximate) - def extra_repr(self) -> str: - return 'approximate={}'.format(repr(self.approximate)) + def extra_repr(self) -> str: + return 'approximate={}'.format(repr(self.approximate)) class Hardshrink(Layer): - r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. + r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. - Hardshrink is defined as: + Hardshrink is defined as: - .. math:: - \text{HardShrink}(x) = - \begin{cases} - x, & \text{ if } x > \lambda \\ - x, & \text{ if } x < -\lambda \\ - 0, & \text{ otherwise } - \end{cases} + .. math:: + \text{HardShrink}(x) = + \begin{cases} + x, & \text{ if } x > \lambda \\ + x, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} - Args: - lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + Args: + lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Hardshrink() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['lambd'] - lambd: float + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Hardshrink() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['lambd'] + lambd: float - def __init__(self, lambd: float = 0.5) -> None: - super().__init__() - self.lambd = lambd + def __init__(self, lambd: float = 0.5) -> None: + super().__init__() + self.lambd = lambd - def update(self, input: ArrayType) -> ArrayType: - return bm.hard_shrink(input, self.lambd) + def update(self, input: ArrayType) -> ArrayType: + return bm.hard_shrink(input, self.lambd) - def extra_repr(self) -> str: - return '{}'.format(self.lambd) + def extra_repr(self) -> str: + return '{}'.format(self.lambd) class LeakyReLU(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) + .. math:: + \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x) - or + or - .. math:: - \text{LeakyReLU}(x) = - \begin{cases} - x, & \text{ if } x \geq 0 \\ - \text{negative\_slope} \times x, & \text{ otherwise } - \end{cases} + .. math:: + \text{LeakyReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + \text{negative\_slope} \times x, & \text{ otherwise } + \end{cases} - Args: - negative_slope: Controls the angle of the negative slope (which is used for - negative input values). Default: 1e-2 - inplace: can optionally do the operation in-place. Default: ``False`` + Args: + negative_slope: Controls the angle of the negative slope (which is used for + negative input values). Default: 1e-2 + inplace: can optionally do the operation in-place. Default: ``False`` - Shape: - - Input: :math:`(*)` where `*` means, any number of additional - dimensions - - Output: :math:`(*)`, same shape as the input + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.LeakyReLU(0.1) - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['inplace', 'negative_slope'] - inplace: bool - negative_slope: float + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.LeakyReLU(0.1) + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['inplace', 'negative_slope'] + inplace: bool + negative_slope: float - def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: - super().__init__() - self.negative_slope = negative_slope - self.inplace = inplace + def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None: + super().__init__() + self.negative_slope = negative_slope + self.inplace = inplace - def update(self, input: ArrayType) -> ArrayType: - return _inplace(input, bm.leaky_relu(input, self.negative_slope), self.inplace) + def update(self, input: ArrayType) -> ArrayType: + return _inplace(input, bm.leaky_relu(input, self.negative_slope), self.inplace) - def extra_repr(self) -> str: - inplace_str = ', inplace=True' if self.inplace else '' - return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) + def extra_repr(self) -> str: + inplace_str = ', inplace=True' if self.inplace else '' + return 'negative_slope={}{}'.format(self.negative_slope, inplace_str) class LogSigmoid(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) + .. math:: + \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right) - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.LogSigmoid() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.LogSigmoid() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ - def update(self, input: ArrayType) -> ArrayType: - return bm.log_sigmoid(input) + def update(self, input: ArrayType) -> ArrayType: + return bm.log_sigmoid(input) class Softplus(Layer): - r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * - \log(1 + \exp(\beta * x))` element-wise. + r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * + \log(1 + \exp(\beta * x))` element-wise. - SoftPlus is a smooth approximation to the ReLU function and can be used - to constrain the output of a machine to always be positive. + SoftPlus is a smooth approximation to the ReLU function and can be used + to constrain the output of a machine to always be positive. - For numerical stability the implementation reverts to the linear function - when :math:`input \times \beta > threshold`. + For numerical stability the implementation reverts to the linear function + when :math:`input \times \beta > threshold`. - Args: - beta: the :math:`\beta` value for the Softplus formulation. Default: 1 - threshold: values above this revert to a linear function. Default: 20 + Args: + beta: the :math:`\beta` value for the Softplus formulation. Default: 1 + threshold: values above this revert to a linear function. Default: 20 - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Softplus() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['beta', 'threshold'] - beta: float - threshold: float + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Softplus() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['beta', 'threshold'] + beta: float + threshold: float - def __init__(self, beta: float = 1, threshold: float = 20.) -> None: - super().__init__() - self.beta = beta - self.threshold = threshold + def __init__(self, beta: float = 1, threshold: float = 20.) -> None: + super().__init__() + self.beta = beta + self.threshold = threshold - def update(self, x: ArrayType) -> ArrayType: - return bm.softplus(x, self.beta, self.threshold) + def update(self, x: ArrayType) -> ArrayType: + return bm.softplus(x, self.beta, self.threshold) - def extra_repr(self) -> str: - return 'beta={}, threshold={}'.format(self.beta, self.threshold) + def extra_repr(self) -> str: + return 'beta={}, threshold={}'.format(self.beta, self.threshold) class Softshrink(Layer): - r"""Applies the soft shrinkage function elementwise: + r"""Applies the soft shrinkage function elementwise: - .. math:: - \text{SoftShrinkage}(x) = - \begin{cases} - x - \lambda, & \text{ if } x > \lambda \\ - x + \lambda, & \text{ if } x < -\lambda \\ - 0, & \text{ otherwise } - \end{cases} + .. math:: + \text{SoftShrinkage}(x) = + \begin{cases} + x - \lambda, & \text{ if } x > \lambda \\ + x + \lambda, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} - Args: - lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + Args: + lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Softshrink() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['lambd'] - lambd: float + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Softshrink() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['lambd'] + lambd: float - def __init__(self, lambd: float = 0.5) -> None: - super().__init__() - self.lambd = lambd + def __init__(self, lambd: float = 0.5) -> None: + super().__init__() + self.lambd = lambd - def update(self, input: ArrayType) -> ArrayType: - return bm.soft_shrink(input, self.lambd) + def update(self, input: ArrayType) -> ArrayType: + return bm.soft_shrink(input, self.lambd) - def extra_repr(self) -> str: - return str(self.lambd) + def extra_repr(self) -> str: + return str(self.lambd) class PReLU(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) - or + or - .. math:: - \text{PReLU}(x) = - \begin{cases} - x, & \text{ if } x \geq 0 \\ - ax, & \text{ otherwise } - \end{cases} + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + ax, & \text{ otherwise } + \end{cases} - Here :math:`a` is a learnable parameter. When called without arguments, `bp.dnn.PReLU()` uses a single - parameter :math:`a` across all input channels. If called with `bp.dnn.PReLU(nChannels)`, - a separate :math:`a` is used for each input channel. + Here :math:`a` is a learnable parameter. When called without arguments, `bp.dnn.PReLU()` uses a single + parameter :math:`a` across all input channels. If called with `bp.dnn.PReLU(nChannels)`, + a separate :math:`a` is used for each input channel. - .. note:: - weight decay should not be used when learning :math:`a` for good performance. + .. note:: + weight decay should not be used when learning :math:`a` for good performance. - .. note:: - Channel dim is the 2nd dim of input. When input has dims < 2, then there is - no channel dim and the number of channels = 1. + .. note:: + Channel dim is the 2nd dim of input. When input has dims < 2, then there is + no channel dim and the number of channels = 1. - Args: - num_parameters (int): number of :math:`a` to learn. - Although it takes an int as input, there is only two values are legitimate: - 1, or the number of channels at input. Default: 1 - init (float): the initial value of :math:`a`. Default: 0.25 + Args: + num_parameters (int): number of :math:`a` to learn. + Although it takes an int as input, there is only two values are legitimate: + 1, or the number of channels at input. Default: 1 + init (float): the initial value of :math:`a`. Default: 0.25 - Shape: - - Input: :math:`( *)` where `*` means, any number of additional - dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`( *)` where `*` means, any number of additional + dimensions. + - Output: :math:`(*)`, same shape as the input. - Attributes: - weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). + Attributes: + weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.PReLU() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ - __constants__ = ['num_parameters'] - num_parameters: int + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.PReLU() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ + __constants__ = ['num_parameters'] + num_parameters: int - def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None: - self.num_parameters = num_parameters - super().__init__() - self.weight = bm.TrainVar(bm.ones(num_parameters, dtype=dtype) * init) + def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None: + self.num_parameters = num_parameters + super().__init__() + self.weight = bm.TrainVar(bm.ones(num_parameters, dtype=dtype) * init) - def update(self, input: ArrayType) -> ArrayType: - return bm.prelu(input, self.weight) + def update(self, input: ArrayType) -> ArrayType: + return bm.prelu(input, self.weight) - def extra_repr(self) -> str: - return 'num_parameters={}'.format(self.num_parameters) + def extra_repr(self) -> str: + return 'num_parameters={}'.format(self.num_parameters) class Softsign(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{SoftSign}(x) = \frac{x}{ 1 + |x|} + .. math:: + \text{SoftSign}(x) = \frac{x}{ 1 + |x|} - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Softsign() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Softsign() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ - def update(self, input: ArrayType) -> ArrayType: - return bm.soft_sign(input) + def update(self, input: ArrayType) -> ArrayType: + return bm.soft_sign(input) class Tanhshrink(Layer): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{Tanhshrink}(x) = x - \tanh(x) + .. math:: + \text{Tanhshrink}(x) = x - \tanh(x) - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Tanhshrink() - >>> input = bm.random.randn(2) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Tanhshrink() + >>> input = bm.random.randn(2) + >>> output = m(input) + """ - def update(self, input: ArrayType) -> ArrayType: - return bm.tanh_shrink(input) + def update(self, input: ArrayType) -> ArrayType: + return bm.tanh_shrink(input) class Softmin(Layer): - r"""Applies the Softmin function to an n-dimensional input Tensor - rescaling them so that the elements of the n-dimensional output Tensor - lie in the range `[0, 1]` and sum to 1. + r"""Applies the Softmin function to an n-dimensional input Tensor + rescaling them so that the elements of the n-dimensional output Tensor + lie in the range `[0, 1]` and sum to 1. - Softmin is defined as: + Softmin is defined as: - .. math:: - \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} + .. math:: + \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} - Shape: - - Input: :math:`(*)` where `*` means, any number of additional - dimensions - - Output: :math:`(*)`, same shape as the input + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input - Args: - dim (int): A dimension along which Softmin will be computed (so every slice - along dim will sum to 1). + Args: + dim (int): A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). - Returns: - a Tensor of the same dimension and shape as the input, with - values in the range [0, 1] + Returns: + a Tensor of the same dimension and shape as the input, with + values in the range [0, 1] - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Softmin(dim=1) - >>> input = bm.random.randn(2, 3) - >>> output = m(input) - """ - __constants__ = ['dim'] - dim: Optional[int] + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Softmin(dim=1) + >>> input = bm.random.randn(2, 3) + >>> output = m(input) + """ + __constants__ = ['dim'] + dim: Optional[int] - def __init__(self, dim: Optional[int] = None) -> None: - super().__init__() - self.dim = dim + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim - def update(self, input: ArrayType) -> ArrayType: - return bm.softmin(input, self.dim) + def update(self, input: ArrayType) -> ArrayType: + return bm.softmin(input, self.dim) - def extra_repr(self): - return 'dim={dim}'.format(dim=self.dim) + def extra_repr(self): + return 'dim={dim}'.format(dim=self.dim) class Softmax(Layer): - r"""Applies the Softmax function to an n-dimensional input Tensor - rescaling them so that the elements of the n-dimensional output Tensor - lie in the range [0,1] and sum to 1. + r"""Applies the Softmax function to an n-dimensional input Tensor + rescaling them so that the elements of the n-dimensional output Tensor + lie in the range [0,1] and sum to 1. - Softmax is defined as: + Softmax is defined as: - .. math:: - \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + .. math:: + \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} - When the input Tensor is a sparse tensor then the unspecified - values are treated as ``-inf``. + When the input Tensor is a sparse tensor then the unspecified + values are treated as ``-inf``. - Shape: - - Input: :math:`(*)` where `*` means, any number of additional - dimensions - - Output: :math:`(*)`, same shape as the input + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input - Returns: - a Tensor of the same dimension and shape as the input with - values in the range [0, 1] + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] - Args: - dim (int): A dimension along which Softmax will be computed (so every slice - along dim will sum to 1). + Args: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). - .. note:: - This module doesn't work directly with NLLLoss, - which expects the Log to be computed between the Softmax and itself. - Use `LogSoftmax` instead (it's faster and has better numerical properties). + .. note:: + This module doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use `LogSoftmax` instead (it's faster and has better numerical properties). - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Softmax(dim=1) - >>> input = bm.random.randn(2, 3) - >>> output = m(input) + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Softmax(dim=1) + >>> input = bm.random.randn(2, 3) + >>> output = m(input) - """ - __constants__ = ['dim'] - dim: Optional[int] + """ + __constants__ = ['dim'] + dim: Optional[int] - def __init__(self, dim: Optional[int] = None) -> None: - super().__init__() - self.dim = dim + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim - def update(self, input: ArrayType) -> ArrayType: - return bm.softmax(input, self.dim) + def update(self, input: ArrayType) -> ArrayType: + return bm.softmax(input, self.dim) - def extra_repr(self) -> str: - return 'dim={dim}'.format(dim=self.dim) + def extra_repr(self) -> str: + return 'dim={dim}'.format(dim=self.dim) class Softmax2d(Layer): - r"""Applies SoftMax over features to each spatial location. + r"""Applies SoftMax over features to each spatial location. - When given an image of ``Channels x Height x Width``, it will - apply `Softmax` to each location :math:`(Channels, h_i, w_j)` + When given an image of ``Channels x Height x Width``, it will + apply `Softmax` to each location :math:`(Channels, h_i, w_j)` - Shape: - - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. - - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) + Shape: + - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. + - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) - Returns: - a Tensor of the same dimension and shape as the input with - values in the range [0, 1] + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.Softmax2d() - >>> # you softmax over the 2nd dimension - >>> input = bm.random.randn(2, 3, 12, 13) - >>> output = m(input) - """ + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.Softmax2d() + >>> # you softmax over the 2nd dimension + >>> input = bm.random.randn(2, 3, 12, 13) + >>> output = m(input) + """ - def update(self, input: ArrayType) -> ArrayType: - assert input.ndim == 4 or input.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input' - return bm.softmax(input, -3) + def update(self, input: ArrayType) -> ArrayType: + assert input.ndim == 4 or input.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input' + return bm.softmax(input, -3) class LogSoftmax(Layer): - r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional - input Tensor. The LogSoftmax formulation can be simplified as: + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional + input Tensor. The LogSoftmax formulation can be simplified as: - .. math:: - \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) + .. math:: + \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) - Shape: - - Input: :math:`(*)` where `*` means, any number of additional - dimensions - - Output: :math:`(*)`, same shape as the input + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input - Args: - dim (int): A dimension along which LogSoftmax will be computed. + Args: + dim (int): A dimension along which LogSoftmax will be computed. - Returns: - a Tensor of the same dimension and shape as the input with - values in the range [-inf, 0) + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [-inf, 0) - Examples:: + Examples:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> m = bp.dnn.LogSoftmax(dim=1) - >>> input = bm.random.randn(2, 3) - >>> output = m(input) - """ - __constants__ = ['dim'] - dim: Optional[int] + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> m = bp.dnn.LogSoftmax(dim=1) + >>> input = bm.random.randn(2, 3) + >>> output = m(input) + """ + __constants__ = ['dim'] + dim: Optional[int] - def __init__(self, dim: Optional[int] = None) -> None: - super().__init__() - self.dim = dim + def __init__(self, dim: Optional[int] = None) -> None: + super().__init__() + self.dim = dim - def update(self, input: ArrayType) -> ArrayType: - return bm.log_softmax(input, self.dim) + def update(self, input: ArrayType) -> ArrayType: + return bm.log_softmax(input, self.dim) - def extra_repr(self): - return 'dim={dim}'.format(dim=self.dim) + def extra_repr(self): + return 'dim={dim}'.format(dim=self.dim) diff --git a/brainpy/_src/dnn/base.py b/brainpy/_src/dnn/base.py index 40665956c..d1130dbce 100644 --- a/brainpy/_src/dnn/base.py +++ b/brainpy/_src/dnn/base.py @@ -1,14 +1,12 @@ from brainpy._src.dynsys import DynamicalSystem - __all__ = [ - 'Layer' + 'Layer' ] class Layer(DynamicalSystem): - """Base class for a layer of artificial neural network.""" - - def reset_state(self, *args, **kwargs): - pass + """Base class for a layer of artificial neural network.""" + def reset_state(self, *args, **kwargs): + pass diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index 759b012a0..4ba5508fc 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -10,468 +10,468 @@ from brainpy.types import ArrayType __all__ = [ - 'Conv1d', 'Conv2d', 'Conv3d', - 'Conv1D', 'Conv2D', 'Conv3D', - 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', + 'Conv1d', 'Conv2d', 'Conv3d', + 'Conv1D', 'Conv2D', 'Conv3D', + 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d', ] def to_dimension_numbers(num_spatial_dims: int, channels_last: bool, transpose: bool) -> lax.ConvDimensionNumbers: - """Create a `lax.ConvDimensionNumbers` for the given inputs.""" - num_dims = num_spatial_dims + 2 - if channels_last: - spatial_dims = tuple(range(1, num_dims - 1)) - image_dn = (0, num_dims - 1) + spatial_dims - else: - spatial_dims = tuple(range(2, num_dims)) - image_dn = (0, 1) + spatial_dims - if transpose: - kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) - else: - kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) - return lax.ConvDimensionNumbers(lhs_spec=image_dn, - rhs_spec=kernel_dn, - out_spec=image_dn) + """Create a `lax.ConvDimensionNumbers` for the given inputs.""" + num_dims = num_spatial_dims + 2 + if channels_last: + spatial_dims = tuple(range(1, num_dims - 1)) + image_dn = (0, num_dims - 1) + spatial_dims + else: + spatial_dims = tuple(range(2, num_dims)) + image_dn = (0, 1) + spatial_dims + if transpose: + kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2)) + else: + kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2)) + return lax.ConvDimensionNumbers(lhs_spec=image_dn, + rhs_spec=kernel_dn, + out_spec=image_dn) class _GeneralConv(Layer): - """Apply a convolution to the inputs. - - Parameters:: - - num_spatial_dims: int - The number of spatial dimensions of the input. - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - stride: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Optional, Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. - """ - - supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode) - - def __init__( - self, - num_spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - mode: bm.Mode = None, - name: str = None, - ): - super(_GeneralConv, self).__init__(name=name, mode=mode) - - self.num_spatial_dims = num_spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = tools.replicate(stride, num_spatial_dims, 'stride') - self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size') - self.lhs_dilation = tools.replicate(lhs_dilation, num_spatial_dims, 'lhs_dilation') - self.rhs_dilation = tools.replicate(rhs_dilation, num_spatial_dims, 'rhs_dilation') - self.groups = groups - self.w_initializer = w_initializer - self.b_initializer = b_initializer - self.mask = mask - self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False) - - if isinstance(padding, str): - assert padding in ['SAME', 'VALID'] - elif isinstance(padding, int): - padding = tuple((padding, padding) for _ in range(num_spatial_dims)) - elif isinstance(padding, (tuple, list)): - if isinstance(padding[0], int): - padding = (padding,) * num_spatial_dims - elif isinstance(padding[0], (tuple, list)): - if len(padding) == 1: - padding = tuple(padding) * num_spatial_dims + """Apply a convolution to the inputs. + + Parameters:: + + num_spatial_dims: int + The number of spatial dimensions of the input. + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + kernel_size: int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + stride: int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding: str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups: int + If specified, divides the input features into groups. default 1. + w_initializer: Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer: Optional, Callable, ArrayType, Initializer + The initializer for the bias. + mask: ArrayType, Optional + The optional mask of the weights. + mode: Mode + The computation mode of the current object. Default it is `training`. + name: str, Optional + The name of the object. + """ + + supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode) + + def __init__( + self, + num_spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + mode: bm.Mode = None, + name: str = None, + ): + super(_GeneralConv, self).__init__(name=name, mode=mode) + + self.num_spatial_dims = num_spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = tools.replicate(stride, num_spatial_dims, 'stride') + self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size') + self.lhs_dilation = tools.replicate(lhs_dilation, num_spatial_dims, 'lhs_dilation') + self.rhs_dilation = tools.replicate(rhs_dilation, num_spatial_dims, 'rhs_dilation') + self.groups = groups + self.w_initializer = w_initializer + self.b_initializer = b_initializer + self.mask = mask + self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False) + + if isinstance(padding, str): + assert padding in ['SAME', 'VALID'] + elif isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(num_spatial_dims)) + elif isinstance(padding, (tuple, list)): + if isinstance(padding[0], int): + padding = (padding,) * num_spatial_dims + elif isinstance(padding[0], (tuple, list)): + if len(padding) == 1: + padding = tuple(padding) * num_spatial_dims + else: + if len(padding) != num_spatial_dims: + raise ValueError(f"Padding {padding} must be a Tuple[int, int], " + f"or sequence of Tuple[int, int] with length 1, " + f"or sequence of Tuple[int, int] with length {num_spatial_dims}.") + padding = tuple(padding) else: - if len(padding) != num_spatial_dims: - raise ValueError(f"Padding {padding} must be a Tuple[int, int], " - f"or sequence of Tuple[int, int] with length 1, " - f"or sequence of Tuple[int, int] with length {num_spatial_dims}.") - padding = tuple(padding) - else: - raise ValueError - self.padding = padding - - assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups' - assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups' - - kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels) - bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) - self.w = parameter(self.w_initializer, kernel_shape, allow_none=False) - self.b = parameter(self.b_initializer, bias_shape, allow_none=True) - if isinstance(self.mode, bm.TrainingMode): - self.w = bm.TrainVar(self.w) - if self.b is not None: - self.b = bm.TrainVar(self.b) - - def _check_input_dim(self, x): - if x.ndim != self.num_spatial_dims + 2 and x.ndim != self.num_spatial_dims + 1: - raise ValueError( - f"expected {self.num_spatial_dims + 2}D or {self.num_spatial_dims + 1}D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") - - def update(self, x): - self._check_input_dim(x) - nonbatching = False - if x.ndim == self.num_spatial_dims + 1: - nonbatching = True - x = bm.unsqueeze(x, 0) - w = self.w.value - if self.mask is not None: - try: - lax.broadcast_shapes(self.w.shape, self.mask.shape) - except: - raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") - w = w * self.mask - y = lax.conv_general_dilated(lhs=bm.as_jax(x), - rhs=bm.as_jax(w), - window_strides=self.stride, - padding=self.padding, - lhs_dilation=self.lhs_dilation, - rhs_dilation=self.rhs_dilation, - feature_group_count=self.groups, - dimension_numbers=self.dimension_numbers) - if nonbatching: - return y[0] if self.b is None else (y + self.b.value)[0] - else: - return y if self.b is None else (y + self.b.value) + raise ValueError + self.padding = padding + + assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups' + assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups' + + kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels) + bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) + self.w = parameter(self.w_initializer, kernel_shape, allow_none=False) + self.b = parameter(self.b_initializer, bias_shape, allow_none=True) + if isinstance(self.mode, bm.TrainingMode): + self.w = bm.TrainVar(self.w) + if self.b is not None: + self.b = bm.TrainVar(self.b) + + def _check_input_dim(self, x): + if x.ndim != self.num_spatial_dims + 2 and x.ndim != self.num_spatial_dims + 1: + raise ValueError( + f"expected {self.num_spatial_dims + 2}D or {self.num_spatial_dims + 1}D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") + + def update(self, x): + self._check_input_dim(x) + nonbatching = False + if x.ndim == self.num_spatial_dims + 1: + nonbatching = True + x = bm.unsqueeze(x, 0) + w = self.w.value + if self.mask is not None: + try: + lax.broadcast_shapes(self.w.shape, self.mask.shape) + except: + raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") + w = w * self.mask + y = lax.conv_general_dilated(lhs=bm.as_jax(x), + rhs=bm.as_jax(w), + window_strides=self.stride, + padding=self.padding, + lhs_dilation=self.lhs_dilation, + rhs_dilation=self.rhs_dilation, + feature_group_count=self.groups, + dimension_numbers=self.dimension_numbers) + if nonbatching: + return y[0] if self.b is None else (y + self.b.value)[0] + else: + return y if self.b is None else (y + self.b.value) - def __repr__(self): - return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' - f'out_channels={self.out_channels}, kernel_size={self.kernel_size}, ' - f'stride={self.stride}, padding={self.padding}, groups={self.groups})') + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, kernel_size={self.kernel_size}, ' + f'stride={self.stride}, padding={self.padding}, groups={self.groups})') class Conv1d(_GeneralConv): - """One-dimensional convolution. - - The input should a 2d array with the shape of ``[H, C]``, or - a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. - - Parameters:: - - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - strides: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = None, - strides: Union[int, Tuple[int, ...]] = None, # deprecated - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - if stride is None: - if strides is None: - stride = 1 - else: - stride = strides - else: - if strides is not None: - raise ValueError('Cannot provide "stride" and "strides" both.') - - super(Conv1d, self).__init__(num_spatial_dims=1, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - mask=mask, - mode=mode, - name=name) - - def _check_input_dim(self, x): - if x.ndim != 3 and x.ndim != 2: - raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") + """One-dimensional convolution. + + The input should a 2d array with the shape of ``[H, C]``, or + a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. + + Parameters:: + + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + kernel_size: int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + strides: int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding: str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups: int + If specified, divides the input features into groups. default 1. + w_initializer: Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer: Callable, ArrayType, Initializer + The initializer for the bias. + mask: ArrayType, Optional + The optional mask of the weights. + mode: Mode + The computation mode of the current object. Default it is `training`. + name: str, Optional + The name of the object. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = None, + strides: Union[int, Tuple[int, ...]] = None, # deprecated + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + if stride is None: + if strides is None: + stride = 1 + else: + stride = strides + else: + if strides is not None: + raise ValueError('Cannot provide "stride" and "strides" both.') + + super(Conv1d, self).__init__(num_spatial_dims=1, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + mask=mask, + mode=mode, + name=name) + + def _check_input_dim(self, x): + if x.ndim != 3 and x.ndim != 2: + raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") class Conv2d(_GeneralConv): - """Two-dimensional convolution. - - The input should a 3d array with the shape of ``[H, W, C]``, or - a 4d array with the shape of ``[B, H, W, C]``. - - Parameters:: - - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - stride: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. - - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = None, - strides: Union[int, Tuple[int, ...]] = None, # deprecated - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - if stride is None: - if strides is None: - stride = 1 - else: - stride = strides - else: - if strides is not None: - raise ValueError('Cannot provide "stride" and "strides" both.') - - super(Conv2d, self).__init__(num_spatial_dims=2, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - mask=mask, - mode=mode, - name=name) - - def _check_input_dim(self, x): - if x.ndim != 4 and x.ndim != 3: - raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") + """Two-dimensional convolution. + + The input should a 3d array with the shape of ``[H, W, C]``, or + a 4d array with the shape of ``[B, H, W, C]``. + + Parameters:: + + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + kernel_size: int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + stride: int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding: str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups: int + If specified, divides the input features into groups. default 1. + w_initializer: Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer: Callable, ArrayType, Initializer + The initializer for the bias. + mask: ArrayType, Optional + The optional mask of the weights. + mode: Mode + The computation mode of the current object. Default it is `training`. + name: str, Optional + The name of the object. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = None, + strides: Union[int, Tuple[int, ...]] = None, # deprecated + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + if stride is None: + if strides is None: + stride = 1 + else: + stride = strides + else: + if strides is not None: + raise ValueError('Cannot provide "stride" and "strides" both.') + + super(Conv2d, self).__init__(num_spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + mask=mask, + mode=mode, + name=name) + + def _check_input_dim(self, x): + if x.ndim != 4 and x.ndim != 3: + raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") class Conv3d(_GeneralConv): - """Three-dimensional convolution. - - The input should a 3d array with the shape of ``[H, W, D, C]``, or - a 4d array with the shape of ``[B, H, W, D, C]``. - - Parameters:: - - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - stride: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. - - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = None, - strides: Union[int, Tuple[int, ...]] = None, # deprecated - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - if stride is None: - if strides is None: - stride = 1 - else: - stride = strides - else: - if strides is not None: - raise ValueError('Cannot provide "stride" and "strides" both.') - - super(Conv3d, self).__init__(num_spatial_dims=3, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - mask=mask, - mode=mode, - name=name) - - def _check_input_dim(self, x): - if x.ndim != 5 and x.ndim != 4: - raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") + """Three-dimensional convolution. + + The input should a 3d array with the shape of ``[H, W, D, C]``, or + a 4d array with the shape of ``[B, H, W, D, C]``. + + Parameters:: + + in_channels: int + The number of input channels. + out_channels: int + The number of output channels. + kernel_size: int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + stride: int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding: str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation: int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups: int + If specified, divides the input features into groups. default 1. + w_initializer: Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer: Callable, ArrayType, Initializer + The initializer for the bias. + mask: ArrayType, Optional + The optional mask of the weights. + mode: Mode + The computation mode of the current object. Default it is `training`. + name: str, Optional + The name of the object. + + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = None, + strides: Union[int, Tuple[int, ...]] = None, # deprecated + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + if stride is None: + if strides is None: + stride = 1 + else: + stride = strides + else: + if strides is not None: + raise ValueError('Cannot provide "stride" and "strides" both.') + + super(Conv3d, self).__init__(num_spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + mask=mask, + mode=mode, + name=name) + + def _check_input_dim(self, x): + if x.ndim != 5 and x.ndim != 4: + raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") Conv1D = Conv1d @@ -480,265 +480,265 @@ def _check_input_dim(self, x): class _GeneralConvTranspose(Layer): - supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode) - - def __init__( - self, - num_spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - precision: Optional[lax.Precision] = None, - mode: bm.Mode = None, - name: str = None, - ): - super().__init__(name=name, mode=mode) - - self.num_spatial_dims = num_spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = tools.replicate(stride, num_spatial_dims, 'stride') - self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size') - self.w_initializer = w_initializer - self.b_initializer = b_initializer - self.precision = precision - self.mask = mask - self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False) - - if isinstance(padding, str): - assert padding in ['SAME', 'VALID'] - elif isinstance(padding, int): - padding = tuple((padding, padding) for _ in range(num_spatial_dims)) - elif isinstance(padding, (tuple, list)): - if isinstance(padding[0], int): - padding = (padding,) * num_spatial_dims - elif isinstance(padding[0], (tuple, list)): - if len(padding) == 1: - padding = tuple(padding) * num_spatial_dims + supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode) + + def __init__( + self, + num_spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + precision: Optional[lax.Precision] = None, + mode: bm.Mode = None, + name: str = None, + ): + super().__init__(name=name, mode=mode) + + self.num_spatial_dims = num_spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = tools.replicate(stride, num_spatial_dims, 'stride') + self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size') + self.w_initializer = w_initializer + self.b_initializer = b_initializer + self.precision = precision + self.mask = mask + self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False) + + if isinstance(padding, str): + assert padding in ['SAME', 'VALID'] + elif isinstance(padding, int): + padding = tuple((padding, padding) for _ in range(num_spatial_dims)) + elif isinstance(padding, (tuple, list)): + if isinstance(padding[0], int): + padding = (padding,) * num_spatial_dims + elif isinstance(padding[0], (tuple, list)): + if len(padding) == 1: + padding = tuple(padding) * num_spatial_dims + else: + if len(padding) != num_spatial_dims: + raise ValueError(f"Padding {padding} must be a Tuple[int, int], " + f"or sequence of Tuple[int, int] with length 1, " + f"or sequence of Tuple[int, int] with length {num_spatial_dims}.") + padding = tuple(padding) else: - if len(padding) != num_spatial_dims: - raise ValueError(f"Padding {padding} must be a Tuple[int, int], " - f"or sequence of Tuple[int, int] with length 1, " - f"or sequence of Tuple[int, int] with length {num_spatial_dims}.") - padding = tuple(padding) - else: - raise ValueError - self.padding = padding - - kernel_shape = tuple(self.kernel_size) + (self.in_channels, self.out_channels) - bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) - self.w = parameter(self.w_initializer, kernel_shape, allow_none=False) - self.b = parameter(self.b_initializer, bias_shape, allow_none=True) - if isinstance(self.mode, bm.TrainingMode): - self.w = bm.TrainVar(self.w) - if self.b is not None: - self.b = bm.TrainVar(self.b) - - def _check_input_dim(self, x): - raise NotImplementedError - - def update(self, x): - self._check_input_dim(x) - nonbatching = False - if x.ndim == self.num_spatial_dims + 1: - nonbatching = True - x = x.unsqueeze(0) - w = self.w.value - if self.mask is not None: - try: - lax.broadcast_shapes(self.w.shape, self.mask.shape) - except: - raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") - w = w * self.mask - y = lax.conv_transpose(lhs=bm.as_jax(x), - rhs=bm.as_jax(w), - strides=self.stride, - padding=self.padding, - precision=self.precision, - rhs_dilation=None, - dimension_numbers=self.dimension_numbers) - if nonbatching: - return y[0] if self.b is None else (y + self.b.value)[0] - else: - return y if self.b is None else (y + self.b.value) + raise ValueError + self.padding = padding + + kernel_shape = tuple(self.kernel_size) + (self.in_channels, self.out_channels) + bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,) + self.w = parameter(self.w_initializer, kernel_shape, allow_none=False) + self.b = parameter(self.b_initializer, bias_shape, allow_none=True) + if isinstance(self.mode, bm.TrainingMode): + self.w = bm.TrainVar(self.w) + if self.b is not None: + self.b = bm.TrainVar(self.b) + + def _check_input_dim(self, x): + raise NotImplementedError + + def update(self, x): + self._check_input_dim(x) + nonbatching = False + if x.ndim == self.num_spatial_dims + 1: + nonbatching = True + x = x.unsqueeze(0) + w = self.w.value + if self.mask is not None: + try: + lax.broadcast_shapes(self.w.shape, self.mask.shape) + except: + raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}") + w = w * self.mask + y = lax.conv_transpose(lhs=bm.as_jax(x), + rhs=bm.as_jax(w), + strides=self.stride, + padding=self.padding, + precision=self.precision, + rhs_dilation=None, + dimension_numbers=self.dimension_numbers) + if nonbatching: + return y[0] if self.b is None else (y + self.b.value)[0] + else: + return y if self.b is None else (y + self.b.value) - def __repr__(self): - return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' - f'out_channels={self.out_channels}, kernel_size={self.kernel_size}, ' - f'stride={self.stride}, padding={self.padding})') + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' + f'out_channels={self.out_channels}, kernel_size={self.kernel_size}, ' + f'stride={self.stride}, padding={self.padding})') class ConvTranspose1d(_GeneralConvTranspose): - """One dimensional transposed convolution (aka. deconvolution).""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - precision: Optional[lax.Precision] = None, - mode: bm.Mode = None, - name: str = None, - ): - """Initializes the module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: The shape of the kernel. Either an integer or a sequence of - length 1. - stride: Optional stride for the kernel. Either an integer or a sequence of - length 1. Defaults to 1. - padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. - Defaults to ``SAME``. See: - https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - w_initializer: Optional weight initialization. By default, truncated normal. - b_initializer: Optional bias initialization. By default, zeros. - mask: Optional mask of the weights. - name: The name of the module. - """ - super().__init__( - num_spatial_dims=1, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - w_initializer=w_initializer, - b_initializer=b_initializer, - precision=precision, - mode=mode, - mask=mask, - name=name - ) - - def _check_input_dim(self, x): - if x.ndim != 3 and x.ndim != 2: - raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") + """One dimensional transposed convolution (aka. deconvolution).""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + precision: Optional[lax.Precision] = None, + mode: bm.Mode = None, + name: str = None, + ): + """Initializes the module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: The shape of the kernel. Either an integer or a sequence of + length 1. + stride: Optional stride for the kernel. Either an integer or a sequence of + length 1. Defaults to 1. + padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. + Defaults to ``SAME``. See: + https://www.tensorflow.org/xla/operation_semantics#conv_convolution. + w_initializer: Optional weight initialization. By default, truncated normal. + b_initializer: Optional bias initialization. By default, zeros. + mask: Optional mask of the weights. + name: The name of the module. + """ + super().__init__( + num_spatial_dims=1, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + w_initializer=w_initializer, + b_initializer=b_initializer, + precision=precision, + mode=mode, + mask=mask, + name=name + ) + + def _check_input_dim(self, x): + if x.ndim != 3 and x.ndim != 2: + raise ValueError(f"expected 3D or 2D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") class ConvTranspose2d(_GeneralConvTranspose): - """Two dimensional transposed convolution (aka. deconvolution).""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - precision: Optional[lax.Precision] = None, - mode: bm.Mode = None, - name: str = None, - ): - """Initializes the module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: The shape of the kernel. Either an integer or a sequence of - length 2. - stride: Optional stride for the kernel. Either an integer or a sequence of - length 2. Defaults to 1. - padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. - Defaults to ``SAME``. See: - https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - w_initializer: Optional weight initialization. By default, truncated normal. - b_initializer: Optional bias initialization. By default, zeros. - mask: Optional mask of the weights. - name: The name of the module. - """ - super().__init__( - num_spatial_dims=2, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - w_initializer=w_initializer, - b_initializer=b_initializer, - precision=precision, - mode=mode, - mask=mask, - name=name - ) - - def _check_input_dim(self, x): - if x.ndim != 4 and x.ndim != 3: - raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") + """Two dimensional transposed convolution (aka. deconvolution).""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + precision: Optional[lax.Precision] = None, + mode: bm.Mode = None, + name: str = None, + ): + """Initializes the module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: The shape of the kernel. Either an integer or a sequence of + length 2. + stride: Optional stride for the kernel. Either an integer or a sequence of + length 2. Defaults to 1. + padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. + Defaults to ``SAME``. See: + https://www.tensorflow.org/xla/operation_semantics#conv_convolution. + w_initializer: Optional weight initialization. By default, truncated normal. + b_initializer: Optional bias initialization. By default, zeros. + mask: Optional mask of the weights. + name: The name of the module. + """ + super().__init__( + num_spatial_dims=2, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + w_initializer=w_initializer, + b_initializer=b_initializer, + precision=precision, + mode=mode, + mask=mask, + name=name + ) + + def _check_input_dim(self, x): + if x.ndim != 4 and x.ndim != 3: + raise ValueError(f"expected 4D or 3D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") class ConvTranspose3d(_GeneralConvTranspose): - """Three dimensional transposed convolution (aka. deconvolution).""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, ...]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - mask: Optional[ArrayType] = None, - precision: Optional[lax.Precision] = None, - mode: bm.Mode = None, - name: str = None, - ): - """Initializes the module. - - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: The shape of the kernel. Either an integer or a sequence of - length 3. - stride: Optional stride for the kernel. Either an integer or a sequence of - length 3. Defaults to 1. - padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. - Defaults to ``SAME``. See: - https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - w_initializer: Optional weight initialization. By default, truncated normal. - b_initializer: Optional bias initialization. By default, zeros. - mask: Optional mask of the weights. - name: The name of the module. - """ - super().__init__( - num_spatial_dims=3, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - w_initializer=w_initializer, - b_initializer=b_initializer, - precision=precision, - mode=mode, - mask=mask, - name=name - ) - - def _check_input_dim(self, x): - if x.ndim != 5 and x.ndim != 4: - raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)") - if self.in_channels != x.shape[-1]: - raise ValueError(f"input channels={x.shape[-1]} needs to have " - f"the same size as in_channels={self.in_channels}.") + """Three dimensional transposed convolution (aka. deconvolution).""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(in_axis=-1, out_axis=-2), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + mask: Optional[ArrayType] = None, + precision: Optional[lax.Precision] = None, + mode: bm.Mode = None, + name: str = None, + ): + """Initializes the module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: The shape of the kernel. Either an integer or a sequence of + length 3. + stride: Optional stride for the kernel. Either an integer or a sequence of + length 3. Defaults to 1. + padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. + Defaults to ``SAME``. See: + https://www.tensorflow.org/xla/operation_semantics#conv_convolution. + w_initializer: Optional weight initialization. By default, truncated normal. + b_initializer: Optional bias initialization. By default, zeros. + mask: Optional mask of the weights. + name: The name of the module. + """ + super().__init__( + num_spatial_dims=3, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + w_initializer=w_initializer, + b_initializer=b_initializer, + precision=precision, + mode=mode, + mask=mask, + name=name + ) + + def _check_input_dim(self, x): + if x.ndim != 5 and x.ndim != 4: + raise ValueError(f"expected 5D or 4D input (got {x.ndim}D input)") + if self.in_channels != x.shape[-1]: + raise ValueError(f"input channels={x.shape[-1]} needs to have " + f"the same size as in_channels={self.in_channels}.") diff --git a/brainpy/_src/dnn/dropout.py b/brainpy/_src/dnn/dropout.py index 6bd8bde7a..3c506c9c8 100644 --- a/brainpy/_src/dnn/dropout.py +++ b/brainpy/_src/dnn/dropout.py @@ -2,50 +2,49 @@ from typing import Optional -from brainpy._src.context import share from brainpy import math as bm, check +from brainpy._src.context import share from brainpy._src.dnn.base import Layer __all__ = [ - 'Dropout' + 'Dropout' ] class Dropout(Layer): - """A layer that stochastically ignores a subset of inputs each training step. - - In training, to compensate for the fraction of input values dropped (`rate`), - all surviving values are multiplied by `1 / (1 - rate)`. - - This layer is active only during training (``mode=brainpy.math.training_mode``). In other - circumstances it is a no-op. - - .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent - neural networks from overfitting." The journal of machine learning - research 15.1 (2014): 1929-1958. - - Args: - prob: Probability to keep element of the tensor. - mode: Mode. The computation mode of the object. - name: str. The name of the dynamic system. - - """ - - def __init__( - self, - prob: float, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None - ): - super(Dropout, self).__init__(mode=mode, name=name) - self.prob = check.is_float(prob, min_bound=0., max_bound=1.) - - def update(self, x, fit: Optional[bool] = None): - if fit is None: - fit = share['fit'] - if fit: - keep_mask = bm.random.bernoulli(self.prob, x.shape) - return bm.where(keep_mask, x / self.prob, 0.) - else: - return x - + """A layer that stochastically ignores a subset of inputs each training step. + + In training, to compensate for the fraction of input values dropped (`rate`), + all surviving values are multiplied by `1 / (1 - rate)`. + + This layer is active only during training (``mode=brainpy.math.training_mode``). In other + circumstances it is a no-op. + + .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent + neural networks from overfitting." The journal of machine learning + research 15.1 (2014): 1929-1958. + + Args: + prob: Probability to keep element of the tensor. + mode: Mode. The computation mode of the object. + name: str. The name of the dynamic system. + + """ + + def __init__( + self, + prob: float, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None + ): + super(Dropout, self).__init__(mode=mode, name=name) + self.prob = check.is_float(prob, min_bound=0., max_bound=1.) + + def update(self, x, fit: Optional[bool] = None): + if fit is None: + fit = share['fit'] + if fit: + keep_mask = bm.random.bernoulli(self.prob, x.shape) + return bm.where(keep_mask, x / self.prob, 0.) + else: + return x diff --git a/brainpy/_src/dnn/function.py b/brainpy/_src/dnn/function.py index 5f33552ed..8a1456bf8 100644 --- a/brainpy/_src/dnn/function.py +++ b/brainpy/_src/dnn/function.py @@ -6,171 +6,171 @@ from brainpy._src.dnn.base import Layer __all__ = [ - 'Activation', - 'Flatten', - 'Unflatten', - 'FunAsLayer', + 'Activation', + 'Flatten', + 'Unflatten', + 'FunAsLayer', ] class Activation(Layer): - r"""Applies an activation function to the inputs - - Parameters: - ---------- - activate_fun: Callable, function - The function of Activation - name: str, Optional - The name of the object - mode: Mode - Enable training this node or not. (default True). - """ - update_style = 'x' - - def __init__( - self, - activate_fun: Callable, - name: Optional[str] = None, - mode: bm.Mode = None, - **kwargs, - ): - super().__init__(name, mode) - self.activate_fun = activate_fun - self.kwargs = kwargs - - def update(self, *args, **kwargs): - return self.activate_fun(*args, **kwargs, **self.kwargs) + r"""Applies an activation function to the inputs + + Parameters: + ---------- + activate_fun: Callable, function + The function of Activation + name: str, Optional + The name of the object + mode: Mode + Enable training this node or not. (default True). + """ + update_style = 'x' + + def __init__( + self, + activate_fun: Callable, + name: Optional[str] = None, + mode: bm.Mode = None, + **kwargs, + ): + super().__init__(name, mode) + self.activate_fun = activate_fun + self.kwargs = kwargs + + def update(self, *args, **kwargs): + return self.activate_fun(*args, **kwargs, **self.kwargs) class Flatten(Layer): - r""" - Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. - - Shape: - - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' - where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any - number of dimensions including none. - - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. - - Args: - start_dim: first dim to flatten (default = 1). - end_dim: last dim to flatten (default = -1). - name: str, Optional. The name of the object. - mode: Mode. Enable training this node or not. (default True). - - Examples:: - >>> import brainpy.math as bm - >>> inp = bm.random.randn(32, 1, 5, 5) - >>> # With default parameters - >>> m = Flatten() - >>> output = m(inp) - >>> output.shape - (32, 25) - >>> # With non-default parameters - >>> m = Flatten(0, 2) - >>> output = m(inp) - >>> output.shape - (160, 5) - """ - - def __init__( - self, - start_dim: int = 0, - end_dim: int = -1, - name: Optional[str] = None, - mode: bm.Mode = None, - ): - super().__init__(name, mode) - - self.start_dim = start_dim - self.end_dim = end_dim - - def update(self, x): - if self.mode.is_child_of(bm.BatchingMode): - start_dim = (self.start_dim + 1) if self.start_dim >= 0 else (x.ndim + self.start_dim + 1) - else: - start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim - return bm.flatten(x, start_dim, self.end_dim) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})' + r""" + Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. + + Shape: + - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' + where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any + number of dimensions including none. + - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. + + Args: + start_dim: first dim to flatten (default = 1). + end_dim: last dim to flatten (default = -1). + name: str, Optional. The name of the object. + mode: Mode. Enable training this node or not. (default True). + + Examples:: + >>> import brainpy.math as bm + >>> inp = bm.random.randn(32, 1, 5, 5) + >>> # With default parameters + >>> m = Flatten() + >>> output = m(inp) + >>> output.shape + (32, 25) + >>> # With non-default parameters + >>> m = Flatten(0, 2) + >>> output = m(inp) + >>> output.shape + (160, 5) + """ + + def __init__( + self, + start_dim: int = 0, + end_dim: int = -1, + name: Optional[str] = None, + mode: bm.Mode = None, + ): + super().__init__(name, mode) + + self.start_dim = start_dim + self.end_dim = end_dim + + def update(self, x): + if self.mode.is_child_of(bm.BatchingMode): + start_dim = (self.start_dim + 1) if self.start_dim >= 0 else (x.ndim + self.start_dim + 1) + else: + start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim + return bm.flatten(x, start_dim, self.end_dim) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})' class Unflatten(Layer): - r""" - Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. - - * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can - be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. - - * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be - a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` - (tuple of `(name, size)` tuples) for `NamedTensor` input. - - Shape: - - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at - dimension :attr:`dim` and :math:`*` means any number of dimensions including none. - - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and - :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. - - Args: - dim: int, Dimension to be unflattened. - sizes: Sequence of int. New shape of the unflattened dimension. - - Examples: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> input = bm.random.randn(2, 50) - >>> # With tuple of ints - >>> m = bp.Sequential( - >>> bp.dnn.Linear(50, 50), - >>> Unflatten(1, (2, 5, 5)) - >>> ) - >>> output = m(input) - >>> output.shape - (2, 2, 5, 5) - >>> # With torch.Size - >>> m = bp.Sequential( - >>> bp.dnn.Linear(50, 50), - >>> Unflatten(1, [2, 5, 5]) - >>> ) - >>> output = m(input) - >>> output.shape - (2, 2, 5, 5) - """ - - def __init__(self, dim: int, sizes: Sequence[int], mode: bm.Mode = None, name: str = None) -> None: - super().__init__(mode=mode, name=name) - - self.dim = dim - self.sizes = sizes - if isinstance(sizes, (tuple, list)): - for idx, elem in enumerate(sizes): - if not isinstance(elem, int): - raise TypeError("unflattened_size must be tuple of ints, " + - "but found element of type {} at pos {}".format(type(elem).__name__, idx)) - else: - raise TypeError("unflattened_size must be tuple or list, but found type {}".format(type(sizes).__name__)) - - def update(self, x): - dim = self.dim + 1 if self.mode.is_batch_mode() else self.dim - return bm.unflatten(x, dim, self.sizes) - - def __repr__(self): - return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})' + r""" + Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. + + * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can + be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. + + * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be + a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` + (tuple of `(name, size)` tuples) for `NamedTensor` input. + + Shape: + - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at + dimension :attr:`dim` and :math:`*` means any number of dimensions including none. + - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and + :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. + + Args: + dim: int, Dimension to be unflattened. + sizes: Sequence of int. New shape of the unflattened dimension. + + Examples: + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> input = bm.random.randn(2, 50) + >>> # With tuple of ints + >>> m = bp.Sequential( + >>> bp.dnn.Linear(50, 50), + >>> Unflatten(1, (2, 5, 5)) + >>> ) + >>> output = m(input) + >>> output.shape + (2, 2, 5, 5) + >>> # With torch.Size + >>> m = bp.Sequential( + >>> bp.dnn.Linear(50, 50), + >>> Unflatten(1, [2, 5, 5]) + >>> ) + >>> output = m(input) + >>> output.shape + (2, 2, 5, 5) + """ + + def __init__(self, dim: int, sizes: Sequence[int], mode: bm.Mode = None, name: str = None) -> None: + super().__init__(mode=mode, name=name) + + self.dim = dim + self.sizes = sizes + if isinstance(sizes, (tuple, list)): + for idx, elem in enumerate(sizes): + if not isinstance(elem, int): + raise TypeError("unflattened_size must be tuple of ints, " + + "but found element of type {} at pos {}".format(type(elem).__name__, idx)) + else: + raise TypeError("unflattened_size must be tuple or list, but found type {}".format(type(sizes).__name__)) + + def update(self, x): + dim = self.dim + 1 if self.mode.is_batch_mode() else self.dim + return bm.unflatten(x, dim, self.sizes) + + def __repr__(self): + return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})' class FunAsLayer(Layer): - def __init__( - self, - fun: Callable, - name: Optional[str] = None, - mode: bm.Mode = None, - **kwargs, - ): - super().__init__(name, mode) - self._fun = fun - self.kwargs = kwargs - - def update(self, *args, **kwargs): - return self._fun(*args, **kwargs, **self.kwargs) + def __init__( + self, + fun: Callable, + name: Optional[str] = None, + mode: bm.Mode = None, + **kwargs, + ): + super().__init__(name, mode) + self._fun = fun + self.kwargs = kwargs + + def update(self, *args, **kwargs): + return self._fun(*args, **kwargs, **self.kwargs) diff --git a/brainpy/_src/dnn/interoperation_flax.py b/brainpy/_src/dnn/interoperation_flax.py index d1bcec068..5b00333bb 100644 --- a/brainpy/_src/dnn/interoperation_flax.py +++ b/brainpy/_src/dnn/interoperation_flax.py @@ -1,148 +1,144 @@ - -import jax import dataclasses from typing import Dict, Tuple + +import jax from jax.tree_util import tree_flatten, tree_map, tree_unflatten from brainpy import math as bm -from brainpy._src.dynsys import DynamicalSystem from brainpy._src.context import share from brainpy._src.dnn.base import Layer +from brainpy._src.dynsys import DynamicalSystem try: - import flax # noqa - from flax.linen.recurrent import RNNCellBase + import flax # noqa + from flax.linen.recurrent import RNNCellBase except: - flax = None - RNNCellBase = object - + flax = None + RNNCellBase = object __all__ = [ - 'FromFlax', - 'ToFlaxRNNCell', - 'ToFlax', + 'FromFlax', + 'ToFlaxRNNCell', + 'ToFlax', ] def _as_jax(a): - if isinstance(a, bm.BaseArray): - return a.value - else: - return a + if isinstance(a, bm.BaseArray): + return a.value + else: + return a def _is_bp(a): - return isinstance(a, bm.BaseArray) + return isinstance(a, bm.BaseArray) class FromFlax(Layer): - """ - Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`. - - Parameters:: - - flax_module: Any - The flax Module. - module_args: Any - The module arguments, used to initialize model parameters. - module_kwargs: Any - The module arguments, used to initialize model parameters. - """ - def __init__(self, flax_module, *module_args, **module_kwargs): - super().__init__() - self.flax_module = flax_module - params = self.flax_module.init(bm.random.split_key(), - *tree_map(_as_jax, module_args, is_leaf=_is_bp), - **tree_map(_as_jax, module_kwargs, is_leaf=_is_bp)) - leaves, self._tree = tree_flatten(params) - self.variables = bm.VarList(tree_map(bm.TrainVar, leaves)) - - def update(self, *args, **kwargs): - params = tree_unflatten(self._tree, [v.value for v in self.variables]) - return self.flax_module.apply(params, - *tree_map(_as_jax, args, is_leaf=_is_bp), - **tree_map(_as_jax, kwargs, is_leaf=_is_bp)) - - def reset_state(self, *args, **kwargs): - pass + """ + Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`. + + Parameters:: + + flax_module: Any + The flax Module. + module_args: Any + The module arguments, used to initialize model parameters. + module_kwargs: Any + The module arguments, used to initialize model parameters. + """ + + def __init__(self, flax_module, *module_args, **module_kwargs): + super().__init__() + self.flax_module = flax_module + params = self.flax_module.init(bm.random.split_key(), + *tree_map(_as_jax, module_args, is_leaf=_is_bp), + **tree_map(_as_jax, module_kwargs, is_leaf=_is_bp)) + leaves, self._tree = tree_flatten(params) + self.variables = bm.VarList(tree_map(bm.TrainVar, leaves)) + + def update(self, *args, **kwargs): + params = tree_unflatten(self._tree, [v.value for v in self.variables]) + return self.flax_module.apply(params, + *tree_map(_as_jax, args, is_leaf=_is_bp), + **tree_map(_as_jax, kwargs, is_leaf=_is_bp)) + + def reset_state(self, *args, **kwargs): + pass to_flax_doc = """Transform a BrainPy :py:class:`~.DynamicalSystem` into a Flax recurrent module.""" - if flax is not None: - class ToFlaxRNNCell(RNNCellBase): - __doc__ = to_flax_doc - - model: DynamicalSystem - train_params: Dict[str, jax.Array] = dataclasses.field(init=False) - - def initialize_carry(self, rng, input_shape: Tuple[int, ...]): - batch_dims = input_shape[:-1] - if len(batch_dims) == 1: - batch_dims = 1 - elif len(batch_dims) == 0: - batch_dims = None - else: - raise ValueError(f'Invalid input shape: {input_shape}') - _state_vars = self.model.vars().unique().not_subset(bm.TrainVar) - self.model.reset(batch_dims) - return [_state_vars.dict(), 0, 0.] - - def setup(self): - _vars = self.model.vars().unique() - _train_vars = _vars.subset(bm.TrainVar) - self.train_params = self.param(self.model.name, lambda rng, a: a.dict(), _train_vars) - - def __call__(self, carry, *inputs): - """A recurrent cell that transformed from a BrainPy :py:class:`~.DynamicalSystem`. - - Args: - carry: the hidden state of the transformed recurrent cell, initialized using - `.initialize_carry()` function in which the original `.reset_state()` is called. - inputs: an ndarray with the input for the current time step. All - dimensions except the final are considered batch dimensions. - - Returns: - A tuple with the new carry and the output. - """ - # shared arguments - i, t = carry[1], carry[2] - old_i = share.load('i', i) - old_t = share.load('t', t) - share.save(i=i, t=t) - - # carry - _vars = self.model.vars().unique() - _state_vars = _vars.not_subset(bm.TrainVar) - for k, v in carry[0].items(): - _state_vars[k].value = v - - # train parameters - _train_vars = _vars.subset(bm.TrainVar) - for k, v in self.train_params.items(): - _train_vars[k].value = v - - # recurrent cell - out = self.model(*inputs) - - # shared arguments - share.save(i=old_i, t=old_t) - # carray and output - return [_state_vars.dict(), i + 1, t + share.dt], out - - @property - def num_feature_axes(self) -> int: - return 1 + class ToFlaxRNNCell(RNNCellBase): + __doc__ = to_flax_doc + + model: DynamicalSystem + train_params: Dict[str, jax.Array] = dataclasses.field(init=False) + + def initialize_carry(self, rng, input_shape: Tuple[int, ...]): + batch_dims = input_shape[:-1] + if len(batch_dims) == 1: + batch_dims = 1 + elif len(batch_dims) == 0: + batch_dims = None + else: + raise ValueError(f'Invalid input shape: {input_shape}') + _state_vars = self.model.vars().unique().not_subset(bm.TrainVar) + self.model.reset(batch_dims) + return [_state_vars.dict(), 0, 0.] + + def setup(self): + _vars = self.model.vars().unique() + _train_vars = _vars.subset(bm.TrainVar) + self.train_params = self.param(self.model.name, lambda rng, a: a.dict(), _train_vars) + + def __call__(self, carry, *inputs): + """A recurrent cell that transformed from a BrainPy :py:class:`~.DynamicalSystem`. + + Args: + carry: the hidden state of the transformed recurrent cell, initialized using + `.initialize_carry()` function in which the original `.reset_state()` is called. + inputs: an ndarray with the input for the current time step. All + dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + # shared arguments + i, t = carry[1], carry[2] + old_i = share.load('i', i) + old_t = share.load('t', t) + share.save(i=i, t=t) + + # carry + _vars = self.model.vars().unique() + _state_vars = _vars.not_subset(bm.TrainVar) + for k, v in carry[0].items(): + _state_vars[k].value = v + + # train parameters + _train_vars = _vars.subset(bm.TrainVar) + for k, v in self.train_params.items(): + _train_vars[k].value = v + + # recurrent cell + out = self.model(*inputs) + + # shared arguments + share.save(i=old_i, t=old_t) + # carray and output + return [_state_vars.dict(), i + 1, t + share.dt], out + + @property + def num_feature_axes(self) -> int: + return 1 else: - class ToFlaxRNNCell(object): - __doc__ = to_flax_doc - - def __init__(self, *args, **kwargs): - raise ModuleNotFoundError('"flax" is not installed, or importing "flax" has errors. Please check.') + class ToFlaxRNNCell(object): + __doc__ = to_flax_doc + def __init__(self, *args, **kwargs): + raise ModuleNotFoundError('"flax" is not installed, or importing "flax" has errors. Please check.') ToFlax = ToFlaxRNNCell - - diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 6fa75f9fa..9f9627504 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- -import importlib.util import numbers from typing import Dict, Optional, Union, Callable @@ -16,11 +15,10 @@ from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy.check import is_initializer from brainpy.connect import csr2csc -from brainpy.errors import MathError, PackageMissingError +from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding - __all__ = [ 'Dense', 'Linear', 'Identity', @@ -42,8 +40,8 @@ class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): y = x \cdot weight + b - Parameters:: - + Parameters:: + num_in: int The number of the input feature. A positive integer. num_out: int diff --git a/brainpy/_src/dnn/normalization.py b/brainpy/_src/dnn/normalization.py index b56a6b699..5fa13901a 100644 --- a/brainpy/_src/dnn/normalization.py +++ b/brainpy/_src/dnn/normalization.py @@ -4,634 +4,634 @@ from jax import lax, numpy as jnp +from brainpy import math as bm from brainpy._src.context import share -from brainpy import math as bm, check +from brainpy._src.dnn.base import Layer from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter from brainpy.types import ArrayType -from brainpy._src.dnn.base import Layer __all__ = [ - 'BatchNorm1d', - 'BatchNorm2d', - 'BatchNorm3d', - 'BatchNorm1D', - 'BatchNorm2D', - 'BatchNorm3D', - - 'LayerNorm', - 'GroupNorm', - 'InstanceNorm', + 'BatchNorm1d', + 'BatchNorm2d', + 'BatchNorm3d', + 'BatchNorm1D', + 'BatchNorm2D', + 'BatchNorm3D', + + 'LayerNorm', + 'GroupNorm', + 'InstanceNorm', ] def _square(x): - """Computes the elementwise square of the absolute value |x|^2.""" - if jnp.iscomplexobj(x): - return lax.square(lax.real(x)) + lax.square(lax.imag(x)) - else: - return lax.square(x) + """Computes the elementwise square of the absolute value |x|^2.""" + if jnp.iscomplexobj(x): + return lax.square(lax.real(x)) + lax.square(lax.imag(x)) + else: + return lax.square(x) class BatchNorm(Layer): - r"""Batch Normalization layer [1]_. - - This layer aims to reduce the internal covariant shift of data. It - normalizes a batch of data by fixing the mean and variance of inputs - on each feature (channel). Most commonly, the first axis of the data - is the batch, and the last is the channel. However, users can specify - the axes to be normalized. - - .. math:: - y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta - - .. note:: - This :attr:`momentum` argument is different from one used in optimizer - classes and the conventional notion of momentum. Mathematically, the - update rule for running statistics here is - :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, - where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the - new observed value. - - Parameters:: - - num_features: int - ``C`` from an expected input of size ``(..., C)``. - axis: int, tuple, list - Axes where the data will be normalized. The feature (channel) axis should be excluded. - momentum: float - The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - epsilon: float - A value added to the denominator for numerical stability. Default: 1e-5 - affine: bool - A boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable - An initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - An initializer generating the original scaling matrix - axis_name: optional, str, sequence of str - If not ``None``, it should be a string (or sequence of - strings) representing the axis name(s) over which this module is being - run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this - argument means that batch statistics are calculated across all replicas - on the named axes. - axis_index_groups: optional, sequence - Specifies how devices are grouped. Valid - only within ``jax.pmap`` collectives. - - References:: - - .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. - - """ - supported_modes = (bm.BatchingMode, bm.TrainingMode) - - def __init__( - self, - num_features: int, - axis: Union[int, Sequence[int]], - epsilon: float = 1e-5, - momentum: float = 0.99, - affine: bool = True, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - axis_name: Optional[Union[str, Sequence[str]]] = None, - axis_index_groups: Optional[Sequence[Sequence[int]]] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(BatchNorm, self).__init__(name=name, mode=mode) - # check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name) - - # parameters - self.num_features = num_features - self.epsilon = epsilon - self.momentum = momentum - self.affine = affine - self.bias_initializer = bias_initializer - self.scale_initializer = scale_initializer - self.axis = (axis,) if jnp.isscalar(axis) else axis - self.axis_name = axis_name - self.axis_index_groups = axis_index_groups - - # variables - self.running_mean = bm.Variable(jnp.zeros(self.num_features)) - self.running_var = bm.Variable(jnp.ones(self.num_features)) - if self.affine: - assert isinstance(self.mode, bm.TrainingMode) - self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_features)) - self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_features)) - - def _check_input_dim(self, x): - raise NotImplementedError - - def update(self, x): - self._check_input_dim(x) - - x = bm.as_jax(x) - - if share.load('fit'): - mean = jnp.mean(x, self.axis) - mean_of_square = jnp.mean(_square(x), self.axis) - if self.axis_name is not None: - mean, mean_of_square = jnp.split( - lax.pmean(jnp.concatenate([mean, mean_of_square]), - axis_name=self.axis_name, - axis_index_groups=self.axis_index_groups), - 2 - ) - var = jnp.maximum(0., mean_of_square - _square(mean)) - self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean) - self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var) - else: - mean = self.running_mean.value - var = self.running_var.value - stats_shape = [(1 if i in self.axis else x.shape[i]) for i in range(x.ndim)] - mean = mean.reshape(stats_shape) - var = var.reshape(stats_shape) - - y = x - mean - mul = lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) - if self.affine: - mul *= self.scale - y *= mul - if self.affine: - y += self.bias - return y + r"""Batch Normalization layer [1]_. + + This layer aims to reduce the internal covariant shift of data. It + normalizes a batch of data by fixing the mean and variance of inputs + on each feature (channel). Most commonly, the first axis of the data + is the batch, and the last is the channel. However, users can specify + the axes to be normalized. + + .. math:: + y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Parameters:: + + num_features: int + ``C`` from an expected input of size ``(..., C)``. + axis: int, tuple, list + Axes where the data will be normalized. The feature (channel) axis should be excluded. + momentum: float + The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 + epsilon: float + A value added to the denominator for numerical stability. Default: 1e-5 + affine: bool + A boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + bias_initializer: Initializer, ArrayType, Callable + An initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + An initializer generating the original scaling matrix + axis_name: optional, str, sequence of str + If not ``None``, it should be a string (or sequence of + strings) representing the axis name(s) over which this module is being + run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this + argument means that batch statistics are calculated across all replicas + on the named axes. + axis_index_groups: optional, sequence + Specifies how devices are grouped. Valid + only within ``jax.pmap`` collectives. + + References:: + + .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. + + """ + supported_modes = (bm.BatchingMode, bm.TrainingMode) + + def __init__( + self, + num_features: int, + axis: Union[int, Sequence[int]], + epsilon: float = 1e-5, + momentum: float = 0.99, + affine: bool = True, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + axis_name: Optional[Union[str, Sequence[str]]] = None, + axis_index_groups: Optional[Sequence[Sequence[int]]] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(BatchNorm, self).__init__(name=name, mode=mode) + # check.is_subclass(self.mode, (bm.BatchingMode, bm.TrainingMode), self.name) + + # parameters + self.num_features = num_features + self.epsilon = epsilon + self.momentum = momentum + self.affine = affine + self.bias_initializer = bias_initializer + self.scale_initializer = scale_initializer + self.axis = (axis,) if jnp.isscalar(axis) else axis + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + + # variables + self.running_mean = bm.Variable(jnp.zeros(self.num_features)) + self.running_var = bm.Variable(jnp.ones(self.num_features)) + if self.affine: + assert isinstance(self.mode, bm.TrainingMode) + self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_features)) + self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_features)) + + def _check_input_dim(self, x): + raise NotImplementedError + + def update(self, x): + self._check_input_dim(x) + + x = bm.as_jax(x) + + if share.load('fit'): + mean = jnp.mean(x, self.axis) + mean_of_square = jnp.mean(_square(x), self.axis) + if self.axis_name is not None: + mean, mean_of_square = jnp.split( + lax.pmean(jnp.concatenate([mean, mean_of_square]), + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups), + 2 + ) + var = jnp.maximum(0., mean_of_square - _square(mean)) + self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean) + self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var) + else: + mean = self.running_mean.value + var = self.running_var.value + stats_shape = [(1 if i in self.axis else x.shape[i]) for i in range(x.ndim)] + mean = mean.reshape(stats_shape) + var = var.reshape(stats_shape) + + y = x - mean + mul = lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) + if self.affine: + mul *= self.scale + y *= mul + if self.affine: + y += self.bias + return y class BatchNorm1d(BatchNorm): - r"""1-D batch normalization [1]_. - - The data should be of `(b, l, c)`, where `b` is the batch dimension, - `l` is the layer dimension, and `c` is the channel dimension. - - .. math:: - y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta - - .. note:: - This :attr:`momentum` argument is different from one used in optimizer - classes and the conventional notion of momentum. Mathematically, the - update rule for running statistics here is - :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, - where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the - new observed value. - - Parameters:: - - num_features: int - ``C`` from an expected input of size ``(B, L, C)``. - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - A value added to the denominator for numerical stability. Default: 1e-5 - momentum: float - The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - affine: bool - A boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable - an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - an initializer generating the original scaling matrix - axis_name: optional, str, sequence of str - If not ``None``, it should be a string (or sequence of - strings) representing the axis name(s) over which this module is being - run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this - argument means that batch statistics are calculated across all replicas - on the named axes. - axis_index_groups: optional, sequence - Specifies how devices are grouped. Valid - only within ``jax.pmap`` collectives. - - References:: - - .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. - - """ - - def __init__( - self, - num_features: int, - axis: Union[int, Sequence[int]] = (0, 1), - epsilon: float = 1e-5, - momentum: float = 0.99, - affine: bool = True, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - axis_name: Optional[Union[str, Sequence[str]]] = None, - axis_index_groups: Optional[Sequence[Sequence[int]]] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(BatchNorm1d, self).__init__(num_features=num_features, - axis=axis, - epsilon=epsilon, - momentum=momentum, - affine=affine, - bias_initializer=bias_initializer, - scale_initializer=scale_initializer, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - mode=mode, - name=name) - - def _check_input_dim(self, x): - if x.ndim != 3: - raise ValueError(f"expected 3D input (got {x.ndim}D input)") - assert x.shape[-1] == self.num_features + r"""1-D batch normalization [1]_. + + The data should be of `(b, l, c)`, where `b` is the batch dimension, + `l` is the layer dimension, and `c` is the channel dimension. + + .. math:: + y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Parameters:: + + num_features: int + ``C`` from an expected input of size ``(B, L, C)``. + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + A value added to the denominator for numerical stability. Default: 1e-5 + momentum: float + The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 + affine: bool + A boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + bias_initializer: Initializer, ArrayType, Callable + an initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + an initializer generating the original scaling matrix + axis_name: optional, str, sequence of str + If not ``None``, it should be a string (or sequence of + strings) representing the axis name(s) over which this module is being + run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this + argument means that batch statistics are calculated across all replicas + on the named axes. + axis_index_groups: optional, sequence + Specifies how devices are grouped. Valid + only within ``jax.pmap`` collectives. + + References:: + + .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. + + """ + + def __init__( + self, + num_features: int, + axis: Union[int, Sequence[int]] = (0, 1), + epsilon: float = 1e-5, + momentum: float = 0.99, + affine: bool = True, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + axis_name: Optional[Union[str, Sequence[str]]] = None, + axis_index_groups: Optional[Sequence[Sequence[int]]] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(BatchNorm1d, self).__init__(num_features=num_features, + axis=axis, + epsilon=epsilon, + momentum=momentum, + affine=affine, + bias_initializer=bias_initializer, + scale_initializer=scale_initializer, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + mode=mode, + name=name) + + def _check_input_dim(self, x): + if x.ndim != 3: + raise ValueError(f"expected 3D input (got {x.ndim}D input)") + assert x.shape[-1] == self.num_features class BatchNorm2d(BatchNorm): - r"""2-D batch normalization [1]_. - - The data should be of `(b, h, w, c)`, where `b` is the batch dimension, - `h` is the height dimension, `w` is the width dimension, and `c` is the - channel dimension. - - .. math:: - y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta - - .. note:: - This :attr:`momentum` argument is different from one used in optimizer - classes and the conventional notion of momentum. Mathematically, the - update rule for running statistics here is - :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, - where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the - new observed value. - - Parameters:: - - num_features: int - ``C`` from an expected input of size ``(B, H, W, C)``. - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - momentum: float - The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - affine: bool - A boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable - an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - an initializer generating the original scaling matrix - axis_name: optional, str, sequence of str - If not ``None``, it should be a string (or sequence of - strings) representing the axis name(s) over which this module is being - run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this - argument means that batch statistics are calculated across all replicas - on the named axes. - axis_index_groups: optional, sequence - Specifies how devices are grouped. Valid - only within ``jax.pmap`` collectives. - - References:: - - .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. - - """ - - def __init__( - self, - num_features: int, - axis: Union[int, Sequence[int]] = (0, 1, 2), - epsilon: float = 1e-5, - momentum: float = 0.99, - affine: bool = True, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - axis_name: Optional[Union[str, Sequence[str]]] = None, - axis_index_groups: Optional[Sequence[Sequence[int]]] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(BatchNorm2d, self).__init__(num_features=num_features, - axis=axis, - epsilon=epsilon, - momentum=momentum, - affine=affine, - bias_initializer=bias_initializer, - scale_initializer=scale_initializer, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - mode=mode, - name=name) - - def _check_input_dim(self, x): - if x.ndim != 4: - raise ValueError(f"expected 4D input (got {x.ndim}D input)") - assert x.shape[-1] == self.num_features + r"""2-D batch normalization [1]_. + + The data should be of `(b, h, w, c)`, where `b` is the batch dimension, + `h` is the height dimension, `w` is the width dimension, and `c` is the + channel dimension. + + .. math:: + y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Parameters:: + + num_features: int + ``C`` from an expected input of size ``(B, H, W, C)``. + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + momentum: float + The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 + affine: bool + A boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + bias_initializer: Initializer, ArrayType, Callable + an initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + an initializer generating the original scaling matrix + axis_name: optional, str, sequence of str + If not ``None``, it should be a string (or sequence of + strings) representing the axis name(s) over which this module is being + run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this + argument means that batch statistics are calculated across all replicas + on the named axes. + axis_index_groups: optional, sequence + Specifies how devices are grouped. Valid + only within ``jax.pmap`` collectives. + + References:: + + .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. + + """ + + def __init__( + self, + num_features: int, + axis: Union[int, Sequence[int]] = (0, 1, 2), + epsilon: float = 1e-5, + momentum: float = 0.99, + affine: bool = True, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + axis_name: Optional[Union[str, Sequence[str]]] = None, + axis_index_groups: Optional[Sequence[Sequence[int]]] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(BatchNorm2d, self).__init__(num_features=num_features, + axis=axis, + epsilon=epsilon, + momentum=momentum, + affine=affine, + bias_initializer=bias_initializer, + scale_initializer=scale_initializer, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + mode=mode, + name=name) + + def _check_input_dim(self, x): + if x.ndim != 4: + raise ValueError(f"expected 4D input (got {x.ndim}D input)") + assert x.shape[-1] == self.num_features class BatchNorm3d(BatchNorm): - r"""3-D batch normalization [1]_. - - The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension, - `h` is the height dimension, `w` is the width dimension, `d` is the depth - dimension, and `c` is the channel dimension. - - .. math:: - y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta - - .. note:: - This :attr:`momentum` argument is different from one used in optimizer - classes and the conventional notion of momentum. Mathematically, the - update rule for running statistics here is - :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, - where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the - new observed value. - - Parameters:: - - num_features: int - ``C`` from an expected input of size ``(B, H, W, D, C)``. - axis: int, tuple, list - axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - momentum: float - The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - affine: bool - A boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable - an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - an initializer generating the original scaling matrix - axis_name: optional, str, sequence of str - If not ``None``, it should be a string (or sequence of - strings) representing the axis name(s) over which this module is being - run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this - argument means that batch statistics are calculated across all replicas - on the named axes. - axis_index_groups: optional, sequence - Specifies how devices are grouped. Valid - only within ``jax.pmap`` collectives. - - References:: - - .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. - - """ - - def __init__( - self, - num_features: int, - axis: Union[int, Sequence[int]] = (0, 1, 2, 3), - epsilon: float = 1e-5, - momentum: float = 0.99, - affine: bool = True, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - axis_name: Optional[Union[str, Sequence[str]]] = None, - axis_index_groups: Optional[Sequence[Sequence[int]]] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(BatchNorm3d, self).__init__(num_features=num_features, - axis=axis, - epsilon=epsilon, - momentum=momentum, - affine=affine, - bias_initializer=bias_initializer, - scale_initializer=scale_initializer, - axis_name=axis_name, - axis_index_groups=axis_index_groups, - mode=mode, - name=name) - - def _check_input_dim(self, x): - if x.ndim != 5: - raise ValueError(f"expected 5D input (got {x.ndim}D input)") - assert x.shape[-1] == self.num_features + r"""3-D batch normalization [1]_. + + The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension, + `h` is the height dimension, `w` is the width dimension, `d` is the depth + dimension, and `c` is the channel dimension. + + .. math:: + y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta + + .. note:: + This :attr:`momentum` argument is different from one used in optimizer + classes and the conventional notion of momentum. Mathematically, the + update rule for running statistics here is + :math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`, + where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the + new observed value. + + Parameters:: + + num_features: int + ``C`` from an expected input of size ``(B, H, W, D, C)``. + axis: int, tuple, list + axes where the data will be normalized. The feature (channel) axis should be excluded. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + momentum: float + The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 + affine: bool + A boolean value that when set to ``True``, this module has + learnable affine parameters. Default: ``True`` + bias_initializer: Initializer, ArrayType, Callable + an initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + an initializer generating the original scaling matrix + axis_name: optional, str, sequence of str + If not ``None``, it should be a string (or sequence of + strings) representing the axis name(s) over which this module is being + run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this + argument means that batch statistics are calculated across all replicas + on the named axes. + axis_index_groups: optional, sequence + Specifies how devices are grouped. Valid + only within ``jax.pmap`` collectives. + + References:: + + .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. + + """ + + def __init__( + self, + num_features: int, + axis: Union[int, Sequence[int]] = (0, 1, 2, 3), + epsilon: float = 1e-5, + momentum: float = 0.99, + affine: bool = True, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + axis_name: Optional[Union[str, Sequence[str]]] = None, + axis_index_groups: Optional[Sequence[Sequence[int]]] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(BatchNorm3d, self).__init__(num_features=num_features, + axis=axis, + epsilon=epsilon, + momentum=momentum, + affine=affine, + bias_initializer=bias_initializer, + scale_initializer=scale_initializer, + axis_name=axis_name, + axis_index_groups=axis_index_groups, + mode=mode, + name=name) + + def _check_input_dim(self, x): + if x.ndim != 5: + raise ValueError(f"expected 5D input (got {x.ndim}D input)") + assert x.shape[-1] == self.num_features class LayerNorm(Layer): - r"""Layer normalization (https://arxiv.org/abs/1607.06450). + r"""Layer normalization (https://arxiv.org/abs/1607.06450). - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + This layer normalizes data on each example, independently of the batch. More + specifically, it normalizes data of shape (b, d1, d2, ..., c) on the axes of + the data dimensions and the channel (d1, d2, ..., c). Different from batch + normalization, scale and bias are assigned to each position (elementwise + operation) instead of the whole channel. If users want to assign a single + scale and bias to a whole example/whole channel, please use GroupNorm/ + InstanceNorm. + + Parameters:: + + normalized_shape: int, sequence of int + The input shape from an expected input of size + + .. math:: + [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] + \times \ldots \times \text{normalized\_shape}[-1]] + + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + bias_initializer: Initializer, ArrayType, Callable + an initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + an initializer generating the original scaling matrix + elementwise_affine: bool + A boolean value that when set to ``True``, this module + has learnable per-element affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Examples:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> # NLP Example + >>> batch, sentence_length, embedding_dim = 20, 5, 10 + >>> embedding = bm.random.randn(batch, sentence_length, embedding_dim) + >>> layer_norm = bp.layers.LayerNorm(embedding_dim) + >>> # Activate module + >>> layer_norm(embedding) + >>> + >>> # Image Example + >>> N, C, H, W = 20, 5, 10, 10 + >>> input = bm.random.randn(N, H, W, C) + >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) + >>> # as shown in the image below + >>> layer_norm = bp.layers.LayerNorm([H, W, C]) + >>> output = layer_norm(input) + + """ + + def __init__( + self, + normalized_shape: Union[int, Sequence[int]], + epsilon: float = 1e-5, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + elementwise_affine: bool = True, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None + ): + super(LayerNorm, self).__init__(name=name, mode=mode) + + self.epsilon = epsilon + self.bias_initializer = bias_initializer + self.scale_initializer = scale_initializer + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.' + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + assert isinstance(self.mode, bm.TrainingMode) + self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape)) + self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape)) + + def update(self, x): + if x.shape[-len(self.normalized_shape):] != self.normalized_shape: + raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), ' + f'but we got {x.shape}') + axis = tuple(range(0, x.ndim - len(self.normalized_shape))) + mean = jnp.mean(bm.as_jax(x), axis=axis, keepdims=True) + variance = jnp.var(bm.as_jax(x), axis=axis, keepdims=True) + inv = lax.rsqrt(variance + lax.convert_element_type(self.epsilon, x.dtype)) + out = (x - mean) * inv + if self.elementwise_affine: + out = self.scale * out + self.bias + return out - This layer normalizes data on each example, independently of the batch. More - specifically, it normalizes data of shape (b, d1, d2, ..., c) on the axes of - the data dimensions and the channel (d1, d2, ..., c). Different from batch - normalization, scale and bias are assigned to each position (elementwise - operation) instead of the whole channel. If users want to assign a single - scale and bias to a whole example/whole channel, please use GroupNorm/ - InstanceNorm. - Parameters:: - - normalized_shape: int, sequence of int - The input shape from an expected input of size +class GroupNorm(Layer): + r"""Group normalization layer. .. math:: - [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] - \times \ldots \times \text{normalized\_shape}[-1]] - - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - bias_initializer: Initializer, ArrayType, Callable - an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - an initializer generating the original scaling matrix - elementwise_affine: bool - A boolean value that when set to ``True``, this module - has learnable per-element affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> # NLP Example - >>> batch, sentence_length, embedding_dim = 20, 5, 10 - >>> embedding = bm.random.randn(batch, sentence_length, embedding_dim) - >>> layer_norm = bp.layers.LayerNorm(embedding_dim) - >>> # Activate module - >>> layer_norm(embedding) - >>> - >>> # Image Example - >>> N, C, H, W = 20, 5, 10, 10 - >>> input = bm.random.randn(N, H, W, C) - >>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions) - >>> # as shown in the image below - >>> layer_norm = bp.layers.LayerNorm([H, W, C]) - >>> output = layer_norm(input) - - """ - - def __init__( - self, - normalized_shape: Union[int, Sequence[int]], - epsilon: float = 1e-5, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - elementwise_affine: bool = True, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None - ): - super(LayerNorm, self).__init__(name=name, mode=mode) - - self.epsilon = epsilon - self.bias_initializer = bias_initializer - self.scale_initializer = scale_initializer - if isinstance(normalized_shape, int): - normalized_shape = (normalized_shape,) - self.normalized_shape = tuple(normalized_shape) - assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.' - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - assert isinstance(self.mode, bm.TrainingMode) - self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape)) - self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape)) - - def update(self, x): - if x.shape[-len(self.normalized_shape):] != self.normalized_shape: - raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), ' - f'but we got {x.shape}') - axis = tuple(range(0, x.ndim - len(self.normalized_shape))) - mean = jnp.mean(bm.as_jax(x), axis=axis, keepdims=True) - variance = jnp.var(bm.as_jax(x), axis=axis, keepdims=True) - inv = lax.rsqrt(variance + lax.convert_element_type(self.epsilon, x.dtype)) - out = (x - mean) * inv - if self.elementwise_affine: - out = self.scale * out + self.bias - return out + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta -class GroupNorm(Layer): - r"""Group normalization layer. - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta - - - This layer divides channels into groups and normalizes the features within each - group. Its computation is also independent of the batch size. The feature size - must be multiple of the group size. - - The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch - size and `c` denotes the feature (channel) size. - - Parameters:: - - num_groups: int - The number of groups. It should be a factor of the number of channels. - num_channels: int - The number of channels expected in input. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - affine: bool - A boolean value that when set to ``True``, this module - has learnable per-channel affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. - bias_initializer: Initializer, ArrayType, Callable - An initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - An initializer generating the original scaling matrix - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> input = bm.random.randn(20, 10, 10, 6) - >>> # Separate 6 channels into 3 groups - >>> m = bp.layers.GroupNorm(3, 6) - >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) - >>> m = bp.layers.GroupNorm(6, 6) - >>> # Put all 6 channels into a single group (equivalent with LayerNorm) - >>> m = bp.layers.GroupNorm(1, 6) - >>> # Activating the module - >>> output = m(input) - """ - - def __init__( - self, - num_groups: int, - num_channels: int, - epsilon: float = 1e-5, - affine: bool = True, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(GroupNorm, self).__init__(name=name, mode=mode) - if num_channels % num_groups != 0: - raise ValueError('num_channels must be divisible by num_groups') - self.num_groups = num_groups - self.num_channels = num_channels - self.epsilon = epsilon - self.affine = affine - self.bias_initializer = bias_initializer - self.scale_initializer = scale_initializer - if self.affine: - assert isinstance(self.mode, bm.TrainingMode) - self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels)) - self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels)) - - def update(self, x): - assert x.shape[-1] == self.num_channels - origin_shape, origin_dim = x.shape, x.ndim - group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups) - x = bm.as_jax(x.reshape(group_shape)) - reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) - mean = jnp.mean(x, reduction_axes, keepdims=True) - var = jnp.var(x, reduction_axes, keepdims=True) - x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) - x = x.reshape(origin_shape) - if self.affine: - x = x * lax.broadcast_to_rank(self.scale.value, origin_dim) - x = x + lax.broadcast_to_rank(self.bias.value, origin_dim) - return x + This layer divides channels into groups and normalizes the features within each + group. Its computation is also independent of the batch size. The feature size + must be multiple of the group size. + + The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch + size and `c` denotes the feature (channel) size. + + Parameters:: + + num_groups: int + The number of groups. It should be a factor of the number of channels. + num_channels: int + The number of channels expected in input. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + affine: bool + A boolean value that when set to ``True``, this module + has learnable per-channel affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias_initializer: Initializer, ArrayType, Callable + An initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + An initializer generating the original scaling matrix + + Examples:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> input = bm.random.randn(20, 10, 10, 6) + >>> # Separate 6 channels into 3 groups + >>> m = bp.layers.GroupNorm(3, 6) + >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + >>> m = bp.layers.GroupNorm(6, 6) + >>> # Put all 6 channels into a single group (equivalent with LayerNorm) + >>> m = bp.layers.GroupNorm(1, 6) + >>> # Activating the module + >>> output = m(input) + """ + + def __init__( + self, + num_groups: int, + num_channels: int, + epsilon: float = 1e-5, + affine: bool = True, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(GroupNorm, self).__init__(name=name, mode=mode) + if num_channels % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + self.num_groups = num_groups + self.num_channels = num_channels + self.epsilon = epsilon + self.affine = affine + self.bias_initializer = bias_initializer + self.scale_initializer = scale_initializer + if self.affine: + assert isinstance(self.mode, bm.TrainingMode) + self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels)) + self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels)) + + def update(self, x): + assert x.shape[-1] == self.num_channels + origin_shape, origin_dim = x.shape, x.ndim + group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups) + x = bm.as_jax(x.reshape(group_shape)) + reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,) + mean = jnp.mean(x, reduction_axes, keepdims=True) + var = jnp.var(x, reduction_axes, keepdims=True) + x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype)) + x = x.reshape(origin_shape) + if self.affine: + x = x * lax.broadcast_to_rank(self.scale.value, origin_dim) + x = x + lax.broadcast_to_rank(self.bias.value, origin_dim) + return x class InstanceNorm(GroupNorm): - r"""Instance normalization layer. - - This layer normalizes the data within each feature. It can be regarded as - a group normalization layer in which `group_size` equals to 1. - - Parameters:: - - num_channels: int - The number of channels expected in input. - epsilon: float - a value added to the denominator for numerical stability. Default: 1e-5 - affine: bool - A boolean value that when set to ``True``, this module - has learnable per-channel affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. - bias_initializer: Initializer, ArrayType, Callable - an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable - an initializer generating the original scaling matrix - """ - - def __init__( - self, - num_channels: int, - epsilon: float = 1e-5, - affine: bool = True, - bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), - scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(InstanceNorm, self).__init__(num_channels=num_channels, - num_groups=num_channels, - epsilon=epsilon, - affine=affine, - bias_initializer=bias_initializer, - scale_initializer=scale_initializer, - mode=mode, - name=name) + r"""Instance normalization layer. + + This layer normalizes the data within each feature. It can be regarded as + a group normalization layer in which `group_size` equals to 1. + + Parameters:: + + num_channels: int + The number of channels expected in input. + epsilon: float + a value added to the denominator for numerical stability. Default: 1e-5 + affine: bool + A boolean value that when set to ``True``, this module + has learnable per-channel affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + bias_initializer: Initializer, ArrayType, Callable + an initializer generating the original translation matrix + scale_initializer: Initializer, ArrayType, Callable + an initializer generating the original scaling matrix + """ + + def __init__( + self, + num_channels: int, + epsilon: float = 1e-5, + affine: bool = True, + bias_initializer: Union[Initializer, ArrayType, Callable] = ZeroInit(), + scale_initializer: Union[Initializer, ArrayType, Callable] = OneInit(), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(InstanceNorm, self).__init__(num_channels=num_channels, + num_groups=num_channels, + epsilon=epsilon, + affine=affine, + bias_initializer=bias_initializer, + scale_initializer=scale_initializer, + mode=mode, + name=name) BatchNorm1D = BatchNorm1d diff --git a/brainpy/_src/dnn/pooling.py b/brainpy/_src/dnn/pooling.py index b5375c04e..6429d6296 100644 --- a/brainpy/_src/dnn/pooling.py +++ b/brainpy/_src/dnn/pooling.py @@ -10,430 +10,386 @@ from brainpy._src.dnn.base import Layer __all__ = [ - 'MaxPool', - 'MinPool', - 'AvgPool', - 'AvgPool1d', - 'AvgPool2d', - 'AvgPool3d', - 'MaxPool1d', - 'MaxPool2d', - 'MaxPool3d', - 'AdaptiveAvgPool1d', - 'AdaptiveAvgPool2d', - 'AdaptiveAvgPool3d', - 'AdaptiveMaxPool1d', - 'AdaptiveMaxPool2d', - 'AdaptiveMaxPool3d', + 'MaxPool', + 'MinPool', + 'AvgPool', + 'AvgPool1d', + 'AvgPool2d', + 'AvgPool3d', + 'MaxPool1d', + 'MaxPool2d', + 'MaxPool3d', + 'AdaptiveAvgPool1d', + 'AdaptiveAvgPool2d', + 'AdaptiveAvgPool3d', + 'AdaptiveMaxPool1d', + 'AdaptiveMaxPool2d', + 'AdaptiveMaxPool3d', ] class Pool(Layer): - """Pooling functions are implemented using the ReduceWindow XLA op. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. - - """ - - def __init__( - self, - init_value, - computation, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]], - padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = None, - mode: bm.Mode = None, - name: Optional[str] = None, - ): - super(Pool, self).__init__(mode=mode, name=name) - - self.init_value = init_value - self.computation = computation - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.channel_axis = channel_axis - if isinstance(padding, str): - if padding not in ("SAME", "VALID"): - raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") - else: - assert all([isinstance(x, (tuple, list)) for x in padding]), \ - f'padding should be sequence of Tuple[int, int]. {padding}' - assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2" + """Pooling functions are implemented using the ReduceWindow XLA op. - def update(self, x): - x = bm.as_jax(x) - window_shape = self._infer_shape(x.ndim, self.kernel_size) - stride = self._infer_shape(x.ndim, self.stride) - padding = (self.padding - if isinstance(self.padding, str) else - self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list))) - r = jax.lax.reduce_window(bm.as_jax(x), - init_value=self.init_value, - computation=self.computation, - window_dimensions=window_shape, - window_strides=stride, - padding=padding) - return r - - def _infer_shape(self, - x_dim: int, - size: Union[Any, Sequence[Any]], - element: Any = 1, - element_type: Union[type, Sequence[type]] = int): - """Infer shape for pooling window or stride.""" - - # channel axis - channel_axis = self.channel_axis - if channel_axis and not 0 <= abs(channel_axis) < x_dim: - raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") - if channel_axis and channel_axis < 0: - channel_axis = x_dim + channel_axis - - if isinstance(size, (tuple, list)) and isinstance(size[0], element_type): - size = tuple(size) - if len(size) > x_dim: - raise ValueError(f'Invalid size {size}. Its dimension is bigger than its input.') - elif len(size) == x_dim: - return size - else: - if isinstance(self.mode, bm.BatchingMode): - size = (element,) + size - if len(size) + 1 == x_dim: - if channel_axis is None: - raise ValueError('"channel_axis" should be provided.') - size = size[:channel_axis] + (element,) + size[channel_axis:] + Parameters:: + + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode: Mode + The computation mode. + name: optional, str + The object name. + + """ + + def __init__( + self, + init_value, + computation, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]], + padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = None, + mode: bm.Mode = None, + name: Optional[str] = None, + ): + super(Pool, self).__init__(mode=mode, name=name) + + self.init_value = init_value + self.computation = computation + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.channel_axis = channel_axis + if isinstance(padding, str): + if padding not in ("SAME", "VALID"): + raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") else: - raise ValueError(f'size {size} is invalid. Please provide more elements.') - return size - else: - if isinstance(self.mode, bm.BatchingMode): - return (element,) + tuple((size if d != channel_axis else element) for d in range(1, x_dim)) - else: - return tuple((size if d != channel_axis else element) for d in range(0, x_dim)) + assert all([isinstance(x, (tuple, list)) for x in padding]), \ + f'padding should be sequence of Tuple[int, int]. {padding}' + assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2" + + def update(self, x): + x = bm.as_jax(x) + window_shape = self._infer_shape(x.ndim, self.kernel_size) + stride = self._infer_shape(x.ndim, self.stride) + padding = (self.padding + if isinstance(self.padding, str) else + self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list))) + r = jax.lax.reduce_window(bm.as_jax(x), + init_value=self.init_value, + computation=self.computation, + window_dimensions=window_shape, + window_strides=stride, + padding=padding) + return r + + def _infer_shape(self, + x_dim: int, + size: Union[Any, Sequence[Any]], + element: Any = 1, + element_type: Union[type, Sequence[type]] = int): + """Infer shape for pooling window or stride.""" + + # channel axis + channel_axis = self.channel_axis + if channel_axis and not 0 <= abs(channel_axis) < x_dim: + raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") + if channel_axis and channel_axis < 0: + channel_axis = x_dim + channel_axis + + if isinstance(size, (tuple, list)) and isinstance(size[0], element_type): + size = tuple(size) + if len(size) > x_dim: + raise ValueError(f'Invalid size {size}. Its dimension is bigger than its input.') + elif len(size) == x_dim: + return size + else: + if isinstance(self.mode, bm.BatchingMode): + size = (element,) + size + if len(size) + 1 == x_dim: + if channel_axis is None: + raise ValueError('"channel_axis" should be provided.') + size = size[:channel_axis] + (element,) + size[channel_axis:] + else: + raise ValueError(f'size {size} is invalid. Please provide more elements.') + return size + else: + if isinstance(self.mode, bm.BatchingMode): + return (element,) + tuple((size if d != channel_axis else element) for d in range(1, x_dim)) + else: + return tuple((size if d != channel_axis else element) for d in range(0, x_dim)) class MaxPool(Pool): - """Pools the input by taking the maximum over a window. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. - - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = None, - mode: bm.Mode = None, - name: Optional[str] = None, - ): - super(MaxPool, self).__init__(init_value=-jax.numpy.inf, - computation=jax.lax.max, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - mode=mode, - name=name) + """Pools the input by taking the maximum over a window. + + Parameters:: + + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode: Mode + The computation mode. + name: optional, str + The object name. + + """ + + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = None, + mode: bm.Mode = None, + name: Optional[str] = None, + ): + super(MaxPool, self).__init__(init_value=-jax.numpy.inf, + computation=jax.lax.max, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + mode=mode, + name=name) class MinPool(Pool): - """Pools the input by taking the minimum over a window. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. - - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = None, - mode: bm.Mode = None, - name: Optional[str] = None, - ): - super(MinPool, self).__init__(init_value=jax.numpy.inf, - computation=jax.lax.min, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - mode=mode, - name=name) + """Pools the input by taking the minimum over a window. + + Parameters:: + + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode: Mode + The computation mode. + name: optional, str + The object name. + + """ + + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = None, + mode: bm.Mode = None, + name: Optional[str] = None, + ): + super(MinPool, self).__init__(init_value=jax.numpy.inf, + computation=jax.lax.min, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + mode=mode, + name=name) class AvgPool(Pool): - """Pools the input by taking the average over a window. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = None, - mode: bm.Mode = None, - name: Optional[str] = None, - ): - super(AvgPool, self).__init__(init_value=0., - computation=jax.lax.add, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - mode=mode, - name=name) - - def update(self, x): - x = bm.as_jax(x) - window_shape = self._infer_shape(x.ndim, self.kernel_size) - strides = self._infer_shape(x.ndim, self.stride) - padding = (self.padding if isinstance(self.padding, str) else - self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list))) - pooled = jax.lax.reduce_window(bm.as_jax(x), - init_value=self.init_value, - computation=self.computation, - window_dimensions=window_shape, - window_strides=strides, - padding=padding) - if padding == "VALID": - # Avoid the extra reduce_window. - return pooled / np.prod(window_shape) - else: - # Count the number of valid entries at each input point, then use that for - # computing average. Assumes that any two arrays of same shape will be - # padded the same. - window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)), - init_value=self.init_value, - computation=self.computation, - window_dimensions=window_shape, - window_strides=strides, - padding=padding) - assert pooled.shape == window_counts.shape - return pooled / window_counts + """Pools the input by taking the average over a window. + Parameters:: -class _MaxPoolNd(Layer): - def __init__( - self, - init_value, - computation, - pool_dim: int, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = None, - padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(name=name, mode=mode) - - self.init_value = init_value - self.computation = computation - self.pool_dim = pool_dim - - # kernel_size - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) * pool_dim - elif isinstance(kernel_size, Sequence): - check.is_sequence(kernel_size, elem_type=int) - if len(kernel_size) != pool_dim: - raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}') - else: - raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.') - self.kernel_size = kernel_size - - # stride - if stride is None: - stride = kernel_size - if isinstance(stride, int): - stride = (stride,) * pool_dim - elif isinstance(stride, Sequence): - check.is_sequence(stride, elem_type=int) - if len(stride) != pool_dim: - raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}') - else: - raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.') - self.stride = stride - - # padding - if isinstance(padding, str): - if padding not in ("SAME", "VALID"): - raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") - elif isinstance(padding, int): - padding = [(padding, padding) for _ in range(pool_dim)] - elif isinstance(padding, (list, tuple)): - if isinstance(padding[0], int): - if len(padding) == pool_dim: - padding = [(x, x) for x in padding] + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode: Mode + The computation mode. + name: optional, str + The object name. + """ + + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[str, Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = None, + mode: bm.Mode = None, + name: Optional[str] = None, + ): + super(AvgPool, self).__init__(init_value=0., + computation=jax.lax.add, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + mode=mode, + name=name) + + def update(self, x): + x = bm.as_jax(x) + window_shape = self._infer_shape(x.ndim, self.kernel_size) + strides = self._infer_shape(x.ndim, self.stride) + padding = (self.padding if isinstance(self.padding, str) else + self._infer_shape(x.ndim, self.padding, element=(0, 0), element_type=(tuple, list))) + pooled = jax.lax.reduce_window(bm.as_jax(x), + init_value=self.init_value, + computation=self.computation, + window_dimensions=window_shape, + window_strides=strides, + padding=padding) + if padding == "VALID": + # Avoid the extra reduce_window. + return pooled / np.prod(window_shape) else: - raise ValueError(f'If padding is a sequence of ints, it ' - f'should has the length of {pool_dim}.') - else: - if not all([isinstance(x, (tuple, list)) for x in padding]): - raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}') - if not all([len(x) == 2 for x in padding]): - raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ") - if len(padding) == 1: - padding = tuple(padding) * pool_dim - assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}' - else: - raise ValueError - self.padding = padding + # Count the number of valid entries at each input point, then use that for + # computing average. Assumes that any two arrays of same shape will be + # padded the same. + window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)), + init_value=self.init_value, + computation=self.computation, + window_dimensions=window_shape, + window_strides=strides, + padding=padding) + assert pooled.shape == window_counts.shape + return pooled / window_counts - # channel_axis - self.channel_axis = check.is_integer(channel_axis, allow_none=True) - def update(self, x): - x = bm.as_jax(x) - x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) - if x.ndim < x_dim: - raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') - window_shape = self._infer_shape(x.ndim, self.kernel_size, 1) - stride = self._infer_shape(x.ndim, self.stride, 1) - padding = (self.padding - if isinstance(self.padding, str) else - self._infer_shape(x.ndim, self.padding, element=(0, 0))) - r = jax.lax.reduce_window(bm.as_jax(x), - init_value=self.init_value, - computation=self.computation, - window_dimensions=window_shape, - window_strides=stride, - padding=padding) - return r - - def _infer_shape(self, x_dim, inputs, element): - channel_axis = self.channel_axis - if channel_axis and not 0 <= abs(channel_axis) < x_dim: - raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") - if channel_axis and channel_axis < 0: - channel_axis = x_dim + channel_axis - all_dims = list(range(x_dim)) - if channel_axis is not None: - all_dims.pop(channel_axis) - pool_dims = all_dims[-self.pool_dim:] - results = [element] * x_dim - for i, dim in enumerate(pool_dims): - results[dim] = inputs[i] - return results +class _MaxPoolNd(Layer): + def __init__( + self, + init_value, + computation, + pool_dim: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = None, + padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(name=name, mode=mode) + + self.init_value = init_value + self.computation = computation + self.pool_dim = pool_dim + + # kernel_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * pool_dim + elif isinstance(kernel_size, Sequence): + check.is_sequence(kernel_size, elem_type=int) + if len(kernel_size) != pool_dim: + raise ValueError(f'kernel_size should a tuple with {pool_dim} ints, but got {len(kernel_size)}') + else: + raise TypeError(f'kernel_size should be a int or a tuple with {pool_dim} ints.') + self.kernel_size = kernel_size + + # stride + if stride is None: + stride = kernel_size + if isinstance(stride, int): + stride = (stride,) * pool_dim + elif isinstance(stride, Sequence): + check.is_sequence(stride, elem_type=int) + if len(stride) != pool_dim: + raise ValueError(f'stride should a tuple with {pool_dim} ints, but got {len(kernel_size)}') + else: + raise TypeError(f'stride should be a int or a tuple with {pool_dim} ints.') + self.stride = stride + + # padding + if isinstance(padding, str): + if padding not in ("SAME", "VALID"): + raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.") + elif isinstance(padding, int): + padding = [(padding, padding) for _ in range(pool_dim)] + elif isinstance(padding, (list, tuple)): + if isinstance(padding[0], int): + if len(padding) == pool_dim: + padding = [(x, x) for x in padding] + else: + raise ValueError(f'If padding is a sequence of ints, it ' + f'should has the length of {pool_dim}.') + else: + if not all([isinstance(x, (tuple, list)) for x in padding]): + raise ValueError(f'padding should be sequence of Tuple[int, int]. {padding}') + if not all([len(x) == 2 for x in padding]): + raise ValueError(f"Each entry in padding must be tuple of 2 ints. {padding} ") + if len(padding) == 1: + padding = tuple(padding) * pool_dim + assert len(padding) == pool_dim, f'padding should has the length of {pool_dim}. {padding}' + else: + raise ValueError + self.padding = padding + + # channel_axis + self.channel_axis = check.is_integer(channel_axis, allow_none=True) + + def update(self, x): + x = bm.as_jax(x) + x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) + if x.ndim < x_dim: + raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') + window_shape = self._infer_shape(x.ndim, self.kernel_size, 1) + stride = self._infer_shape(x.ndim, self.stride, 1) + padding = (self.padding + if isinstance(self.padding, str) else + self._infer_shape(x.ndim, self.padding, element=(0, 0))) + r = jax.lax.reduce_window(bm.as_jax(x), + init_value=self.init_value, + computation=self.computation, + window_dimensions=window_shape, + window_strides=stride, + padding=padding) + return r + + def _infer_shape(self, x_dim, inputs, element): + channel_axis = self.channel_axis + if channel_axis and not 0 <= abs(channel_axis) < x_dim: + raise ValueError(f"Invalid channel axis {channel_axis} for input with {x_dim} dimensions") + if channel_axis and channel_axis < 0: + channel_axis = x_dim + channel_axis + all_dims = list(range(x_dim)) + if channel_axis is not None: + all_dims.pop(channel_axis) + pool_dims = all_dims[-self.pool_dim:] + results = [element] * x_dim + for i, dim in enumerate(pool_dims): + results[dim] = inputs[i] + return results class MaxPool1d(_MaxPoolNd): - """Applies a 1D max pooling over an input signal composed of several input - planes. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. - - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = None, - padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(init_value=-jax.numpy.inf, - computation=jax.lax.max, - pool_dim=1, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - name=name, - mode=mode) - - -class MaxPool2d(_MaxPoolNd): - """Applies a 1D max pooling over an input signal composed of several input + """Applies a 1D max pooling over an input signal composed of several input planes. Parameters:: @@ -456,27 +412,148 @@ class MaxPool2d(_MaxPoolNd): """ - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = None, - padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(init_value=-jax.numpy.inf, - computation=jax.lax.max, - pool_dim=2, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - name=name, mode=mode) + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = None, + padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(init_value=-jax.numpy.inf, + computation=jax.lax.max, + pool_dim=1, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + name=name, + mode=mode) + + +class MaxPool2d(_MaxPoolNd): + """Applies a 1D max pooling over an input signal composed of several input + planes. + + Parameters:: + + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode: Mode + The computation mode. + name: optional, str + The object name. + + """ + + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = None, + padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(init_value=-jax.numpy.inf, + computation=jax.lax.max, + pool_dim=2, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + name=name, mode=mode) class MaxPool3d(_MaxPoolNd): - """Applies a 1D max pooling over an input signal composed of several input + """Applies a 1D max pooling over an input signal composed of several input + planes. + + Parameters:: + + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode: Mode + The computation mode. + name: optional, str + The object name. + + """ + + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = None, + padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(init_value=-jax.numpy.inf, + computation=jax.lax.max, + pool_dim=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + name=name, mode=mode) + + +class _AvgPoolNd(_MaxPoolNd): + def update(self, x): + x = bm.as_jax(x) + x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) + if x.ndim < x_dim: + raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') + dims = self._infer_shape(x.ndim, self.kernel_size, 1) + stride = self._infer_shape(x.ndim, self.stride, 1) + padding = (self.padding + if isinstance(self.padding, str) else + self._infer_shape(x.ndim, self.padding, element=(0, 0))) + pooled = jax.lax.reduce_window(bm.as_jax(x), + init_value=self.init_value, + computation=self.computation, + window_dimensions=dims, + window_strides=stride, + padding=padding) + if padding == "VALID": + # Avoid the extra reduce_window. + return pooled / np.prod(dims) + else: + # Count the number of valid entries at each input point, then use that for + # computing average. Assumes that any two arrays of same shape will be + # padded the same. + window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)), + init_value=self.init_value, + computation=self.computation, + window_dimensions=dims, + window_strides=stride, + padding=padding) + assert pooled.shape == window_counts.shape + return pooled / window_counts + + +class AvgPool1d(_AvgPoolNd): + """Applies a 1D average pooling over an input signal composed of several input planes. Parameters:: @@ -499,471 +576,394 @@ class MaxPool3d(_MaxPoolNd): """ - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = None, - padding: Union[str, int, Tuple[int], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(init_value=-jax.numpy.inf, - computation=jax.lax.max, - pool_dim=3, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - name=name, mode=mode) - + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(init_value=0., + computation=jax.lax.add, + pool_dim=1, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + name=name, + mode=mode) -class _AvgPoolNd(_MaxPoolNd): - def update(self, x): - x = bm.as_jax(x) - x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) - if x.ndim < x_dim: - raise ValueError(f'Excepted input with >= {x_dim} dimensions, but got {x.ndim}.') - dims = self._infer_shape(x.ndim, self.kernel_size, 1) - stride = self._infer_shape(x.ndim, self.stride, 1) - padding = (self.padding - if isinstance(self.padding, str) else - self._infer_shape(x.ndim, self.padding, element=(0, 0))) - pooled = jax.lax.reduce_window(bm.as_jax(x), - init_value=self.init_value, - computation=self.computation, - window_dimensions=dims, - window_strides=stride, - padding=padding) - if padding == "VALID": - # Avoid the extra reduce_window. - return pooled / np.prod(dims) - else: - # Count the number of valid entries at each input point, then use that for - # computing average. Assumes that any two arrays of same shape will be - # padded the same. - window_counts = jax.lax.reduce_window(jnp.ones_like(bm.as_jax(x)), - init_value=self.init_value, - computation=self.computation, - window_dimensions=dims, - window_strides=stride, - padding=padding) - assert pooled.shape == window_counts.shape - return pooled / window_counts +class AvgPool2d(_AvgPoolNd): + """Applies a 2D average pooling over an input signal composed of several input + planes. -class AvgPool1d(_AvgPoolNd): - """Applies a 1D average pooling over an input signal composed of several input - planes. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. - - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(init_value=0., - computation=jax.lax.add, - pool_dim=1, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - name=name, - mode=mode) + Parameters:: + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode: Mode + The computation mode. + name: optional, str + The object name. + """ -class AvgPool2d(_AvgPoolNd): - """Applies a 2D average pooling over an input signal composed of several input - planes. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(init_value=0., - computation=jax.lax.add, - pool_dim=2, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - name=name, - mode=mode) + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(init_value=0., + computation=jax.lax.add, + pool_dim=2, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + name=name, + mode=mode) class AvgPool3d(_AvgPoolNd): - """Applies a 3D average pooling over an input signal composed of several input - planes. - - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. - - """ - - def __init__( - self, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Sequence[int]] = 1, - padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", - channel_axis: Optional[int] = -1, - mode: bm.Mode = None, - name: Optional[str] = None - ): - super().__init__(init_value=0., - computation=jax.lax.add, - pool_dim=3, - kernel_size=kernel_size, - stride=stride, - padding=padding, - channel_axis=channel_axis, - name=name, - mode=mode) + """Applies a 3D average pooling over an input signal composed of several input + planes. + + Parameters:: + + kernel_size: int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride: int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding: str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode: Mode + The computation mode. + name: optional, str + The object name. + + """ + + def __init__( + self, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]] = 1, + padding: Union[str, int, Tuple[int, ...], Sequence[Tuple[int, int]]] = "VALID", + channel_axis: Optional[int] = -1, + mode: bm.Mode = None, + name: Optional[str] = None + ): + super().__init__(init_value=0., + computation=jax.lax.add, + pool_dim=3, + kernel_size=kernel_size, + stride=stride, + padding=padding, + channel_axis=channel_axis, + name=name, + mode=mode) def _adaptive_pool1d(x, target_size: int, operation: Callable): - """Adaptive pool 1D. - - Args: - x: The input. Should be a JAX array of shape `(dim,)`. - target_size: The shape of the output after the pooling operation `(target_size,)`. - operation: The pooling operation to be performed on the input array. - - Returns: - A JAX array of shape `(target_size, )`. - """ - x = bm.as_jax(x) - size = jnp.size(x) - num_head_arrays = size % target_size - num_block = size // target_size - if num_head_arrays != 0: - head_end_index = num_head_arrays * (num_block + 1) - heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1)) - tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block)) - outs = jnp.concatenate([heads, tails]) - else: - outs = jax.vmap(operation)(x.reshape(-1, num_block)) - return outs + """Adaptive pool 1D. + + Args: + x: The input. Should be a JAX array of shape `(dim,)`. + target_size: The shape of the output after the pooling operation `(target_size,)`. + operation: The pooling operation to be performed on the input array. + + Returns: + A JAX array of shape `(target_size, )`. + """ + x = bm.as_jax(x) + size = jnp.size(x) + num_head_arrays = size % target_size + num_block = size // target_size + if num_head_arrays != 0: + head_end_index = num_head_arrays * (num_block + 1) + heads = jax.vmap(operation)(x[:head_end_index].reshape(num_head_arrays, -1)) + tails = jax.vmap(operation)(x[head_end_index:].reshape(-1, num_block)) + outs = jnp.concatenate([heads, tails]) + else: + outs = jax.vmap(operation)(x.reshape(-1, num_block)) + return outs def _generate_vmap(fun: Callable, map_axes: List[int]): - map_axes = sorted(map_axes) - for axis in map_axes: - fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis) - return fun + map_axes = sorted(map_axes) + for axis in map_axes: + fun = jax.vmap(fun, in_axes=(axis, None, None), out_axes=axis) + return fun class AdaptivePool(Layer): - """General N dimensional adaptive down-sampling to a target shape. - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - num_spatial_dims: int - The number of spatial dimensions. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - operation: Callable - The down-sampling operation. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__( - self, - target_shape: Union[int, Sequence[int]], - num_spatial_dims: int, - operation: Callable, - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - self.channel_axis = channel_axis - self.operation = operation - if isinstance(target_shape, int): - self.target_shape = (target_shape,) * num_spatial_dims - elif isinstance(target_shape, Sequence) and (len(target_shape) == num_spatial_dims): - self.target_shape = target_shape - else: - raise ValueError("`target_size` must either be an int or tuple of length " - f"{num_spatial_dims} containing ints.") - - def update(self, x): - """Input-output mapping. + """General N dimensional adaptive down-sampling to a target shape. Parameters:: - x: Array - Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)` - or `(..., dim_1, dim_2)`. + target_shape: int, sequence of int + The target output shape. + num_spatial_dims: int + The number of spatial dimensions. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + operation: Callable + The down-sampling operation. + name: str + The class name. + mode: Mode + The computing mode. """ - x = bm.as_jax(x) - # channel axis - channel_axis = self.channel_axis - - if channel_axis: - if not 0 <= abs(channel_axis) < x.ndim: - raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}") - if channel_axis < 0: - channel_axis = x.ndim + channel_axis - # input dimension - if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape): - raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} " - f"dimensions (channel_axis={self.channel_axis}). " - f"But got {x.ndim} dimensions.") - # pooling dimensions - pool_dims = list(range(x.ndim)) - if channel_axis: - pool_dims.pop(channel_axis) - - # pooling - for i, di in enumerate(pool_dims[-len(self.target_shape):]): - poo_axes = [j for j in range(x.ndim) if j != di] - op = _generate_vmap(_adaptive_pool1d, poo_axes) - x = op(x, self.target_shape[i], self.operation) - return x + def __init__( + self, + target_shape: Union[int, Sequence[int]], + num_spatial_dims: int, + operation: Callable, + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + self.channel_axis = channel_axis + self.operation = operation + if isinstance(target_shape, int): + self.target_shape = (target_shape,) * num_spatial_dims + elif isinstance(target_shape, Sequence) and (len(target_shape) == num_spatial_dims): + self.target_shape = target_shape + else: + raise ValueError("`target_size` must either be an int or tuple of length " + f"{num_spatial_dims} containing ints.") + + def update(self, x): + """Input-output mapping. + + Parameters:: + + x: Array + Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)` + or `(..., dim_1, dim_2)`. + """ + x = bm.as_jax(x) + + # channel axis + channel_axis = self.channel_axis + + if channel_axis: + if not 0 <= abs(channel_axis) < x.ndim: + raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}") + if channel_axis < 0: + channel_axis = x.ndim + channel_axis + # input dimension + if (x.ndim - (0 if channel_axis is None else 1)) < len(self.target_shape): + raise ValueError(f"Invalid input dimension. Except >={len(self.target_shape)} " + f"dimensions (channel_axis={self.channel_axis}). " + f"But got {x.ndim} dimensions.") + # pooling dimensions + pool_dims = list(range(x.ndim)) + if channel_axis: + pool_dims.pop(channel_axis) + + # pooling + for i, di in enumerate(pool_dims[-len(self.target_shape):]): + poo_axes = [j for j in range(x.ndim) if j != di] + op = _generate_vmap(_adaptive_pool1d, poo_axes) + x = op(x, self.target_shape[i], self.operation) + return x class AdaptiveAvgPool1d(AdaptivePool): - """Adaptive one-dimensional average down-sampling. - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__(self, - target_shape: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(target_shape, - channel_axis=channel_axis, - num_spatial_dims=1, - operation=jnp.mean, - name=name, - mode=mode) + """Adaptive one-dimensional average down-sampling. + + Parameters:: + + target_shape: int, sequence of int + The target output shape. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name: str + The class name. + mode: Mode + The computing mode. + """ + + def __init__(self, + target_shape: Union[int, Sequence[int]], + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None): + super().__init__(target_shape, + channel_axis=channel_axis, + num_spatial_dims=1, + operation=jnp.mean, + name=name, + mode=mode) class AdaptiveAvgPool2d(AdaptivePool): - """Adaptive two-dimensional average down-sampling. - - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__(self, - target_shape: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(target_shape, - channel_axis=channel_axis, - num_spatial_dims=2, - operation=jnp.mean, - name=name, - mode=mode) + """Adaptive two-dimensional average down-sampling. + + + Parameters:: + + target_shape: int, sequence of int + The target output shape. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name: str + The class name. + mode: Mode + The computing mode. + """ + + def __init__(self, + target_shape: Union[int, Sequence[int]], + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None): + super().__init__(target_shape, + channel_axis=channel_axis, + num_spatial_dims=2, + operation=jnp.mean, + name=name, + mode=mode) class AdaptiveAvgPool3d(AdaptivePool): - """Adaptive three-dimensional average down-sampling. - - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__(self, - target_shape: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(target_shape, - channel_axis=channel_axis, - num_spatial_dims=3, - operation=jnp.mean, - name=name, - mode=mode) + """Adaptive three-dimensional average down-sampling. + + + Parameters:: + + target_shape: int, sequence of int + The target output shape. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name: str + The class name. + mode: Mode + The computing mode. + """ + + def __init__(self, + target_shape: Union[int, Sequence[int]], + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None): + super().__init__(target_shape, + channel_axis=channel_axis, + num_spatial_dims=3, + operation=jnp.mean, + name=name, + mode=mode) class AdaptiveMaxPool1d(AdaptivePool): - """Adaptive one-dimensional maximum down-sampling. - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__(self, - target_shape: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(target_shape, - channel_axis=channel_axis, - num_spatial_dims=1, - operation=jnp.max, - name=name, - mode=mode) + """Adaptive one-dimensional maximum down-sampling. + + Parameters:: + + target_shape: int, sequence of int + The target output shape. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name: str + The class name. + mode: Mode + The computing mode. + """ + + def __init__(self, + target_shape: Union[int, Sequence[int]], + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None): + super().__init__(target_shape, + channel_axis=channel_axis, + num_spatial_dims=1, + operation=jnp.max, + name=name, + mode=mode) class AdaptiveMaxPool2d(AdaptivePool): - """Adaptive two-dimensional maximum down-sampling. - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__(self, - target_shape: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(target_shape, - channel_axis=channel_axis, - num_spatial_dims=2, - operation=jnp.max, - name=name, - mode=mode) + """Adaptive two-dimensional maximum down-sampling. + + Parameters:: + + target_shape: int, sequence of int + The target output shape. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name: str + The class name. + mode: Mode + The computing mode. + """ + + def __init__(self, + target_shape: Union[int, Sequence[int]], + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None): + super().__init__(target_shape, + channel_axis=channel_axis, + num_spatial_dims=2, + operation=jnp.max, + name=name, + mode=mode) class AdaptiveMaxPool3d(AdaptivePool): - """Adaptive three-dimensional maximum down-sampling. - - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. - """ - - def __init__(self, - target_shape: Union[int, Sequence[int]], - channel_axis: Optional[int] = -1, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(target_shape, - channel_axis=channel_axis, - num_spatial_dims=3, - operation=jnp.max, - name=name, - mode=mode) + """Adaptive three-dimensional maximum down-sampling. + + Parameters:: + + target_shape: int, sequence of int + The target output shape. + channel_axis: int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name: str + The class name. + mode: Mode + The computing mode. + """ + + def __init__(self, + target_shape: Union[int, Sequence[int]], + channel_axis: Optional[int] = -1, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None): + super().__init__(target_shape, + channel_axis=channel_axis, + num_spatial_dims=3, + operation=jnp.max, + name=name, + mode=mode) diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 98befaa71..b4355bae0 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,293 +1,265 @@ from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm - class Test_Activation(parameterized.TestCase): - @parameterized.product( - inplace=[True, False] - ) - def test_Threshold(self, inplace): - bm.random.seed() - threshold_layer = bp.dnn.Threshold(5, 20, inplace) - input = bm.random.randn(2) - if inplace == True: - threshold_layer(input) - elif inplace == False: - output = threshold_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_ReLU(self, inplace): - bm.random.seed() - ReLU_layer = bp.dnn.ReLU(inplace) - input = bm.random.randn(2) - if inplace == True: - ReLU_layer(input) - elif inplace == False: - output = ReLU_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_RReLU(self, inplace): - bm.random.seed() - RReLU_layer = bp.dnn.RReLU(lower=0, upper=1, inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - RReLU_layer(input) - elif inplace == False: - output = RReLU_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_Hardtanh(self, inplace): - bm.random.seed() - Hardtanh_layer = bp.dnn.Hardtanh(min_val=0, max_val=1, inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - Hardtanh_layer(input) - elif inplace == False: - output = Hardtanh_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_ReLU6(self, inplace): - bm.random.seed() - ReLU6_layer = bp.dnn.ReLU6(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - ReLU6_layer(input) - elif inplace == False: - output = ReLU6_layer(input) - - - def test_Sigmoid(self): - bm.random.seed() - Sigmoid_layer = bp.dnn.Sigmoid() - input = bm.random.randn(2) - output = Sigmoid_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_Hardsigmoid(self, inplace): - bm.random.seed() - Hardsigmoid_layer = bp.dnn.Hardsigmoid(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - Hardsigmoid_layer(input) - elif inplace == False: - output = Hardsigmoid_layer(input) - - - def test_Tanh(self): - bm.random.seed() - Tanh_layer = bp.dnn.Tanh() - input = bm.random.randn(2) - output = Tanh_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_SiLU(self, inplace): - bm.random.seed() - SiLU_layer = bp.dnn.SiLU(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - SiLU_layer(input) - elif inplace == False: - output = SiLU_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_Mish(self, inplace): - bm.random.seed() - Mish_layer = bp.dnn.Mish(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - Mish_layer(input) - elif inplace == False: - output = Mish_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_Hardswish(self, inplace): - bm.random.seed() - Hardswish_layer = bp.dnn.Hardswish(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - Hardswish_layer(input) - elif inplace == False: - output = Hardswish_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_ELU(self, inplace): - bm.random.seed() - ELU_layer = bp.dnn.ELU(alpha=0.5, inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - ELU_layer(input) - elif inplace == False: - output = ELU_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_CELU(self, inplace): - bm.random.seed() - CELU_layer = bp.dnn.CELU(alpha=0.5, inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - CELU_layer(input) - elif inplace == False: - output = CELU_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_SELU(self, inplace): - bm.random.seed() - SELU_layer = bp.dnn.SELU(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - SELU_layer(input) - elif inplace == False: - output = SELU_layer(input) - - - def test_GLU(self): - bm.random.seed() - GLU_layer = bp.dnn.GLU() - input = bm.random.randn(4, 2) - output = GLU_layer(input) - - - @parameterized.product( - approximate=['tanh', 'none'] - ) - def test_GELU(self, approximate): - bm.random.seed() - GELU_layer = bp.dnn.GELU() - input = bm.random.randn(2) - output = GELU_layer(input) - - - def test_Hardshrink(self): - bm.random.seed() - Hardshrink_layer = bp.dnn.Hardshrink(lambd=1) - input = bm.random.randn(2) - output = Hardshrink_layer(input) - - - @parameterized.product( - inplace=[True, False] - ) - def test_LeakyReLU(self, inplace): - bm.random.seed() - LeakyReLU_layer = bp.dnn.LeakyReLU(inplace=inplace) - input = bm.random.randn(2) - if inplace == True: - LeakyReLU_layer(input) - elif inplace == False: - output = LeakyReLU_layer(input) - - - def test_LogSigmoid(self): - bm.random.seed() - LogSigmoid_layer = bp.dnn.LogSigmoid() - input = bm.random.randn(2) - output = LogSigmoid_layer(input) - - - @parameterized.product( - beta=[1, 2, 3], - threshold=[20, 21, 22] - ) - def test_Softplus(self, beta, threshold): - bm.random.seed() - Softplus_layer = bp.dnn.Softplus(beta=beta, threshold=threshold) - input = bm.random.randn(2) - output = Softplus_layer(input) - - - def test_Softshrink(self): - bm.random.seed() - Softshrink_layer = bp.dnn.Softshrink(lambd=1) - input = bm.random.randn(2) - output = Softshrink_layer(input) - - - def test_PReLU(self): - bm.random.seed() - PReLU_layer = bp.dnn.PReLU(num_parameters=2, init=0.5) - input = bm.random.randn(2) - output = PReLU_layer(input) - - - def test_Softsign(self): - bm.random.seed() - Softsign_layer = bp.dnn.Softsign() - input = bm.random.randn(2) - output = Softsign_layer(input) - - - def test_Tanhshrink(self): - bm.random.seed() - Tanhshrink_layer = bp.dnn.Tanhshrink() - input = bm.random.randn(2) - output = Tanhshrink_layer(input) - - - def test_Softmin(self): - bm.random.seed() - Softmin_layer = bp.dnn.Softmin(dim=2) - input = bm.random.randn(2, 3, 4) - output = Softmin_layer(input) - - - def test_Softmax(self): - bm.random.seed() - Softmax_layer = bp.dnn.Softmax(dim=2) - input = bm.random.randn(2, 3, 4) - output = Softmax_layer(input) - - - def test_Softmax2d(self): - bm.random.seed() - Softmax2d_layer = bp.dnn.Softmax2d() - input = bm.random.randn(2, 3, 12, 13) - output = Softmax2d_layer(input) - - - def test_LogSoftmax(self): - bm.random.seed() - LogSoftmax_layer = bp.dnn.LogSoftmax(dim=2) - input = bm.random.randn(2, 3, 4) - output = LogSoftmax_layer(input) - + @parameterized.product( + inplace=[True, False] + ) + def test_Threshold(self, inplace): + bm.random.seed() + threshold_layer = bp.dnn.Threshold(5, 20, inplace) + input = bm.random.randn(2) + if inplace == True: + threshold_layer(input) + elif inplace == False: + output = threshold_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_ReLU(self, inplace): + bm.random.seed() + ReLU_layer = bp.dnn.ReLU(inplace) + input = bm.random.randn(2) + if inplace == True: + ReLU_layer(input) + elif inplace == False: + output = ReLU_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_RReLU(self, inplace): + bm.random.seed() + RReLU_layer = bp.dnn.RReLU(lower=0, upper=1, inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + RReLU_layer(input) + elif inplace == False: + output = RReLU_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_Hardtanh(self, inplace): + bm.random.seed() + Hardtanh_layer = bp.dnn.Hardtanh(min_val=0, max_val=1, inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + Hardtanh_layer(input) + elif inplace == False: + output = Hardtanh_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_ReLU6(self, inplace): + bm.random.seed() + ReLU6_layer = bp.dnn.ReLU6(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + ReLU6_layer(input) + elif inplace == False: + output = ReLU6_layer(input) + + def test_Sigmoid(self): + bm.random.seed() + Sigmoid_layer = bp.dnn.Sigmoid() + input = bm.random.randn(2) + output = Sigmoid_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_Hardsigmoid(self, inplace): + bm.random.seed() + Hardsigmoid_layer = bp.dnn.Hardsigmoid(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + Hardsigmoid_layer(input) + elif inplace == False: + output = Hardsigmoid_layer(input) + + def test_Tanh(self): + bm.random.seed() + Tanh_layer = bp.dnn.Tanh() + input = bm.random.randn(2) + output = Tanh_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_SiLU(self, inplace): + bm.random.seed() + SiLU_layer = bp.dnn.SiLU(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + SiLU_layer(input) + elif inplace == False: + output = SiLU_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_Mish(self, inplace): + bm.random.seed() + Mish_layer = bp.dnn.Mish(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + Mish_layer(input) + elif inplace == False: + output = Mish_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_Hardswish(self, inplace): + bm.random.seed() + Hardswish_layer = bp.dnn.Hardswish(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + Hardswish_layer(input) + elif inplace == False: + output = Hardswish_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_ELU(self, inplace): + bm.random.seed() + ELU_layer = bp.dnn.ELU(alpha=0.5, inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + ELU_layer(input) + elif inplace == False: + output = ELU_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_CELU(self, inplace): + bm.random.seed() + CELU_layer = bp.dnn.CELU(alpha=0.5, inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + CELU_layer(input) + elif inplace == False: + output = CELU_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_SELU(self, inplace): + bm.random.seed() + SELU_layer = bp.dnn.SELU(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + SELU_layer(input) + elif inplace == False: + output = SELU_layer(input) + + def test_GLU(self): + bm.random.seed() + GLU_layer = bp.dnn.GLU() + input = bm.random.randn(4, 2) + output = GLU_layer(input) + + @parameterized.product( + approximate=['tanh', 'none'] + ) + def test_GELU(self, approximate): + bm.random.seed() + GELU_layer = bp.dnn.GELU() + input = bm.random.randn(2) + output = GELU_layer(input) + + def test_Hardshrink(self): + bm.random.seed() + Hardshrink_layer = bp.dnn.Hardshrink(lambd=1) + input = bm.random.randn(2) + output = Hardshrink_layer(input) + + @parameterized.product( + inplace=[True, False] + ) + def test_LeakyReLU(self, inplace): + bm.random.seed() + LeakyReLU_layer = bp.dnn.LeakyReLU(inplace=inplace) + input = bm.random.randn(2) + if inplace == True: + LeakyReLU_layer(input) + elif inplace == False: + output = LeakyReLU_layer(input) + + def test_LogSigmoid(self): + bm.random.seed() + LogSigmoid_layer = bp.dnn.LogSigmoid() + input = bm.random.randn(2) + output = LogSigmoid_layer(input) + + @parameterized.product( + beta=[1, 2, 3], + threshold=[20, 21, 22] + ) + def test_Softplus(self, beta, threshold): + bm.random.seed() + Softplus_layer = bp.dnn.Softplus(beta=beta, threshold=threshold) + input = bm.random.randn(2) + output = Softplus_layer(input) + + def test_Softshrink(self): + bm.random.seed() + Softshrink_layer = bp.dnn.Softshrink(lambd=1) + input = bm.random.randn(2) + output = Softshrink_layer(input) + + def test_PReLU(self): + bm.random.seed() + PReLU_layer = bp.dnn.PReLU(num_parameters=2, init=0.5) + input = bm.random.randn(2) + output = PReLU_layer(input) + + def test_Softsign(self): + bm.random.seed() + Softsign_layer = bp.dnn.Softsign() + input = bm.random.randn(2) + output = Softsign_layer(input) + + def test_Tanhshrink(self): + bm.random.seed() + Tanhshrink_layer = bp.dnn.Tanhshrink() + input = bm.random.randn(2) + output = Tanhshrink_layer(input) + + def test_Softmin(self): + bm.random.seed() + Softmin_layer = bp.dnn.Softmin(dim=2) + input = bm.random.randn(2, 3, 4) + output = Softmin_layer(input) + + def test_Softmax(self): + bm.random.seed() + Softmax_layer = bp.dnn.Softmax(dim=2) + input = bm.random.randn(2, 3, 4) + output = Softmax_layer(input) + + def test_Softmax2d(self): + bm.random.seed() + Softmax2d_layer = bp.dnn.Softmax2d() + input = bm.random.randn(2, 3, 12, 13) + output = Softmax2d_layer(input) + + def test_LogSoftmax(self): + bm.random.seed() + LogSoftmax_layer = bp.dnn.LogSoftmax(dim=2) + input = bm.random.randn(2, 3, 4) + output = LogSoftmax_layer(input) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 29b5128eb..a9178eb26 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -10,267 +10,254 @@ import brainpy.math as bm if platform.system() == 'Darwin': - pytest.skip('skip Mac OS', allow_module_level=True) - + pytest.skip('skip Mac OS', allow_module_level=True) + class TestConv(parameterized.TestCase): - def test_Conv2D_img(self): - img = jnp.zeros((2, 200, 198, 4)) - for k in range(4): - x = 30 + 60 * k - y = 20 + 60 * k - img = img.at[0, x:x + 10, y:y + 10, k].set(1.0) - img = img.at[1, x:x + 20, y:y + 20, k].set(3.0) + def test_Conv2D_img(self): + img = jnp.zeros((2, 200, 198, 4)) + for k in range(4): + x = 30 + 60 * k + y = 20 + 60 * k + img = img.at[0, x:x + 10, y:y + 10, k].set(1.0) + img = img.at[1, x:x + 20, y:y + 20, k].set(3.0) - with bp.math.training_environment(): - net = bp.layers.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3), - strides=(2, 1), padding='VALID', groups=4) - out = net(img) - print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 99, 196, 32)) - # print("First output channel:") - # plt.figure(figsize=(10, 10)) - # plt.imshow(np.array(img)[0, :, :, 0]) - # plt.show() - + with bp.math.training_environment(): + net = bp.layers.Conv2d(in_channels=4, out_channels=32, kernel_size=(3, 3), + strides=(2, 1), padding='VALID', groups=4) + out = net(img) + print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 99, 196, 32)) + # print("First output channel:") + # plt.figure(figsize=(10, 10)) + # plt.imshow(np.array(img)[0, :, :, 0]) + # plt.show() - def test_conv1D(self): - with bp.math.training_environment(): - model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) + def test_conv1D(self): + with bp.math.training_environment(): + model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) - input = bp.math.ones((2, 5, 3)) + input = bp.math.ones((2, 5, 3)) - out = model(input) - print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 32)) - # print("First output channel:") - # plt.figure(figsize=(10, 10)) - # plt.imshow(np.array(out)[0, :, :]) - # plt.show() - + out = model(input) + print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 32)) + # print("First output channel:") + # plt.figure(figsize=(10, 10)) + # plt.imshow(np.array(out)[0, :, :]) + # plt.show() - def test_conv2D(self): - bm.random.seed() - with bp.math.training_environment(): - model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3)) + def test_conv2D(self): + bm.random.seed() + with bp.math.training_environment(): + model = bp.layers.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3)) - input = bp.math.ones((2, 5, 5, 3)) + input = bp.math.ones((2, 5, 5, 3)) - out = model(input) - print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 32)) - # print("First output channel:") - # plt.figure(figsize=(10, 10)) - # plt.imshow(np.array(out)[0, :, :, 31]) - # plt.show() - + out = model(input) + print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 32)) + # print("First output channel:") + # plt.figure(figsize=(10, 10)) + # plt.imshow(np.array(out)[0, :, :, 31]) + # plt.show() - def test_conv3D(self): - bm.random.seed() - with bp.math.training_environment(): - model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3)) - input = bp.math.ones((2, 5, 5, 5, 3)) - out = model(input) - print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 5, 32)) - + def test_conv3D(self): + bm.random.seed() + with bp.math.training_environment(): + model = bp.layers.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3)) + input = bp.math.ones((2, 5, 5, 5, 3)) + out = model(input) + print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 5, 32)) class TestConvTranspose1d(parameterized.TestCase): - def test_conv_transpose(self): - bm.random.seed() - x = bm.ones((1, 8, 3)) - for use_bias in [True, False]: - conv_transpose_module = bp.layers.ConvTranspose1d( - in_channels=3, - out_channels=4, - kernel_size=(3,), - padding='VALID', - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit() if use_bias else None, - mode=bm.training_mode - ) - self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4)) - y = conv_transpose_module(x) - print(y.shape) - correct_ans = jnp.array([[[4., 4., 4., 4.], - [7., 7., 7., 7.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [10., 10., 10., 10.], - [7., 7., 7., 7.], - [4., 4., 4., 4.]]]) - if not use_bias: - correct_ans -= 1. - self.assertTrue(bm.allclose(y, correct_ans)) - + def test_conv_transpose(self): + bm.random.seed() + x = bm.ones((1, 8, 3)) + for use_bias in [True, False]: + conv_transpose_module = bp.layers.ConvTranspose1d( + in_channels=3, + out_channels=4, + kernel_size=(3,), + padding='VALID', + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit() if use_bias else None, + mode=bm.training_mode + ) + self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4)) + y = conv_transpose_module(x) + print(y.shape) + correct_ans = jnp.array([[[4., 4., 4., 4.], + [7., 7., 7., 7.], + [10., 10., 10., 10.], + [10., 10., 10., 10.], + [10., 10., 10., 10.], + [10., 10., 10., 10.], + [10., 10., 10., 10.], + [10., 10., 10., 10.], + [7., 7., 7., 7.], + [4., 4., 4., 4.]]]) + if not use_bias: + correct_ans -= 1. + self.assertTrue(bm.allclose(y, correct_ans)) - def test_single_input_masked_conv_transpose(self): - bm.random.seed() - x = jnp.ones((1, 8, 3)) - m = jnp.tril(jnp.ones((3, 3, 4))) - conv_transpose_module = bp.layers.ConvTranspose1d( - in_channels=3, - out_channels=4, - kernel_size=(3,), - padding='VALID', - mask=m, - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit(), - mode=bm.batching_mode - ) - self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4)) - y = conv_transpose_module(x) - print(y.shape) - correct_ans = jnp.array([[[4., 3., 2., 1.], - [7., 5., 3., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [10., 7., 4., 1.], - [7., 5., 3., 1.], - [4., 3., 2., 1.]]]) - self.assertTrue(bm.allclose(y, correct_ans)) - + def test_single_input_masked_conv_transpose(self): + bm.random.seed() + x = jnp.ones((1, 8, 3)) + m = jnp.tril(jnp.ones((3, 3, 4))) + conv_transpose_module = bp.layers.ConvTranspose1d( + in_channels=3, + out_channels=4, + kernel_size=(3,), + padding='VALID', + mask=m, + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit(), + mode=bm.batching_mode + ) + self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4)) + y = conv_transpose_module(x) + print(y.shape) + correct_ans = jnp.array([[[4., 3., 2., 1.], + [7., 5., 3., 1.], + [10., 7., 4., 1.], + [10., 7., 4., 1.], + [10., 7., 4., 1.], + [10., 7., 4., 1.], + [10., 7., 4., 1.], + [10., 7., 4., 1.], + [7., 5., 3., 1.], + [4., 3., 2., 1.]]]) + self.assertTrue(bm.allclose(y, correct_ans)) - def test_computation_padding_same(self): - bm.random.seed() - data = jnp.ones([1, 3, 1]) - for use_bias in [True, False]: - net = bp.layers.ConvTranspose1d( - in_channels=1, - out_channels=1, - kernel_size=3, - stride=1, - padding="SAME", - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit() if use_bias else None, - mode=bm.batching_mode - ) - out = net(data) - self.assertEqual(out.shape, (1, 3, 1)) - out = jnp.squeeze(out, axis=(0, 2)) - expected_out = bm.as_jax([2, 3, 2]) - if use_bias: - expected_out += 1 - self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5)) - + def test_computation_padding_same(self): + bm.random.seed() + data = jnp.ones([1, 3, 1]) + for use_bias in [True, False]: + net = bp.layers.ConvTranspose1d( + in_channels=1, + out_channels=1, + kernel_size=3, + stride=1, + padding="SAME", + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit() if use_bias else None, + mode=bm.batching_mode + ) + out = net(data) + self.assertEqual(out.shape, (1, 3, 1)) + out = jnp.squeeze(out, axis=(0, 2)) + expected_out = bm.as_jax([2, 3, 2]) + if use_bias: + expected_out += 1 + self.assertTrue(bm.allclose(out, expected_out, rtol=1e-5)) class TestConvTranspose2d(parameterized.TestCase): - def test_conv_transpose(self): - bm.random.seed() - x = bm.ones((1, 8, 8, 3)) - for use_bias in [True, False]: - conv_transpose_module = bp.layers.ConvTranspose2d( - in_channels=3, - out_channels=4, - kernel_size=(3, 3), - padding='VALID', - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit() if use_bias else None, - mode=bm.training_mode - ) - self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4)) - y = conv_transpose_module(x) - print(y.shape) - + def test_conv_transpose(self): + bm.random.seed() + x = bm.ones((1, 8, 8, 3)) + for use_bias in [True, False]: + conv_transpose_module = bp.layers.ConvTranspose2d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3), + padding='VALID', + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit() if use_bias else None, + mode=bm.training_mode + ) + self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4)) + y = conv_transpose_module(x) + print(y.shape) - def test_single_input_masked_conv_transpose(self): - bm.random.seed() - x = jnp.ones((1, 8, 8, 3)) - m = jnp.tril(jnp.ones((3, 3, 3, 4))) - conv_transpose_module = bp.layers.ConvTranspose2d( - in_channels=3, - out_channels=4, - kernel_size=(3, 3), - padding='VALID', - mask=m, - w_initializer=bp.init.OneInit(), - mode=bm.training_mode - ) - y = conv_transpose_module(x) - print(y.shape) - + def test_single_input_masked_conv_transpose(self): + bm.random.seed() + x = jnp.ones((1, 8, 8, 3)) + m = jnp.tril(jnp.ones((3, 3, 3, 4))) + conv_transpose_module = bp.layers.ConvTranspose2d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3), + padding='VALID', + mask=m, + w_initializer=bp.init.OneInit(), + mode=bm.training_mode + ) + y = conv_transpose_module(x) + print(y.shape) - def test_computation_padding_same(self): - bm.random.seed() - x = bm.ones((1, 8, 8, 3)) - for use_bias in [True, False]: - conv_transpose_module = bp.layers.ConvTranspose2d( - in_channels=3, - out_channels=4, - kernel_size=(3, 3), - stride=1, - padding='SAME', - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit() if use_bias else None, - mode=bm.training_mode, - # mode=bm.nonbatching_mode, - ) - y = conv_transpose_module(x) - print(y.shape) - + def test_computation_padding_same(self): + bm.random.seed() + x = bm.ones((1, 8, 8, 3)) + for use_bias in [True, False]: + conv_transpose_module = bp.layers.ConvTranspose2d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3), + stride=1, + padding='SAME', + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit() if use_bias else None, + mode=bm.training_mode, + # mode=bm.nonbatching_mode, + ) + y = conv_transpose_module(x) + print(y.shape) class TestConvTranspose3d(parameterized.TestCase): - def test_conv_transpose(self): - bm.random.seed() - x = bm.ones((1, 8, 8, 8, 3)) - for use_bias in [True, False]: - conv_transpose_module = bp.layers.ConvTranspose3d( - in_channels=3, - out_channels=4, - kernel_size=(3, 3, 3), - padding='VALID', - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit() if use_bias else None, - mode=bm.training_mode - ) - y = conv_transpose_module(x) - print(y.shape) - + def test_conv_transpose(self): + bm.random.seed() + x = bm.ones((1, 8, 8, 8, 3)) + for use_bias in [True, False]: + conv_transpose_module = bp.layers.ConvTranspose3d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3, 3), + padding='VALID', + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit() if use_bias else None, + mode=bm.training_mode + ) + y = conv_transpose_module(x) + print(y.shape) - def test_single_input_masked_conv_transpose(self): - bm.random.seed() - x = jnp.ones((1, 8, 8, 8, 3)) - m = jnp.tril(jnp.ones((3, 3, 3, 3, 4))) - conv_transpose_module = bp.layers.ConvTranspose3d( - in_channels=3, - out_channels=4, - kernel_size=(3, 3, 3), - padding='VALID', - mask=m, - w_initializer=bp.init.OneInit(), - mode=bm.training_mode - ) - y = conv_transpose_module(x) - print(y.shape) - + def test_single_input_masked_conv_transpose(self): + bm.random.seed() + x = jnp.ones((1, 8, 8, 8, 3)) + m = jnp.tril(jnp.ones((3, 3, 3, 3, 4))) + conv_transpose_module = bp.layers.ConvTranspose3d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3, 3), + padding='VALID', + mask=m, + w_initializer=bp.init.OneInit(), + mode=bm.training_mode + ) + y = conv_transpose_module(x) + print(y.shape) - def test_computation_padding_same(self): - bm.random.seed() - x = bm.ones((1, 8, 8, 8, 3)) - for use_bias in [True, False]: - conv_transpose_module = bp.layers.ConvTranspose3d( - in_channels=3, - out_channels=4, - kernel_size=(3, 3, 3), - stride=1, - padding='SAME', - w_initializer=bp.init.OneInit(), - b_initializer=bp.init.OneInit() if use_bias else None, - mode=bm.training_mode - ) - y = conv_transpose_module(x) - print(y.shape) - + def test_computation_padding_same(self): + bm.random.seed() + x = bm.ones((1, 8, 8, 8, 3)) + for use_bias in [True, False]: + conv_transpose_module = bp.layers.ConvTranspose3d( + in_channels=3, + out_channels=4, + kernel_size=(3, 3, 3), + stride=1, + padding='SAME', + w_initializer=bp.init.OneInit(), + b_initializer=bp.init.OneInit() if use_bias else None, + mode=bm.training_mode + ) + y = conv_transpose_module(x) + print(y.shape) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_flax.py b/brainpy/_src/dnn/tests/test_flax.py index b452d7829..4e8ef1034 100644 --- a/brainpy/_src/dnn/tests/test_flax.py +++ b/brainpy/_src/dnn/tests/test_flax.py @@ -33,12 +33,12 @@ class myRNN(nn.Module): - @nn.compact - def __call__(self, x): # x:(batch, time, features) - x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell - return x + @nn.compact + def __call__(self, x): # x:(batch, time, features) + x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell + return x def test_init(): - model = myRNN() - model.init(jax.random.PRNGKey(0), jnp.ones([1, 10, 1])) # batch,time,feature + model = myRNN() + model.init(jax.random.PRNGKey(0), jnp.ones([1, 10, 1])) # batch,time,feature diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 7e645224e..2cf2c6b23 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -9,37 +9,34 @@ class TestFunction(parameterized.TestCase): - def test_flatten_batching_mode(self): - bm.random.seed() - layer = bp.dnn.Flatten(mode=bm.BatchingMode()) - input = bm.random.randn(20, 10, 10, 6) + def test_flatten_batching_mode(self): + bm.random.seed() + layer = bp.dnn.Flatten(mode=bm.BatchingMode()) + input = bm.random.randn(20, 10, 10, 6) - output = layer.update(input) + output = layer.update(input) - expected_shape = (20, 600) - self.assertEqual(output.shape, expected_shape) - + expected_shape = (20, 600) + self.assertEqual(output.shape, expected_shape) - def test_flatten_non_batching_mode(self): - bm.random.seed() - layer = bp.dnn.Flatten(mode=bm.NonBatchingMode()) - input = bm.random.randn(10, 10, 6) + def test_flatten_non_batching_mode(self): + bm.random.seed() + layer = bp.dnn.Flatten(mode=bm.NonBatchingMode()) + input = bm.random.randn(10, 10, 6) - output = layer.update(input) + output = layer.update(input) - expected_shape = (600,) - self.assertEqual(output.shape, expected_shape) - + expected_shape = (600,) + self.assertEqual(output.shape, expected_shape) - def test_unflatten(self): - bm.random.seed() - layer = bp.dnn.Unflatten(1, (10, 6), mode=bm.NonBatchingMode()) - input = bm.random.randn(5, 60) - output = layer.update(input) - expected_shape = (5, 10, 6) - self.assertEqual(output.shape, expected_shape) - + def test_unflatten(self): + bm.random.seed() + layer = bp.dnn.Unflatten(1, (10, 6), mode=bm.NonBatchingMode()) + input = bm.random.randn(5, 60) + output = layer.update(input) + expected_shape = (5, 10, 6) + self.assertEqual(output.shape, expected_shape) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index f951d1002..b79890a7a 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,243 +1,228 @@ -import pytest +import jax.numpy as jnp from absl.testing import absltest from absl.testing import parameterized -import jax.numpy as jnp import brainpy as bp import brainpy.math as bm - class TestLinear(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bm.random.seed() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - num_out=[20,] - ) - def test_Dense1(self, size, num_out): - bm.random.seed() - f = bp.dnn.Linear(10, num_out) - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size[:-1] + (num_out,)) - - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - ) - def test_Identity(self, size): - bm.random.seed() - f = bp.dnn.Identity() - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size) - - - def test_AllToAll1(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((8, 10)) - y = f(x) - expected = bm.sum(x, axis=1, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((10,)) - y = f(x) - expected = bm.sum(x, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - - def test_OneToOne(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((8, 10)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((10,)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - - @parameterized.product( - conn=[ - # bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_MaskedLinear(self, conn): - bm.random.seed() - bm.random.DEFAULT.seed(123) - f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_CSRLinear(self, conn): - bm.random.seed() - f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(jnp.asarray(x)) - self.assertTrue(y.shape == (16, 100)) - - x = bm.random.random((100,)) - y = f(jnp.asarray(x)) - self.assertTrue(y.shape == (100,)) - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_EventCSRLinear(self, conn): - bm.random.seed() - f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(jnp.asarray(x)) - self.assertTrue(y.shape == (16, 100)) - x = bm.random.random((100,)) - y = f(jnp.asarray(x)) - self.assertTrue(y.shape == (100,)) - - - @parameterized.product( - prob=[0.1], - weight=[0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - - conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) - # print(conn_matrix.shape) - # self.assertTrue(conn_matrix.shape == (200, 100)) - - - @parameterized.product( - prob=[0.1], - w_low=[-0.01, ], - w_high=[0.01, ], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - - conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) - - - @parameterized.product( - prob=[0.1], - w_mu=[-0.01], - w_sigma=[0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - - conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) - - - @parameterized.product( - prob=[0.1], - weight=[0.01,], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) < 0.1 - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - - conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) - - - @parameterized.product( - prob=[0.1], - w_low=[-0.01], - w_high=[0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) < 0.1 - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - - conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) - - - @parameterized.product( - prob=[0.1], - w_mu=[-0.01], - w_sigma=[0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) < 0.1 - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - - conn_matrix = f.get_conn_matrix() - self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) - + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bm.random.seed() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + num_out=[20, ] + ) + def test_Dense1(self, size, num_out): + bm.random.seed() + f = bp.dnn.Linear(10, num_out) + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size[:-1] + (num_out,)) + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + ) + def test_Identity(self, size): + bm.random.seed() + f = bp.dnn.Identity() + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size) + + def test_AllToAll1(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((8, 10)) + y = f(x) + expected = bm.sum(x, axis=1, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((10,)) + y = f(x) + expected = bm.sum(x, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + def test_OneToOne(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((8, 10)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((10,)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + @parameterized.product( + conn=[ + # bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_MaskedLinear(self, conn): + bm.random.seed() + bm.random.DEFAULT.seed(123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear(self, conn): + bm.random.seed() + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(jnp.asarray(x)) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(jnp.asarray(x)) + self.assertTrue(y.shape == (100,)) + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_EventCSRLinear(self, conn): + bm.random.seed() + f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(jnp.asarray(x)) + self.assertTrue(y.shape == (16, 100)) + x = bm.random.random((100,)) + y = f(jnp.asarray(x)) + self.assertTrue(y.shape == (100,)) + + @parameterized.product( + prob=[0.1], + weight=[0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + # print(conn_matrix.shape) + # self.assertTrue(conn_matrix.shape == (200, 100)) + + @parameterized.product( + prob=[0.1], + w_low=[-0.01, ], + w_high=[0.01, ], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + + @parameterized.product( + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + + @parameterized.product( + prob=[0.1], + weight=[0.01, ], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + + @parameterized.product( + prob=[0.1], + w_low=[-0.01], + w_high=[0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) + + @parameterized.product( + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) < 0.1 + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + + conn_matrix = f.get_conn_matrix() + self.assertTrue(bm.allclose(y, x @ conn_matrix.T)) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index c6eff3df8..aa52a1baa 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,4 +1,3 @@ -import pytest from absl.testing import absltest from absl.testing import parameterized @@ -7,795 +6,783 @@ class Test_Conv(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode) - output = layer(input) - - - def test_Conv1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode) - output = layer(input) - - - def test_Conv2_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode) - output = layer(input) - - - def test_Conv3_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode - ) - output = layer(input) - - - def test_ConvTranspose1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode - ) - output = layer(input) - - - def test_ConvTranspose2d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode - ) - output = layer(input) - - - def test_ConvTranspose3d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode) + output = layer(input) + + def test_Conv1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode) + output = layer(input) + + def test_Conv2_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode) + output = layer(input) + + def test_Conv3_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode + ) + output = layer(input) + + def test_ConvTranspose1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode + ) + output = layer(input) + + def test_ConvTranspose2d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode + ) + output = layer(input) + + def test_ConvTranspose3d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) class TestPool(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MinPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AvgPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.MaxPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MinPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AvgPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.MaxPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) class Test_Dropout(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Dropout(self, mode): - bp.share.save(fit=False) - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.Dropout(prob=0.2, - mode=mode) - output = layer(input) + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Dropout(self, mode): + bp.share.save(fit=False) + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.Dropout(prob=0.2, + mode=mode) + output = layer(input) class Test_function(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Flatten(self, mode): - bm.random.seed() - layer = bp.dnn.Flatten(mode=mode) - input = bm.random.randn(10, 5, 5, 5, 4) - output = layer(input) + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Flatten(self, mode): + bm.random.seed() + layer = bp.dnn.Flatten(mode=mode) + input = bm.random.randn(10, 5, 5, 5, 4) + output = layer(input) class Test_linear(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_linear(self, mode): - bm.random.seed() - input = bm.random.randn(10, 9, 8, 7) - layer = bp.dnn.Linear(num_in=7, - num_out=6, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AllToAll(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.AllToAll(num_pre=10, - num_post=20, - weight=0.1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_OneToOne(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.OneToOne(num=10, - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaskedLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_linear(self, mode): + bm.random.seed() + input = bm.random.randn(10, 9, 8, 7) + layer = bp.dnn.Linear(num_in=7, + num_out=6, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AllToAll(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.AllToAll(num_pre=10, + num_post=20, weight=0.1, mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_CSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventCSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_OneToOne(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.OneToOne(num=10, + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaskedLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_CSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventCSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPNormalLinear(num_in=100, + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPUniformLinear(num_in=100, num_out=200, prob=0.1, - w_mu=-0.01, - w_sigma=0.01, + w_low=-0.01, + w_high=0.01, seed=100, mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) class Test_Normalization(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm1d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm1d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm2d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm2d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm3d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm3d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 7, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_LayerNorm(self, mode): - bm.random.seed() - layer = bp.dnn.LayerNorm(normalized_shape=3, - mode=mode, - elementwise_affine=False - ) - input = bm.random.randn(10, 5, 3) - outout = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_GroupNorm(self, mode): - bm.random.seed() - layer = bp.dnn.GroupNorm(num_groups=2, - num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_InstanceNorm(self, mode): - bm.random.seed() - layer = bp.dnn.InstanceNorm(num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm1d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm1d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm2d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm2d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm3d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm3d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 7, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_LayerNorm(self, mode): + bm.random.seed() + layer = bp.dnn.LayerNorm(normalized_shape=3, + mode=mode, + elementwise_affine=False + ) + input = bm.random.randn(10, 5, 3) + outout = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_GroupNorm(self, mode): + bm.random.seed() + layer = bp.dnn.GroupNorm(num_groups=2, + num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_InstanceNorm(self, mode): + bm.random.seed() + layer = bp.dnn.InstanceNorm(num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index b61c7d15c..585cb225a 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,71 +1,65 @@ -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm class Test_Normalization(parameterized.TestCase): - @parameterized.product( - fit=[True, False], - ) - def test_BatchNorm1d(self, fit): - bm.random.seed() - net = bp.dnn.BatchNorm1d(num_features=10, mode=bm.training_mode) - bp.share.save(fit=fit) - input = bm.random.randn(1, 3, 10) - output = net(input) - + @parameterized.product( + fit=[True, False], + ) + def test_BatchNorm1d(self, fit): + bm.random.seed() + net = bp.dnn.BatchNorm1d(num_features=10, mode=bm.training_mode) + bp.share.save(fit=fit) + input = bm.random.randn(1, 3, 10) + output = net(input) - @parameterized.product( - fit=[True, False] - ) - def test_BatchNorm2d(self, fit): - bm.random.seed() - net = bp.dnn.BatchNorm2d(10, mode=bm.training_mode) - bp.share.save(fit=fit) - input = bm.random.randn(1, 3, 4, 10) - output = net(input) - + @parameterized.product( + fit=[True, False] + ) + def test_BatchNorm2d(self, fit): + bm.random.seed() + net = bp.dnn.BatchNorm2d(10, mode=bm.training_mode) + bp.share.save(fit=fit) + input = bm.random.randn(1, 3, 4, 10) + output = net(input) - @parameterized.product( - fit=[True, False] - ) - def test_BatchNorm3d(self, fit): - bm.random.seed() - net = bp.dnn.BatchNorm3d(10, mode=bm.training_mode) - bp.share.save(fit=fit) - input = bm.random.randn(1, 3, 4, 5, 10) - output = net(input) - + @parameterized.product( + fit=[True, False] + ) + def test_BatchNorm3d(self, fit): + bm.random.seed() + net = bp.dnn.BatchNorm3d(10, mode=bm.training_mode) + bp.share.save(fit=fit) + input = bm.random.randn(1, 3, 4, 5, 10) + output = net(input) - @parameterized.product( - normalized_shape=(10, [5, 10]) - ) - def test_LayerNorm(self, normalized_shape): - bm.random.seed() - net = bp.dnn.LayerNorm(normalized_shape, mode=bm.training_mode) - input = bm.random.randn(20, 5, 10) - output = net(input) - + @parameterized.product( + normalized_shape=(10, [5, 10]) + ) + def test_LayerNorm(self, normalized_shape): + bm.random.seed() + net = bp.dnn.LayerNorm(normalized_shape, mode=bm.training_mode) + input = bm.random.randn(20, 5, 10) + output = net(input) - @parameterized.product( - num_groups=[1, 2, 3, 6] - ) - def test_GroupNorm(self, num_groups): - bm.random.seed() - input = bm.random.randn(20, 10, 10, 6) - net = bp.dnn.GroupNorm(num_groups=num_groups, num_channels=6, mode=bm.training_mode) - output = net(input) - + @parameterized.product( + num_groups=[1, 2, 3, 6] + ) + def test_GroupNorm(self, num_groups): + bm.random.seed() + input = bm.random.randn(20, 10, 10, 6) + net = bp.dnn.GroupNorm(num_groups=num_groups, num_channels=6, mode=bm.training_mode) + output = net(input) - def test_InstanceNorm(self): - bm.random.seed() - input = bm.random.randn(20, 10, 10, 6) - net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode) - output = net(input) - + def test_InstanceNorm(self): + bm.random.seed() + input = bm.random.randn(20, 10, 10, 6) + net = bp.dnn.InstanceNorm(num_channels=6, mode=bm.training_mode) + output = net(input) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 7febe71c7..56522a03b 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -11,241 +11,224 @@ class TestPool(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def test_maxpool(self): - bm.random.seed() - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - print(jnp.arange(9).reshape(3, 3)) - print(x) - print(x.shape) - shared = {'fit': False} - with bm.training_environment(): - net = bp.dnn.MaxPool((2, 2), 1, channel_axis=-1) - y = net(shared, x) - print("out shape: ", y.shape) - expected_y = jnp.array([[4., 5.], - [7., 8.]]).reshape((1, 2, 2, 1)) - np.testing.assert_allclose(y, expected_y) - - - def test_maxpool2(self): - bm.random.seed() - x = bm.random.rand(10, 20, 20, 4) - with bm.training_environment(): - net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1) - y = net(x) - print("out shape: ", y.shape) - - - def test_minpool(self): - bm.random.seed() - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - shared = {'fit': False} - with bm.training_environment(): - net = bp.dnn.MinPool((2, 2), 1, channel_axis=-1) - y = net(shared, x) - print("out shape: ", y.shape) - expected_y = jnp.array([ - [0., 1.], - [3., 4.], - ]).reshape((1, 2, 2, 1)) - np.testing.assert_allclose(y, expected_y) - - - def test_avgpool(self): - bm.random.seed() - x = jnp.full((1, 3, 3, 1), 2.) - with bm.training_environment(): - net = bp.dnn.AvgPool((2, 2), 1, channel_axis=-1) - y = net(x) - print("out shape: ", y.shape) - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.)) - - - def test_MaxPool2d_v1(self): - bm.random.seed() - arr = bm.random.rand(16, 32, 32, 8) - - out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr) - self.assertTrue(out.shape == (16, 16, 16, 8)) - - out = bp.dnn.MaxPool2d(2, 2, channel_axis=None)(arr) - self.assertTrue(out.shape == (16, 32, 16, 4)) - - out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr) - self.assertTrue(out.shape == (16, 32, 17, 5)) - - out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) - self.assertTrue(out.shape == (16, 32, 18, 5)) - - out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) - self.assertTrue(out.shape == (16, 17, 17, 8)) - - out = bp.dnn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) - self.assertTrue(out.shape == (16, 17, 32, 5)) - - - def test_AvgPool2d_v1(self): - bm.random.seed() - arr = bm.random.rand(16, 32, 32, 8) - - out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr) - self.assertTrue(out.shape == (16, 16, 16, 8)) - - out = bp.dnn.AvgPool2d(2, 2, channel_axis=None)(arr) - self.assertTrue(out.shape == (16, 32, 16, 4)) - - out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr) - self.assertTrue(out.shape == (16, 32, 17, 5)) - - out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) - self.assertTrue(out.shape == (16, 32, 18, 5)) - - out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) - self.assertTrue(out.shape == (16, 17, 17, 8)) - - out = bp.dnn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) - self.assertTrue(out.shape == (16, 17, 32, 5)) - - - @parameterized.named_parameters( - dict(testcase_name=f'target_size={target_size}', - target_size=target_size) - for target_size in [10, 9, 8, 7, 6] - ) - def test_adaptive_pool1d(self, target_size): - bm.random.seed() - from brainpy._src.dnn.pooling import _adaptive_pool1d - - arr = bm.random.rand(100) - op = jax.numpy.mean - - out = _adaptive_pool1d(arr, target_size, op) - print(out.shape) - self.assertTrue(out.shape == (target_size,)) - - out = _adaptive_pool1d(arr, target_size, op) - print(out.shape) - self.assertTrue(out.shape == (target_size,)) - - - def test_AdaptiveAvgPool2d_v1(self): - bm.random.seed() - input = bm.random.randn(64, 8, 9) - - output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) - self.assertTrue(output.shape == (64, 5, 7)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) - self.assertTrue(output.shape == (64, 2, 3)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) - self.assertTrue(output.shape == (2, 3, 9)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) - self.assertTrue(output.shape == (2, 8, 3)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input) - self.assertTrue(output.shape == (64, 2, 3)) - - - def test_AdaptiveAvgPool2d_v2(self): - bm.random.seed() - input = bm.random.randn(128, 64, 32, 16) - - output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) - self.assertTrue(output.shape == (128, 64, 5, 7)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) - self.assertTrue(output.shape == (128, 64, 2, 3)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) - self.assertTrue(output.shape == (128, 2, 3, 16)) - - output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) - self.assertTrue(output.shape == (128, 64, 2, 3)) - print() - - - def test_AdaptiveAvgPool3d_v1(self): - bm.random.seed() - input = bm.random.randn(10, 128, 64, 32) - net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], channel_axis=0, mode=bm.nonbatching_mode) - output = net(input) - self.assertTrue(output.shape == (10, 6, 5, 3)) - - - def test_AdaptiveAvgPool3d_v2(self): - bm.random.seed() - input = bm.random.randn(10, 20, 128, 64, 32) - net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], mode=bm.batching_mode) - output = net(input) - self.assertTrue(output.shape == (10, 6, 5, 3, 32)) - - - @parameterized.product( - axis=(-1, 0, 1) - ) - def test_AdaptiveMaxPool1d_v1(self, axis): - bm.random.seed() - input = bm.random.randn(32, 16) - net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis) - output = net(input) - - - @parameterized.product( - axis=(-1, 0, 1, 2) - ) - def test_AdaptiveMaxPool1d_v2(self, axis): - bm.random.seed() - input = bm.random.randn(2, 32, 16) - net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis) - output = net(input) - - - @parameterized.product( - axis=(-1, 0, 1, 2) - ) - def test_AdaptiveMaxPool2d_v1(self, axis): - bm.random.seed() - input = bm.random.randn(32, 16, 12) - net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis) - output = net(input) - - - @parameterized.product( - axis=(-1, 0, 1, 2, 3) - ) - def test_AdaptiveMaxPool2d_v2(self, axis): - bm.random.seed() - input = bm.random.randn(2, 32, 16, 12) - net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis) - # output = net(input) - - - @parameterized.product( - axis=(-1, 0, 1, 2, 3) - ) - def test_AdaptiveMaxPool3d_v1(self, axis): - bm.random.seed() - input = bm.random.randn(2, 128, 64, 32) - net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis) - output = net(input) - print() - - - @parameterized.product( - axis=(-1, 0, 1, 2, 3, 4) - ) - def test_AdaptiveMaxPool3d_v1(self, axis): - bm.random.seed() - input = bm.random.randn(2, 128, 64, 32, 16) - net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis) - output = net(input) - + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def test_maxpool(self): + bm.random.seed() + x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) + print(jnp.arange(9).reshape(3, 3)) + print(x) + print(x.shape) + shared = {'fit': False} + with bm.training_environment(): + net = bp.dnn.MaxPool((2, 2), 1, channel_axis=-1) + y = net(shared, x) + print("out shape: ", y.shape) + expected_y = jnp.array([[4., 5.], + [7., 8.]]).reshape((1, 2, 2, 1)) + np.testing.assert_allclose(y, expected_y) + + def test_maxpool2(self): + bm.random.seed() + x = bm.random.rand(10, 20, 20, 4) + with bm.training_environment(): + net = bp.dnn.MaxPool((2, 2), (2, 2), channel_axis=-1) + y = net(x) + print("out shape: ", y.shape) + + def test_minpool(self): + bm.random.seed() + x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) + shared = {'fit': False} + with bm.training_environment(): + net = bp.dnn.MinPool((2, 2), 1, channel_axis=-1) + y = net(shared, x) + print("out shape: ", y.shape) + expected_y = jnp.array([ + [0., 1.], + [3., 4.], + ]).reshape((1, 2, 2, 1)) + np.testing.assert_allclose(y, expected_y) + + def test_avgpool(self): + bm.random.seed() + x = jnp.full((1, 3, 3, 1), 2.) + with bm.training_environment(): + net = bp.dnn.AvgPool((2, 2), 1, channel_axis=-1) + y = net(x) + print("out shape: ", y.shape) + np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.)) + + def test_MaxPool2d_v1(self): + bm.random.seed() + arr = bm.random.rand(16, 32, 32, 8) + + out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1)(arr) + self.assertTrue(out.shape == (16, 16, 16, 8)) + + out = bp.dnn.MaxPool2d(2, 2, channel_axis=None)(arr) + self.assertTrue(out.shape == (16, 32, 16, 4)) + + out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr) + self.assertTrue(out.shape == (16, 32, 17, 5)) + + out = bp.dnn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) + self.assertTrue(out.shape == (16, 32, 18, 5)) + + out = bp.dnn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) + self.assertTrue(out.shape == (16, 17, 17, 8)) + + out = bp.dnn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) + self.assertTrue(out.shape == (16, 17, 32, 5)) + + def test_AvgPool2d_v1(self): + bm.random.seed() + arr = bm.random.rand(16, 32, 32, 8) + + out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1)(arr) + self.assertTrue(out.shape == (16, 16, 16, 8)) + + out = bp.dnn.AvgPool2d(2, 2, channel_axis=None)(arr) + self.assertTrue(out.shape == (16, 32, 16, 4)) + + out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr) + self.assertTrue(out.shape == (16, 32, 17, 5)) + + out = bp.dnn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) + self.assertTrue(out.shape == (16, 32, 18, 5)) + + out = bp.dnn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) + self.assertTrue(out.shape == (16, 17, 17, 8)) + + out = bp.dnn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) + self.assertTrue(out.shape == (16, 17, 32, 5)) + + @parameterized.named_parameters( + dict(testcase_name=f'target_size={target_size}', + target_size=target_size) + for target_size in [10, 9, 8, 7, 6] + ) + def test_adaptive_pool1d(self, target_size): + bm.random.seed() + from brainpy._src.dnn.pooling import _adaptive_pool1d + + arr = bm.random.rand(100) + op = jax.numpy.mean + + out = _adaptive_pool1d(arr, target_size, op) + print(out.shape) + self.assertTrue(out.shape == (target_size,)) + + out = _adaptive_pool1d(arr, target_size, op) + print(out.shape) + self.assertTrue(out.shape == (target_size,)) + + def test_AdaptiveAvgPool2d_v1(self): + bm.random.seed() + input = bm.random.randn(64, 8, 9) + + output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) + self.assertTrue(output.shape == (64, 5, 7)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) + self.assertTrue(output.shape == (64, 2, 3)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) + self.assertTrue(output.shape == (2, 3, 9)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) + self.assertTrue(output.shape == (2, 8, 3)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input) + self.assertTrue(output.shape == (64, 2, 3)) + + def test_AdaptiveAvgPool2d_v2(self): + bm.random.seed() + input = bm.random.randn(128, 64, 32, 16) + + output = bp.dnn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) + self.assertTrue(output.shape == (128, 64, 5, 7)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) + self.assertTrue(output.shape == (128, 64, 2, 3)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) + self.assertTrue(output.shape == (128, 2, 3, 16)) + + output = bp.dnn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) + self.assertTrue(output.shape == (128, 64, 2, 3)) + print() + + def test_AdaptiveAvgPool3d_v1(self): + bm.random.seed() + input = bm.random.randn(10, 128, 64, 32) + net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], channel_axis=0, mode=bm.nonbatching_mode) + output = net(input) + self.assertTrue(output.shape == (10, 6, 5, 3)) + + def test_AdaptiveAvgPool3d_v2(self): + bm.random.seed() + input = bm.random.randn(10, 20, 128, 64, 32) + net = bp.dnn.AdaptiveAvgPool3d(target_shape=[6, 5, 3], mode=bm.batching_mode) + output = net(input) + self.assertTrue(output.shape == (10, 6, 5, 3, 32)) + + @parameterized.product( + axis=(-1, 0, 1) + ) + def test_AdaptiveMaxPool1d_v1(self, axis): + bm.random.seed() + input = bm.random.randn(32, 16) + net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis) + output = net(input) + + @parameterized.product( + axis=(-1, 0, 1, 2) + ) + def test_AdaptiveMaxPool1d_v2(self, axis): + bm.random.seed() + input = bm.random.randn(2, 32, 16) + net = bp.dnn.AdaptiveMaxPool1d(target_shape=4, channel_axis=axis) + output = net(input) + + @parameterized.product( + axis=(-1, 0, 1, 2) + ) + def test_AdaptiveMaxPool2d_v1(self, axis): + bm.random.seed() + input = bm.random.randn(32, 16, 12) + net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis) + output = net(input) + + @parameterized.product( + axis=(-1, 0, 1, 2, 3) + ) + def test_AdaptiveMaxPool2d_v2(self, axis): + bm.random.seed() + input = bm.random.randn(2, 32, 16, 12) + net = bp.dnn.AdaptiveAvgPool2d(target_shape=[5, 4], channel_axis=axis) + # output = net(input) + + @parameterized.product( + axis=(-1, 0, 1, 2, 3) + ) + def test_AdaptiveMaxPool3d_v1(self, axis): + bm.random.seed() + input = bm.random.randn(2, 128, 64, 32) + net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis) + output = net(input) + print() + + @parameterized.product( + axis=(-1, 0, 1, 2, 3, 4) + ) + def test_AdaptiveMaxPool3d_v1(self, axis): + bm.random.seed() + input = bm.random.randn(2, 128, 64, 32, 16) + net = bp.dnn.AdaptiveMaxPool3d(target_shape=[6, 5, 4], channel_axis=axis) + output = net(input) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dyn/_docs.py b/brainpy/_src/dyn/_docs.py index d528d4266..cd0c56baf 100644 --- a/brainpy/_src/dyn/_docs.py +++ b/brainpy/_src/dyn/_docs.py @@ -37,10 +37,8 @@ V_initializer: ArrayType, callable. The initializer of membrane potential. '''.strip() - ltc_doc = 'with liquid time-constant' - dual_exp_syn_doc = r''' **Model Descriptions** @@ -88,7 +86,6 @@ ''' - alpha_syn_doc = r''' **Model Descriptions** @@ -116,7 +113,6 @@ ''' - exp_syn_doc = r''' **Model Descriptions** @@ -148,7 +144,6 @@ ''' - std_doc = r''' This model filters the synaptic current by the following equation: @@ -172,7 +167,6 @@ ''' - stp_doc = r''' This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. @@ -202,4 +196,3 @@ ''' - diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index e18ac2a82..13439f1ec 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -4,21 +4,20 @@ from brainpy._src.mixin import SupportAutoDelay, ParamDesc __all__ = [ - 'NeuDyn', 'SynDyn', 'IonChaDyn', + 'NeuDyn', 'SynDyn', 'IonChaDyn', ] class NeuDyn(Dynamic, SupportAutoDelay): - """Neuronal Dynamics.""" - pass + """Neuronal Dynamics.""" + pass class SynDyn(Dynamic, SupportAutoDelay, ParamDesc): - """Synaptic Dynamics.""" - pass + """Synaptic Dynamics.""" + pass class IonChaDyn(Dynamic): - """Ion Channel Dynamics.""" - pass - + """Ion Channel Dynamics.""" + pass diff --git a/brainpy/_src/dyn/channels/__init__.py b/brainpy/_src/dyn/channels/__init__.py index 4d43a4d2a..b8165f5f3 100644 --- a/brainpy/_src/dyn/channels/__init__.py +++ b/brainpy/_src/dyn/channels/__init__.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- from .base import * -from .potassium import * -from .sodium import * -from .hyperpolarization_activated import * from .calcium import * -from .potassium_calcium import * +from .hyperpolarization_activated import * from .leaky import * +from .potassium import * +from .potassium_calcium import * +from .sodium import * diff --git a/brainpy/_src/dyn/channels/base.py b/brainpy/_src/dyn/channels/base.py index b933930a0..12d5b4dbb 100644 --- a/brainpy/_src/dyn/channels/base.py +++ b/brainpy/_src/dyn/channels/base.py @@ -1,31 +1,31 @@ # -*- coding: utf-8 -*- from brainpy._src.dyn.base import IonChaDyn -from brainpy._src.mixin import TreeNode from brainpy._src.dyn.neurons.hh import HHTypedNeuron +from brainpy._src.mixin import TreeNode __all__ = [ - 'IonChannel', + 'IonChannel', ] class IonChannel(IonChaDyn, TreeNode): - """Base class for ion channels.""" + """Base class for ion channels.""" - '''The type of the master object.''' - master_type = HHTypedNeuron + '''The type of the master object.''' + master_type = HHTypedNeuron - def update(self, *args, **kwargs): - raise NotImplementedError('Must be implemented by the subclass.') + def update(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') - def current(self, *args, **kwargs): - raise NotImplementedError('Must be implemented by the subclass.') + def current(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') - def reset_state(self, *args, **kwargs): - raise NotImplementedError('Must be implemented by the subclass.') + def reset_state(self, *args, **kwargs): + raise NotImplementedError('Must be implemented by the subclass.') - def clear_input(self): - pass + def clear_input(self): + pass - def __repr__(self): - return f'{self.name}(size={self.size})' + def __repr__(self): + return f'{self.name}(size={self.size})' diff --git a/brainpy/_src/dyn/channels/calcium.py b/brainpy/_src/dyn/channels/calcium.py index 988952182..1e7508a70 100644 --- a/brainpy/_src/dyn/channels/calcium.py +++ b/brainpy/_src/dyn/channels/calcium.py @@ -17,778 +17,778 @@ from .base import IonChannel __all__ = [ - 'CalciumChannel', + 'CalciumChannel', - 'ICaN_IS2008', - 'ICaT_HM1992', - 'ICaT_HP1992', - 'ICaHT_HM1992', - 'ICaL_IS2008', + 'ICaN_IS2008', + 'ICaT_HM1992', + 'ICaT_HP1992', + 'ICaHT_HM1992', + 'ICaL_IS2008', ] class CalciumChannel(IonChannel): - """Base class for Calcium ion channels.""" + """Base class for Calcium ion channels.""" - master_type = Calcium - '''The type of the master object.''' + master_type = Calcium + '''The type of the master object.''' - def update(self, V, C, E): - raise NotImplementedError + def update(self, V, C, E): + raise NotImplementedError - def current(self, V, C, E): - raise NotImplementedError + def current(self, V, C, E): + raise NotImplementedError - def reset(self, V, C, E, batch_size: int = None): - self.reset_state(V, C, E, batch_size) + def reset(self, V, C, E, batch_size: int = None): + self.reset_state(V, C, E, batch_size) - def reset_state(self, V, C, E, batch_size: int = None): - raise NotImplementedError('Must be implemented by the subclass.') + def reset_state(self, V, C, E, batch_size: int = None): + raise NotImplementedError('Must be implemented by the subclass.') class _ICa_p2q_ss(CalciumChannel): - r"""The calcium current model of :math:`p^2q` current which described with steady-state format. - - The dynamics of this generalized calcium current model is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - phi_p : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - phi_p: Union[float, ArrayType, Initializer, Callable] = 3., - phi_q: Union[float, ArrayType, Initializer, Callable] = 3., - g_max: Union[float, ArrayType, Initializer, Callable] = 2., - method: str = 'exp_auto', - mode: Optional[bm.Mode] = None, - name: Optional[str] = None - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode, ) - - # parameters - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # functions - self.integral = odeint(JointEq([self.dp, self.dq]), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V, C, E): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - return self.g_max * self.p * self.p * self.q * (E - V) - - def reset_state(self, V, C, E, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The calcium current model of :math:`p^2q` current which described with steady-state format. + + The dynamics of this generalized calcium current model is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + phi_p : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + phi_p: Union[float, ArrayType, Initializer, Callable] = 3., + phi_q: Union[float, ArrayType, Initializer, Callable] = 3., + g_max: Union[float, ArrayType, Initializer, Callable] = 2., + method: str = 'exp_auto', + mode: Optional[bm.Mode] = None, + name: Optional[str] = None + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode, ) + + # parameters + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # functions + self.integral = odeint(JointEq([self.dp, self.dq]), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V, C, E): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + return self.g_max * self.p * self.p * self.q * (E - V) + + def reset_state(self, V, C, E, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class _ICa_p2q_markov(CalciumChannel): - r"""The calcium current model of :math:`p^2q` current which described with first-order Markov chain. - - The dynamics of this generalized calcium current model is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= \phi_p (\alpha_p(V)(1-p) - \beta_p(V)p) \\ - {dq \over dt} &= \phi_q (\alpha_q(V)(1-q) - \beta_q(V)q) \\ - - where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - phi_p : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - phi_p: Union[float, ArrayType, Initializer, Callable] = 3., - phi_q: Union[float, ArrayType, Initializer, Callable] = 3., - g_max: Union[float, ArrayType, Initializer, Callable] = 2., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # functions - self.integral = odeint(JointEq([self.dp, self.dq]), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_alpha(V) * (1 - p) - self.f_p_beta(V) * p) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q) - - def update(self, V, C, E): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - return self.g_max * self.p * self.p * self.q * (E - V) - - def reset_state(self, V, C, E, batch_size=None): - alpha, beta = self.f_p_alpha(V), self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - alpha, beta = self.f_q_alpha(V), self.f_q_beta(V) - self.q.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError - - def f_q_alpha(self, V): - raise NotImplementedError - - def f_q_beta(self, V): - raise NotImplementedError + r"""The calcium current model of :math:`p^2q` current which described with first-order Markov chain. + + The dynamics of this generalized calcium current model is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= \phi_p (\alpha_p(V)(1-p) - \beta_p(V)p) \\ + {dq \over dt} &= \phi_q (\alpha_q(V)(1-q) - \beta_q(V)q) \\ + + where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + phi_p : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + phi_p: Union[float, ArrayType, Initializer, Callable] = 3., + phi_q: Union[float, ArrayType, Initializer, Callable] = 3., + g_max: Union[float, ArrayType, Initializer, Callable] = 2., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # functions + self.integral = odeint(JointEq([self.dp, self.dq]), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_alpha(V) * (1 - p) - self.f_p_beta(V) * p) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_alpha(V) * (1 - q) - self.f_q_beta(V) * q) + + def update(self, V, C, E): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + return self.g_max * self.p * self.p * self.q * (E - V) + + def reset_state(self, V, C, E, batch_size=None): + alpha, beta = self.f_p_alpha(V), self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + alpha, beta = self.f_q_alpha(V), self.f_q_beta(V) + self.q.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError + + def f_q_alpha(self, V): + raise NotImplementedError + + def f_q_beta(self, V): + raise NotImplementedError class ICaN_IS2008(CalciumChannel): - r"""The calcium-activated non-selective cation channel model - proposed by (Inoue & Strowbridge, 2008) [2]_. - - The dynamics of the calcium-activated non-selective cation channel model [1]_ [2]_ is given by: - - .. math:: - - \begin{aligned} - I_{CAN} &=g_{\mathrm{max}} M\left([Ca^{2+}]_{i}\right) p \left(V-E\right)\\ - &M\left([Ca^{2+}]_{i}\right) ={[Ca^{2+}]_{i} \over 0.2+[Ca^{2+}]_{i}} \\ - &{dp \over dt} = {\phi \cdot (p_{\infty}-p)\over \tau_p} \\ - &p_{\infty} = {1.0 \over 1 + \exp(-(V + 43) / 5.2)} \\ - &\tau_{p} = {2.7 \over \exp(-(V + 55) / 15) + \exp((V + 55) / 15)} + 1.6 - \end{aligned} - - where :math:`\phi` is the temperature factor. - - Parameters:: - - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - E : float - The reversal potential (mV). - phi : float - The temperature factor. - - References:: - - .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated - thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. - .. [2] Inoue T, Strowbridge BW (2008) Transient activity induces a long-lasting - increase in the excitability of olfactory bulb interneurons. - J Neurophysiol 99: 187–199. - """ - - '''The type of the master object.''' - master_type = CalciumDyna - - def __init__( - self, - size: Shape, - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = 10., - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.derivative, method=method) - - def derivative(self, p, t, V): - phi_p = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) - p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6 - return self.phi * (phi_p - p) / p_inf - - def update(self, V, C, E): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - M = C / (C + 0.2) - g = self.g_max * M * self.p - return g * (self.E - V) - - def reset_state(self, V, C, E, batch_size=None): - self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size + r"""The calcium-activated non-selective cation channel model + proposed by (Inoue & Strowbridge, 2008) [2]_. + + The dynamics of the calcium-activated non-selective cation channel model [1]_ [2]_ is given by: + + .. math:: + + \begin{aligned} + I_{CAN} &=g_{\mathrm{max}} M\left([Ca^{2+}]_{i}\right) p \left(V-E\right)\\ + &M\left([Ca^{2+}]_{i}\right) ={[Ca^{2+}]_{i} \over 0.2+[Ca^{2+}]_{i}} \\ + &{dp \over dt} = {\phi \cdot (p_{\infty}-p)\over \tau_p} \\ + &p_{\infty} = {1.0 \over 1 + \exp(-(V + 43) / 5.2)} \\ + &\tau_{p} = {2.7 \over \exp(-(V + 55) / 15) + \exp((V + 55) / 15)} + 1.6 + \end{aligned} + + where :math:`\phi` is the temperature factor. + + Parameters:: + + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + E : float + The reversal potential (mV). + phi : float + The temperature factor. + + References:: + + .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated + thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. + .. [2] Inoue T, Strowbridge BW (2008) Transient activity induces a long-lasting + increase in the excitability of olfactory bulb interneurons. + J Neurophysiol 99: 187–199. + """ + + '''The type of the master object.''' + master_type = CalciumDyna + + def __init__( + self, + size: Shape, + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = 10., + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.derivative, method=method) + + def derivative(self, p, t, V): + phi_p = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) + p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp((V + 55.) / 15.)) + 1.6 + return self.phi * (phi_p - p) / p_inf + + def update(self, V, C, E): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + M = C / (C + 0.2) + g = self.g_max * M * self.p + return g * (self.E - V) + + def reset_state(self, V, C, E, batch_size=None): + self.p.value = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size class ICaT_HM1992(_ICa_p2q_ss): - r"""The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. - - The dynamics of the low-threshold T-type calcium current model [1]_ is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - &p_{\infty} = {1 \over 1+\exp [-(V+59-V_{sh}) / 6.2]} \\ - &\tau_{p} = 0.612 + {1 \over \exp [-(V+132.-V_{sh}) / 16.7]+\exp [(V+16.8-V_{sh}) / 18.2]} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - &q_{\infty} = {1 \over 1+\exp [(V+83-V_{sh}) / 4]} \\ - & \begin{array}{l} \tau_{q} = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ - \tau_{q} = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} - - where :math:`\phi_p = 3.55^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters:: - - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in - rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. - - See Also:: - - ICa_p2q_form - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType] = 36., - T_base_p: Union[float, ArrayType] = 3.55, - T_base_q: Union[float, ArrayType] = 3., - g_max: Union[float, ArrayType, Initializer, Callable] = 2., - V_sh: Union[float, ArrayType, Initializer, Callable] = -3., - phi_p: Union[float, ArrayType, Initializer, Callable] = None, - phi_q: Union[float, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p - phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) - self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2)) - - def f_p_tau(self, V): - return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + - bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.0)) - - def f_q_tau(self, V): - return bm.where(V >= (-80. + self.V_sh), - bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., - bm.exp((V + 467. - self.V_sh) / 66.6)) + r"""The low-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. + + The dynamics of the low-threshold T-type calcium current model [1]_ is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + &p_{\infty} = {1 \over 1+\exp [-(V+59-V_{sh}) / 6.2]} \\ + &\tau_{p} = 0.612 + {1 \over \exp [-(V+132.-V_{sh}) / 16.7]+\exp [(V+16.8-V_{sh}) / 18.2]} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + &q_{\infty} = {1 \over 1+\exp [(V+83-V_{sh}) / 4]} \\ + & \begin{array}{l} \tau_{q} = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ + \tau_{q} = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} + + where :math:`\phi_p = 3.55^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters:: + + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in + rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. + + See Also:: + + ICa_p2q_form + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType] = 36., + T_base_p: Union[float, ArrayType] = 3.55, + T_base_q: Union[float, ArrayType] = 3., + g_max: Union[float, ArrayType, Initializer, Callable] = 2., + V_sh: Union[float, ArrayType, Initializer, Callable] = -3., + phi_p: Union[float, ArrayType, Initializer, Callable] = None, + phi_q: Union[float, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p + phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1 + bm.exp(-(V + 59. - self.V_sh) / 6.2)) + + def f_p_tau(self, V): + return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + + bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.0)) + + def f_q_tau(self, V): + return bm.where(V >= (-80. + self.V_sh), + bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., + bm.exp((V + 467. - self.V_sh) / 66.6)) class ICaT_HP1992(_ICa_p2q_ss): - r"""The low-threshold T-type calcium current model for thalamic - reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_. - - The dynamics of the low-threshold T-type calcium current model in thalamic - reticular nucleus neurons [1]_ is given by: - - .. math:: - - I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - &p_{\infty} = {1 \over 1+\exp [-(V+52-V_{sh}) / 7.4]} \\ - &\tau_{p} = 3+{1 \over \exp [(V+27-V_{sh}) / 10]+\exp [-(V+102-V_{sh}) / 15]} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - &q_{\infty} = {1 \over 1+\exp [(V+80-V_{sh}) / 5]} \\ - & \tau_q = 85+ {1 \over \exp [(V+48-V_{sh}) / 4]+\exp [-(V+407-V_{sh}) / 50]} - - where :math:`\phi_p = 5^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters:: - - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [1] Huguenard JR, Prince DA (1992) A novel T-type current underlies - prolonged Ca2+- dependent burst firing in GABAergic neurons of rat - thalamic reticular nucleus. J Neurosci 12: 3804–3817. - - See Also:: - - ICa_p2q_form - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType] = 36., - T_base_p: Union[float, ArrayType] = 5., - T_base_q: Union[float, ArrayType] = 3., - g_max: Union[float, ArrayType, Initializer, Callable] = 1.75, - V_sh: Union[float, ArrayType, Initializer, Callable] = -3., - phi_p: Union[float, ArrayType, Initializer, Callable] = None, - phi_q: Union[float, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p - phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) - self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4)) - - def f_p_tau(self, V): - return 3. + 1. / (bm.exp((V + 27. - self.V_sh) / 10.) + - bm.exp(-(V + 102. - self.V_sh) / 15.)) - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.)) - - def f_q_tau(self, V): - return 85. + 1. / (bm.exp((V + 48. - self.V_sh) / 4.) + - bm.exp(-(V + 407. - self.V_sh) / 50.)) + r"""The low-threshold T-type calcium current model for thalamic + reticular nucleus proposed by (Huguenard & Prince, 1992) [1]_. + + The dynamics of the low-threshold T-type calcium current model in thalamic + reticular nucleus neurons [1]_ is given by: + + .. math:: + + I_{CaT} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + &p_{\infty} = {1 \over 1+\exp [-(V+52-V_{sh}) / 7.4]} \\ + &\tau_{p} = 3+{1 \over \exp [(V+27-V_{sh}) / 10]+\exp [-(V+102-V_{sh}) / 15]} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + &q_{\infty} = {1 \over 1+\exp [(V+80-V_{sh}) / 5]} \\ + & \tau_q = 85+ {1 \over \exp [(V+48-V_{sh}) / 4]+\exp [-(V+407-V_{sh}) / 50]} + + where :math:`\phi_p = 5^{\frac{T-24}{10}}` and :math:`\phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters:: + + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [1] Huguenard JR, Prince DA (1992) A novel T-type current underlies + prolonged Ca2+- dependent burst firing in GABAergic neurons of rat + thalamic reticular nucleus. J Neurosci 12: 3804–3817. + + See Also:: + + ICa_p2q_form + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType] = 36., + T_base_p: Union[float, ArrayType] = 5., + T_base_q: Union[float, ArrayType] = 3., + g_max: Union[float, ArrayType, Initializer, Callable] = 1.75, + V_sh: Union[float, ArrayType, Initializer, Callable] = -3., + phi_p: Union[float, ArrayType, Initializer, Callable] = None, + phi_q: Union[float, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + phi_p = T_base_p ** ((T - 24) / 10) if phi_p is None else phi_p + phi_q = T_base_q ** ((T - 24) / 10) if phi_q is None else phi_q + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4)) + + def f_p_tau(self, V): + return 3. + 1. / (bm.exp((V + 27. - self.V_sh) / 10.) + + bm.exp(-(V + 102. - self.V_sh) / 15.)) + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.)) + + def f_q_tau(self, V): + return 85. + 1. / (bm.exp((V + 48. - self.V_sh) / 4.) + + bm.exp(-(V + 407. - self.V_sh) / 50.)) class ICaHT_HM1992(_ICa_p2q_ss): - r"""The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. - - The high-threshold T-type calcium current model is adopted from [1]_. - Its dynamics is given by - - .. math:: - - \begin{aligned} - I_{\mathrm{Ca/HT}} &= g_{\mathrm{max}} p^2 q (V-E_{Ca}) - \\ - {dp \over dt} &= {\phi_{p} \cdot (p_{\infty} - p) \over \tau_{p}} \\ - &\tau_{p} =\frac{1}{\exp \left(\frac{V+132-V_{sh}}{-16.7}\right)+\exp \left(\frac{V+16.8-V_{sh}}{18.2}\right)}+0.612 \\ - & p_{\infty} = {1 \over 1+exp[-(V+59-V_{sh}) / 6.2]} - \\ - {dq \over dt} &= {\phi_{q} \cdot (q_{\infty} - h) \over \tau_{q}} \\ - & \begin{array}{l} \tau_q = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ - \tau_q = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} \\ - &q_{\infty} = {1 \over 1+exp[(V+83 -V_{shift})/4]} - \end{aligned} - - where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters:: - - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - g_max : float, ArrayType, Initializer, Callable - The maximum conductance. - V_sh : float, ArrayType, Initializer, Callable - The membrane potential shift. - - References:: - - .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in - rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. - - See Also:: - - ICa_p2q_form - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType] = 36., - T_base_p: Union[float, ArrayType] = 3.55, - T_base_q: Union[float, ArrayType] = 3., - g_max: Union[float, ArrayType, Initializer, Callable] = 2., - V_sh: Union[float, ArrayType, Initializer, Callable] = 25., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_p=T_base_p ** ((T - 24) / 10), - phi_q=T_base_q ** ((T - 24) / 10), - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) - self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq([self.dp, self.dq]), method=method) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2)) - - def f_p_tau(self, V): - return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + - bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.)) - - def f_q_tau(self, V): - return bm.where(V >= (-80. + self.V_sh), - bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., - bm.exp((V + 467. - self.V_sh) / 66.6)) + r"""The high-threshold T-type calcium current model proposed by (Huguenard & McCormick, 1992) [1]_. + + The high-threshold T-type calcium current model is adopted from [1]_. + Its dynamics is given by + + .. math:: + + \begin{aligned} + I_{\mathrm{Ca/HT}} &= g_{\mathrm{max}} p^2 q (V-E_{Ca}) + \\ + {dp \over dt} &= {\phi_{p} \cdot (p_{\infty} - p) \over \tau_{p}} \\ + &\tau_{p} =\frac{1}{\exp \left(\frac{V+132-V_{sh}}{-16.7}\right)+\exp \left(\frac{V+16.8-V_{sh}}{18.2}\right)}+0.612 \\ + & p_{\infty} = {1 \over 1+exp[-(V+59-V_{sh}) / 6.2]} + \\ + {dq \over dt} &= {\phi_{q} \cdot (q_{\infty} - h) \over \tau_{q}} \\ + & \begin{array}{l} \tau_q = \exp \left(\frac{V+467-V_{sh}}{66.6}\right) \quad V< (-80 +V_{sh})\, mV \\ + \tau_q = \exp \left(\frac{V+22-V_{sh}}{-10.5}\right)+28 \quad V \geq (-80 + V_{sh})\, mV \end{array} \\ + &q_{\infty} = {1 \over 1+exp[(V+83 -V_{shift})/4]} + \end{aligned} + + where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters:: + + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + g_max : float, ArrayType, Initializer, Callable + The maximum conductance. + V_sh : float, ArrayType, Initializer, Callable + The membrane potential shift. + + References:: + + .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in + rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. + + See Also:: + + ICa_p2q_form + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType] = 36., + T_base_p: Union[float, ArrayType] = 3.55, + T_base_q: Union[float, ArrayType] = 3., + g_max: Union[float, ArrayType, Initializer, Callable] = 2., + V_sh: Union[float, ArrayType, Initializer, Callable] = 25., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_p=T_base_p ** ((T - 24) / 10), + phi_q=T_base_q ** ((T - 24) / 10), + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq([self.dp, self.dq]), method=method) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2)) + + def f_p_tau(self, V): + return 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + + bm.exp((V + 16.8 - self.V_sh) / 18.2)) + 0.612 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.)) + + def f_q_tau(self, V): + return bm.where(V >= (-80. + self.V_sh), + bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., + bm.exp((V + 467. - self.V_sh) / 66.6)) class ICaHT_Re1993(_ICa_p2q_markov): - r"""The high-threshold T-type calcium current model proposed by (Reuveni, et al., 1993) [1]_. - - HVA Calcium current was described for neocortical neurons by Sayer et al. (1990). - Its dynamics is given by (the rate functions are measured under 36 Celsius): - - .. math:: - - \begin{aligned} - I_{L} &=\bar{g}_{L} q^{2} r\left(V-E_{\mathrm{Ca}}\right) \\ - \frac{\mathrm{d} q}{\mathrm{~d} t} &= \phi_p (\alpha_{q}(V)(1-q)-\beta_{q}(V) q) \\ - \frac{\mathrm{d} r}{\mathrm{~d} t} &= \phi_q (\alpha_{r}(V)(1-r)-\beta_{r}(V) r) \\ - \alpha_{q} &=\frac{0.055(-27-V+V_{sh})}{\exp [(-27-V+V_{sh}) / 3.8]-1} \\ - \beta_{q} &=0.94 \exp [(-75-V+V_{sh}) / 17] \\ - \alpha_{r} &=0.000457 \exp [(-13-V+V_{sh}) / 50] \\ - \beta_{r} &=\frac{0.0065}{\exp [(-15-V+V_{sh}) / 28]+1}, - \end{aligned} - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximum conductance. - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - T : float, ArrayType - The temperature. - T_base_p : float, ArrayType - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float, ArrayType - The brainpy_object temperature factor of :math:`q` channel. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - If `None`, :math:`\phi_p = \mathrm{T_base_p}^{\frac{T-23}{10}}`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - If `None`, :math:`\phi_q = \mathrm{T_base_q}^{\frac{T-23}{10}}`. - - References:: - - .. [1] Reuveni, I., et al. "Stepwise repolarization from Ca2+ plateaus - in neocortical pyramidal cells: evidence for nonhomogeneous - distribution of HVA Ca2+ channels in dendrites." Journal of - Neuroscience 13.11 (1993): 4609-4621. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType] = 36., - T_base_p: Union[float, ArrayType] = 2.3, - T_base_q: Union[float, ArrayType] = 2.3, - phi_p: Union[float, ArrayType, Initializer, Callable] = None, - phi_q: Union[float, ArrayType, Initializer, Callable] = None, - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p - phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode) - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) - self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = -27 - V + self.V_sh - return 0.055 * temp / (bm.exp(temp / 3.8) - 1) - - def f_p_beta(self, V): - return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.) - - def f_q_alpha(self, V): - return 0.000457 * bm.exp((-13. - V + self.V_sh) / 50.) - - def f_q_beta(self, V): - return 0.0065 / (bm.exp((-15. - V + self.V_sh) / 28.) + 1.) + r"""The high-threshold T-type calcium current model proposed by (Reuveni, et al., 1993) [1]_. + + HVA Calcium current was described for neocortical neurons by Sayer et al. (1990). + Its dynamics is given by (the rate functions are measured under 36 Celsius): + + .. math:: + + \begin{aligned} + I_{L} &=\bar{g}_{L} q^{2} r\left(V-E_{\mathrm{Ca}}\right) \\ + \frac{\mathrm{d} q}{\mathrm{~d} t} &= \phi_p (\alpha_{q}(V)(1-q)-\beta_{q}(V) q) \\ + \frac{\mathrm{d} r}{\mathrm{~d} t} &= \phi_q (\alpha_{r}(V)(1-r)-\beta_{r}(V) r) \\ + \alpha_{q} &=\frac{0.055(-27-V+V_{sh})}{\exp [(-27-V+V_{sh}) / 3.8]-1} \\ + \beta_{q} &=0.94 \exp [(-75-V+V_{sh}) / 17] \\ + \alpha_{r} &=0.000457 \exp [(-13-V+V_{sh}) / 50] \\ + \beta_{r} &=\frac{0.0065}{\exp [(-15-V+V_{sh}) / 28]+1}, + \end{aligned} + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximum conductance. + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + T : float, ArrayType + The temperature. + T_base_p : float, ArrayType + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float, ArrayType + The brainpy_object temperature factor of :math:`q` channel. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + If `None`, :math:`\phi_p = \mathrm{T_base_p}^{\frac{T-23}{10}}`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + If `None`, :math:`\phi_q = \mathrm{T_base_q}^{\frac{T-23}{10}}`. + + References:: + + .. [1] Reuveni, I., et al. "Stepwise repolarization from Ca2+ plateaus + in neocortical pyramidal cells: evidence for nonhomogeneous + distribution of HVA Ca2+ channels in dendrites." Journal of + Neuroscience 13.11 (1993): 4609-4621. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType] = 36., + T_base_p: Union[float, ArrayType] = 2.3, + T_base_q: Union[float, ArrayType] = 2.3, + phi_p: Union[float, ArrayType, Initializer, Callable] = None, + phi_q: Union[float, ArrayType, Initializer, Callable] = None, + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + phi_p = T_base_p ** ((T - 23.) / 10.) if phi_p is None else phi_p + phi_q = T_base_q ** ((T - 23.) / 10.) if phi_q is None else phi_q + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + mode=mode) + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = -27 - V + self.V_sh + return 0.055 * temp / (bm.exp(temp / 3.8) - 1) + + def f_p_beta(self, V): + return 0.94 * bm.exp((-75. - V + self.V_sh) / 17.) + + def f_q_alpha(self, V): + return 0.000457 * bm.exp((-13. - V + self.V_sh) / 50.) + + def f_q_beta(self, V): + return 0.0065 / (bm.exp((-15. - V + self.V_sh) / 28.) + 1.) class ICaL_IS2008(_ICa_p2q_ss): - r"""The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_. - - The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_. - Its dynamics is given by: - - .. math:: - - I_{CaL} &= g_{max} p^2 q(V-E_{Ca}) \\ - {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ - & p_{\infty} = {1 \over 1+\exp [-(V+10-V_{sh}) / 4.]} \\ - & \tau_{p} = 0.4+{0.7 \over \exp [(V+5-V_{sh}) / 15]+\exp [-(V+5-V_{sh}) / 15]} \\ - {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ - & q_{\infty} = {1 \over 1+\exp [(V+25-V_{sh}) / 2]} \\ - & \tau_q = 300 + {100 \over \exp [(V+40-V_{sh}) / 9.5]+\exp [-(V+40-V_{sh}) / 9.5]} - - where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` - are temperature-dependent factors (:math:`T` is the temperature in Celsius), - :math:`E_{Ca}` is the reversal potential of Calcium channel. - - Parameters:: - - T : float - The temperature. - T_base_p : float - The brainpy_object temperature factor of :math:`p` channel. - T_base_q : float - The brainpy_object temperature factor of :math:`q` channel. - g_max : float - The maximum conductance. - V_sh : float - The membrane potential shift. - - References:: - - .. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting - increase in the excitability of olfactory bulb interneurons." Journal of - neurophysiology 99, no. 1 (2008): 187-199. - - See Also:: - - ICa_p2q_form - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType, Initializer, Callable] = 36., - T_base_p: Union[float, ArrayType, Initializer, Callable] = 3.55, - T_base_q: Union[float, ArrayType, Initializer, Callable] = 3., - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_p=T_base_p ** ((T - 24) / 10), - phi_q=T_base_q ** ((T - 24) / 10), - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) - self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.)) - - def f_p_tau(self, V): - return 0.4 + .7 / (bm.exp(-(V + 5. - self.V_sh) / 15.) + - bm.exp((V + 5. - self.V_sh) / 15.)) - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.)) - - def f_q_tau(self, V): - return 300. + 100. / (bm.exp((V + 40 - self.V_sh) / 9.5) + - bm.exp(-(V + 40 - self.V_sh) / 9.5)) + r"""The L-type calcium channel model proposed by (Inoue & Strowbridge, 2008) [1]_. + + The L-type calcium channel model is adopted from (Inoue, et, al., 2008) [1]_. + Its dynamics is given by: + + .. math:: + + I_{CaL} &= g_{max} p^2 q(V-E_{Ca}) \\ + {dp \over dt} &= {\phi_p \cdot (p_{\infty}-p)\over \tau_p} \\ + & p_{\infty} = {1 \over 1+\exp [-(V+10-V_{sh}) / 4.]} \\ + & \tau_{p} = 0.4+{0.7 \over \exp [(V+5-V_{sh}) / 15]+\exp [-(V+5-V_{sh}) / 15]} \\ + {dq \over dt} &= {\phi_q \cdot (q_{\infty}-q) \over \tau_q} \\ + & q_{\infty} = {1 \over 1+\exp [(V+25-V_{sh}) / 2]} \\ + & \tau_q = 300 + {100 \over \exp [(V+40-V_{sh}) / 9.5]+\exp [-(V+40-V_{sh}) / 9.5]} + + where :math:`phi_p = 3.55^{\frac{T-24}{10}}` and :math:`phi_q = 3^{\frac{T-24}{10}}` + are temperature-dependent factors (:math:`T` is the temperature in Celsius), + :math:`E_{Ca}` is the reversal potential of Calcium channel. + + Parameters:: + + T : float + The temperature. + T_base_p : float + The brainpy_object temperature factor of :math:`p` channel. + T_base_q : float + The brainpy_object temperature factor of :math:`q` channel. + g_max : float + The maximum conductance. + V_sh : float + The membrane potential shift. + + References:: + + .. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting + increase in the excitability of olfactory bulb interneurons." Journal of + neurophysiology 99, no. 1 (2008): 187-199. + + See Also:: + + ICa_p2q_form + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType, Initializer, Callable] = 36., + T_base_p: Union[float, ArrayType, Initializer, Callable] = 3.55, + T_base_q: Union[float, ArrayType, Initializer, Callable] = 3., + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_p=T_base_p ** ((T - 24) / 10), + phi_q=T_base_q ** ((T - 24) / 10), + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base_p = parameter(T_base_p, self.varshape, allow_none=False) + self.T_base_q = parameter(T_base_q, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.)) + + def f_p_tau(self, V): + return 0.4 + .7 / (bm.exp(-(V + 5. - self.V_sh) / 15.) + + bm.exp((V + 5. - self.V_sh) / 15.)) + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.)) + + def f_q_tau(self, V): + return 300. + 100. / (bm.exp((V + 40 - self.V_sh) / 9.5) + + bm.exp(-(V + 40 - self.V_sh) / 9.5)) diff --git a/brainpy/_src/dyn/channels/hyperpolarization_activated.py b/brainpy/_src/dyn/channels/hyperpolarization_activated.py index 07b6c75b2..27b5be101 100644 --- a/brainpy/_src/dyn/channels/hyperpolarization_activated.py +++ b/brainpy/_src/dyn/channels/hyperpolarization_activated.py @@ -16,235 +16,235 @@ from .base import IonChannel __all__ = [ - 'Ih_HM1992', - 'Ih_De1996', + 'Ih_HM1992', + 'Ih_De1996', ] class Ih_HM1992(IonChannel): - r"""The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_. + r"""The hyperpolarization-activated cation current model propsoed by (Huguenard & McCormick, 1992) [1]_. - The hyperpolarization-activated cation current model is adopted from - (Huguenard, et, al., 1992) [1]_. Its dynamics is given by: + The hyperpolarization-activated cation current model is adopted from + (Huguenard, et, al., 1992) [1]_. Its dynamics is given by: - .. math:: + .. math:: - \begin{aligned} - I_h &= g_{\mathrm{max}} p \\ - \frac{dp}{dt} &= \phi \frac{p_{\infty} - p}{\tau_p} \\ - p_{\infty} &=\frac{1}{1+\exp ((V+75) / 5.5)} \\ - \tau_{p} &=\frac{1}{\exp (-0.086 V-14.59)+\exp (0.0701 V-1.87)} - \end{aligned} + \begin{aligned} + I_h &= g_{\mathrm{max}} p \\ + \frac{dp}{dt} &= \phi \frac{p_{\infty} - p}{\tau_p} \\ + p_{\infty} &=\frac{1}{1+\exp ((V+75) / 5.5)} \\ + \tau_{p} &=\frac{1}{\exp (-0.086 V-14.59)+\exp (0.0701 V-1.87)} + \end{aligned} - where :math:`\phi=1` is a temperature-dependent factor. + where :math:`\phi=1` is a temperature-dependent factor. - Parameters:: + Parameters:: - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - E : float - The reversal potential (mV). - phi : float - The temperature-dependent factor. + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + E : float + The reversal potential (mV). + phi : float + The temperature-dependent factor. - References:: + References:: - .. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents - involved in rhythmic oscillations in thalamic relay neurons." Journal - of neurophysiology 68, no. 4 (1992): 1373-1383. + .. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents + involved in rhythmic oscillations in thalamic relay neurons." Journal + of neurophysiology 68, no. 4 (1992): 1373-1383. - """ + """ - master_type = HHTypedNeuron + master_type = HHTypedNeuron - def __init__( - self, - size: Shape, - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - E: Union[float, ArrayType, Initializer, Callable] = 43., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) + def __init__( + self, + size: Shape, + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + E: Union[float, ArrayType, Initializer, Callable] = 43., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) - # parameters - self.phi = parameter(phi, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.E = parameter(E, self.varshape, allow_none=False) + # parameters + self.phi = parameter(phi, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) - # variable - self.p = variable(bm.zeros, self.mode, self.varshape) + # variable + self.p = variable(bm.zeros, self.mode, self.varshape) - # function - self.integral = odeint(self.derivative, method=method) + # function + self.integral = odeint(self.derivative, method=method) - def derivative(self, p, t, V): - return self.phi * (self.f_p_inf(V) - p) / self.f_p_tau(V) + def derivative(self, p, t, V): + return self.phi * (self.f_p_inf(V) - p) / self.f_p_tau(V) - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size - def update(self, V): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + def update(self, V): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - def current(self, V): - return self.g_max * self.p * (self.E - V) + def current(self, V): + return self.g_max * self.p * (self.E - V) - def f_p_inf(self, V): - return 1. / (1. + bm.exp((V + 75.) / 5.5)) + def f_p_inf(self, V): + return 1. / (1. + bm.exp((V + 75.) / 5.5)) - def f_p_tau(self, V): - return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87)) + def f_p_tau(self, V): + return 1. / (bm.exp(-0.086 * V - 14.59) + bm.exp(0.0701 * V - 1.87)) class Ih_De1996(IonChannel): - r"""The hyperpolarization-activated cation current model propsoed by (Destexhe, et al., 1996) [1]_. - - The full kinetic schema was - - .. math:: - - \begin{gathered} - C \underset{\beta(V)}{\stackrel{\alpha(V)}{\rightleftarrows}} O \\ - P_{0}+2 \mathrm{Ca}^{2+} \underset{k_{2}}{\stackrel{k_{1}}{\rightleftarrows}} P_{1} \\ - O+P_{1} \underset{k_{4}}{\rightleftarrows} O_{\mathrm{L}} - \end{gathered} - - where the first reaction represents the voltage-dependent transitions of :math:`I_h` channels - between closed (C) and open (O) forms, with :math:`\alpha` and :math:`\beta` as transition rates. - The second reaction represents the biding of intracellular :math:`\mathrm{Ca^{2+}}` ions to a - regulating factor (:math:`P_0` for unbound and :math:`P_1` for bound) with four binding sites for - calcium and rates of :math:`k_1 = 2.5e^7\, mM^{-4} \, ms^{-1}` and :math:`k_2=4e-4 \, ms^{-1}` - (half-activation of 0.002 mM :math:`Ca^{2+}`). The calcium-bound form :math:`P_1` associates - with the open form of the channel, leading to a locked open form :math:`O_L`, with rates of - :math:`k_3=0.1 \, ms^{-1}` and :math:`k_4 = 0.001 \, ms^{-1}`. - - The current is the proportional to the relative concentration of open channels - - .. math:: - - I_h = g_h (O+g_{inc}O_L) (V - E_h) - - with a maximal conductance of :math:`\bar{g}_{\mathrm{h}}=0.02 \mathrm{mS} / \mathrm{cm}^{2}` - and a reversal potential of :math:`E_{\mathrm{h}}=-40 \mathrm{mV}`. Because of the factor - :math:`g_{\text {inc }}=2`, the conductance of the calcium-bound open state of - :math:`I_{\mathrm{h}}` channels is twice that of the unbound open state. This produces an - augmentation of conductance after the binding of :math:`\mathrm{Ca}^{2+}`, as observed in - sino-atrial cells (Hagiwara and Irisawa 1989). - - The rates of :math:`\alpha` and :math:`\beta` are: - - .. math:: - - & \alpha = m_{\infty} / \tau_m \\ - & \beta = (1-m_{\infty}) / \tau_m \\ - & m_{\infty} = 1/(1+\exp((V+75-V_{sh})/5.5)) \\ - & \tau_m = (5.3 + 267/(\exp((V+71.5-V_{sh})/14.2) + \exp(-(V+89-V_{sh})/11.6))) - - and the temperature regulating factor :math:`\phi=2^{(T-24)/10}`. - - References:: - - .. [1] Destexhe, Alain, et al. "Ionic mechanisms underlying synchronized - oscillations and propagating waves in a model of ferret thalamic - slices." Journal of neurophysiology 76.3 (1996): 2049-2070. - """ - - master_type = Calcium - - def __init__( - self, - size: Shape, - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -40., - k2: Union[float, ArrayType, Initializer, Callable] = 4e-4, - k4: Union[float, ArrayType, Initializer, Callable] = 1e-3, - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - g_max: Union[float, ArrayType, Initializer, Callable] = 0.02, - g_inc: Union[float, ArrayType, Initializer, Callable] = 2., - Ca_half: Union[float, ArrayType, Initializer, Callable] = 2e-3, - T: Union[float, ArrayType] = 36., - T_base: Union[float, ArrayType] = 3., - phi: Union[float, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base = parameter(T_base, self.varshape, allow_none=False) - if phi is None: - self.phi = self.T_base ** ((self.T - 24.) / 10) - else: - self.phi = parameter(phi, self.varshape, allow_none=False) - self.E = parameter(E, self.varshape, allow_none=False) - self.k2 = parameter(k2, self.varshape, allow_none=False) - self.Ca_half = parameter(Ca_half, self.varshape, allow_none=False) - self.k1 = self.k2 / self.Ca_half ** 4 - self.k4 = parameter(k4, self.varshape, allow_none=False) - self.k3 = self.k4 / 0.01 - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.g_inc = parameter(g_inc, self.varshape, allow_none=False) - - # variable - self.O = variable(bm.zeros, self.mode, self.varshape) - self.OL = variable(bm.zeros, self.mode, self.varshape) - self.P1 = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method) - - def dO(self, O, t, OL, V): - inf = self.f_inf(V) - tau = self.f_tau(V) - alpha = inf / tau - beta = (1 - inf) / tau - return alpha * (1 - O - OL) - beta * O - - def dOL(self, OL, t, O, P1): - return self.k3 * P1 * O - self.k4 * OL - - def dP1(self, P1, t, C_Ca): - return self.k1 * C_Ca ** 4 * (1 - P1) - self.k2 * P1 - - def update(self, V, C_Ca, E_Ca): - self.O.value, self.OL.value, self.P1.value = self.integral(self.O.value, self.OL.value, self.P1.value, - share['t'], V=V, C_Ca=C_Ca, dt=share['dt']) - - def current(self, V, C_Ca, E_Ca): - return self.g_max * (self.O + self.g_inc * self.OL) * (self.E - V) - - def reset_state(self, V, C_Ca, E_Ca, batch_size=None): - varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape) - self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape) - inf = self.f_inf(V) - tau = self.f_tau(V) - alpha = inf / tau - beta = (1 - inf) / tau - self.O.value = alpha / (alpha + alpha * self.k3 * self.P1 / self.k4 + beta) - self.OL.value = self.k3 * self.P1 * self.O / self.k4 - if isinstance(batch_size, int): - assert self.P1.shape[0] == batch_size - assert self.O.shape[0] == batch_size - assert self.OL.shape[0] == batch_size - - def f_inf(self, V): - return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5)) - - def f_tau(self, V): - return (20. + 1000 / (bm.exp((V + 71.5 - self.V_sh) / 14.2) + - bm.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi + r"""The hyperpolarization-activated cation current model propsoed by (Destexhe, et al., 1996) [1]_. + + The full kinetic schema was + + .. math:: + + \begin{gathered} + C \underset{\beta(V)}{\stackrel{\alpha(V)}{\rightleftarrows}} O \\ + P_{0}+2 \mathrm{Ca}^{2+} \underset{k_{2}}{\stackrel{k_{1}}{\rightleftarrows}} P_{1} \\ + O+P_{1} \underset{k_{4}}{\rightleftarrows} O_{\mathrm{L}} + \end{gathered} + + where the first reaction represents the voltage-dependent transitions of :math:`I_h` channels + between closed (C) and open (O) forms, with :math:`\alpha` and :math:`\beta` as transition rates. + The second reaction represents the biding of intracellular :math:`\mathrm{Ca^{2+}}` ions to a + regulating factor (:math:`P_0` for unbound and :math:`P_1` for bound) with four binding sites for + calcium and rates of :math:`k_1 = 2.5e^7\, mM^{-4} \, ms^{-1}` and :math:`k_2=4e-4 \, ms^{-1}` + (half-activation of 0.002 mM :math:`Ca^{2+}`). The calcium-bound form :math:`P_1` associates + with the open form of the channel, leading to a locked open form :math:`O_L`, with rates of + :math:`k_3=0.1 \, ms^{-1}` and :math:`k_4 = 0.001 \, ms^{-1}`. + + The current is the proportional to the relative concentration of open channels + + .. math:: + + I_h = g_h (O+g_{inc}O_L) (V - E_h) + + with a maximal conductance of :math:`\bar{g}_{\mathrm{h}}=0.02 \mathrm{mS} / \mathrm{cm}^{2}` + and a reversal potential of :math:`E_{\mathrm{h}}=-40 \mathrm{mV}`. Because of the factor + :math:`g_{\text {inc }}=2`, the conductance of the calcium-bound open state of + :math:`I_{\mathrm{h}}` channels is twice that of the unbound open state. This produces an + augmentation of conductance after the binding of :math:`\mathrm{Ca}^{2+}`, as observed in + sino-atrial cells (Hagiwara and Irisawa 1989). + + The rates of :math:`\alpha` and :math:`\beta` are: + + .. math:: + + & \alpha = m_{\infty} / \tau_m \\ + & \beta = (1-m_{\infty}) / \tau_m \\ + & m_{\infty} = 1/(1+\exp((V+75-V_{sh})/5.5)) \\ + & \tau_m = (5.3 + 267/(\exp((V+71.5-V_{sh})/14.2) + \exp(-(V+89-V_{sh})/11.6))) + + and the temperature regulating factor :math:`\phi=2^{(T-24)/10}`. + + References:: + + .. [1] Destexhe, Alain, et al. "Ionic mechanisms underlying synchronized + oscillations and propagating waves in a model of ferret thalamic + slices." Journal of neurophysiology 76.3 (1996): 2049-2070. + """ + + master_type = Calcium + + def __init__( + self, + size: Shape, + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -40., + k2: Union[float, ArrayType, Initializer, Callable] = 4e-4, + k4: Union[float, ArrayType, Initializer, Callable] = 1e-3, + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + g_max: Union[float, ArrayType, Initializer, Callable] = 0.02, + g_inc: Union[float, ArrayType, Initializer, Callable] = 2., + Ca_half: Union[float, ArrayType, Initializer, Callable] = 2e-3, + T: Union[float, ArrayType] = 36., + T_base: Union[float, ArrayType] = 3., + phi: Union[float, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base = parameter(T_base, self.varshape, allow_none=False) + if phi is None: + self.phi = self.T_base ** ((self.T - 24.) / 10) + else: + self.phi = parameter(phi, self.varshape, allow_none=False) + self.E = parameter(E, self.varshape, allow_none=False) + self.k2 = parameter(k2, self.varshape, allow_none=False) + self.Ca_half = parameter(Ca_half, self.varshape, allow_none=False) + self.k1 = self.k2 / self.Ca_half ** 4 + self.k4 = parameter(k4, self.varshape, allow_none=False) + self.k3 = self.k4 / 0.01 + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.g_inc = parameter(g_inc, self.varshape, allow_none=False) + + # variable + self.O = variable(bm.zeros, self.mode, self.varshape) + self.OL = variable(bm.zeros, self.mode, self.varshape) + self.P1 = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dO, self.dOL, self.dP1), method=method) + + def dO(self, O, t, OL, V): + inf = self.f_inf(V) + tau = self.f_tau(V) + alpha = inf / tau + beta = (1 - inf) / tau + return alpha * (1 - O - OL) - beta * O + + def dOL(self, OL, t, O, P1): + return self.k3 * P1 * O - self.k4 * OL + + def dP1(self, P1, t, C_Ca): + return self.k1 * C_Ca ** 4 * (1 - P1) - self.k2 * P1 + + def update(self, V, C_Ca, E_Ca): + self.O.value, self.OL.value, self.P1.value = self.integral(self.O.value, self.OL.value, self.P1.value, + share['t'], V=V, C_Ca=C_Ca, dt=share['dt']) + + def current(self, V, C_Ca, E_Ca): + return self.g_max * (self.O + self.g_inc * self.OL) * (self.E - V) + + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): + varshape = self.varshape if (batch_size is None) else ((batch_size,) + self.varshape) + self.P1.value = bm.broadcast_to(self.k1 * C_Ca ** 4 / (self.k1 * C_Ca ** 4 + self.k2), varshape) + inf = self.f_inf(V) + tau = self.f_tau(V) + alpha = inf / tau + beta = (1 - inf) / tau + self.O.value = alpha / (alpha + alpha * self.k3 * self.P1 / self.k4 + beta) + self.OL.value = self.k3 * self.P1 * self.O / self.k4 + if isinstance(batch_size, int): + assert self.P1.shape[0] == batch_size + assert self.O.shape[0] == batch_size + assert self.OL.shape[0] == batch_size + + def f_inf(self, V): + return 1 / (1 + bm.exp((V + 75 - self.V_sh) / 5.5)) + + def f_tau(self, V): + return (20. + 1000 / (bm.exp((V + 71.5 - self.V_sh) / 14.2) + + bm.exp(-(V + 89 - self.V_sh) / 11.6))) / self.phi diff --git a/brainpy/_src/dyn/channels/leaky.py b/brainpy/_src/dyn/channels/leaky.py index 04ce7de5e..c75d43ceb 100644 --- a/brainpy/_src/dyn/channels/leaky.py +++ b/brainpy/_src/dyn/channels/leaky.py @@ -14,55 +14,55 @@ from .base import IonChannel __all__ = [ - 'LeakyChannel', - 'IL', + 'LeakyChannel', + 'IL', ] class LeakyChannel(IonChannel): - """Base class for leaky channel dynamics.""" + """Base class for leaky channel dynamics.""" - master_type = HHTypedNeuron + master_type = HHTypedNeuron - def reset_state(self, V, batch_size=None): - pass + def reset_state(self, V, batch_size=None): + pass class IL(LeakyChannel): - """The leakage channel current. - - Parameters:: - - g_max : float - The leakage conductance. - E : float - The reversal potential. - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.1, - E: Union[int, float, ArrayType, Initializer, Callable] = -70., - method: str = None, - name: str = None, - mode: bm.Mode = None, - ): - super(IL, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.method = method - - def reset_state(self, V, batch_size=None): - pass - - def update(self, V): - pass - - def current(self, V): - return self.g_max * (self.E - V) + """The leakage channel current. + + Parameters:: + + g_max : float + The leakage conductance. + E : float + The reversal potential. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.1, + E: Union[int, float, ArrayType, Initializer, Callable] = -70., + method: str = None, + name: str = None, + mode: bm.Mode = None, + ): + super(IL, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.method = method + + def reset_state(self, V, batch_size=None): + pass + + def update(self, V): + pass + + def current(self, V): + return self.g_max * (self.E - V) diff --git a/brainpy/_src/dyn/channels/potassium.py b/brainpy/_src/dyn/channels/potassium.py index 9fef7fce8..d29ae2933 100644 --- a/brainpy/_src/dyn/channels/potassium.py +++ b/brainpy/_src/dyn/channels/potassium.py @@ -17,2022 +17,2022 @@ from .base import IonChannel __all__ = [ - 'PotassiumChannel', - 'IKDR_Ba2002v2', - 'IK_TM1991v2', - 'IK_HH1952v2', - 'IKA1_HM1992v2', - 'IKA2_HM1992v2', - 'IKK2A_HM1992v2', - 'IKK2B_HM1992v2', - 'IKNI_Ya1989v2', - 'IK_Leak', + 'PotassiumChannel', + 'IKDR_Ba2002v2', + 'IK_TM1991v2', + 'IK_HH1952v2', + 'IKA1_HM1992v2', + 'IKA2_HM1992v2', + 'IKK2A_HM1992v2', + 'IKK2B_HM1992v2', + 'IKNI_Ya1989v2', + 'IK_Leak', ] class PotassiumChannel(IonChannel): - """Base class for sodium channel dynamics.""" + """Base class for sodium channel dynamics.""" - master_type = Potassium + master_type = Potassium - def update(self, V, C, E): - raise NotImplementedError + def update(self, V, C, E): + raise NotImplementedError - def current(self, V, C, E): - raise NotImplementedError + def current(self, V, C, E): + raise NotImplementedError - def reset(self, V, C, E, batch_size: int = None): - self.reset_state(V, C, E, batch_size) + def reset(self, V, C, E, batch_size: int = None): + self.reset_state(V, C, E, batch_size) - def reset_state(self, V, C, E, batch_size: int = None): - raise NotImplementedError('Must be implemented by the subclass.') + def reset_state(self, V, C, E, batch_size: int = None): + raise NotImplementedError('Must be implemented by the subclass.') class _IK_p4_markov_v2(PotassiumChannel): - r"""The delayed rectifier potassium channel of :math:`p^4` - current which described with first-order Markov chain. - - This general potassium current model should have the form of - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor. - - Parameters:: - - size: int, sequence of int - The object size. - keep_size: bool - Whether we use `size` to initialize the variable. Otherwise, variable shape - will be initialized as `num`. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi : float, ArrayType, Initializer, Callable - The temperature-dependent factor. - method: str - The numerical integration method. - name: str - The object name. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.derivative, method=method) - - def derivative(self, p, t, V): - return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) - - def update(self, V, C, E): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - return self.g_max * self.p ** 4 * (E - V) - - def reset_state(self, V, C, E, batch_size=None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError + r"""The delayed rectifier potassium channel of :math:`p^4` + current which described with first-order Markov chain. + + This general potassium current model should have the form of + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor. + + Parameters:: + + size: int, sequence of int + The object size. + keep_size: bool + Whether we use `size` to initialize the variable. Otherwise, variable shape + will be initialized as `num`. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi : float, ArrayType, Initializer, Callable + The temperature-dependent factor. + method: str + The numerical integration method. + name: str + The object name. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.derivative, method=method) + + def derivative(self, p, t, V): + return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) + + def update(self, V, C, E): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + return self.g_max * self.p ** 4 * (E - V) + + def reset_state(self, V, C, E, batch_size=None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError class IKDR_Ba2002v2(_IK_p4_markov_v2): - r"""The delayed rectifier potassium channel current. - - The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ - \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters:: - - size: int, sequence of int - The object size. - keep_size: bool - Whether we use `size` to initialize the variable. Otherwise, variable shape - will be initialized as `num`. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - T_base : float, ArrayType - The brainpy_object of temperature factor. - T : float, ArrayType, Initializer, Callable - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : float, ArrayType, Initializer, Callable - The shift of the membrane potential to spike. - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = -50., - T_base: Union[float, ArrayType] = 3., - T: Union[float, ArrayType] = 36., - phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None, - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - phi = T_base ** ((T - 36) / 10) if phi is None else phi - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi=phi, - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base = parameter(T_base, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) - - def f_p_beta(self, V): - return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) + r"""The delayed rectifier potassium channel current. + + The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ + \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters:: + + size: int, sequence of int + The object size. + keep_size: bool + Whether we use `size` to initialize the variable. Otherwise, variable shape + will be initialized as `num`. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + T_base : float, ArrayType + The brainpy_object of temperature factor. + T : float, ArrayType, Initializer, Callable + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : float, ArrayType, Initializer, Callable + The shift of the membrane potential to spike. + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = -50., + T_base: Union[float, ArrayType] = 3., + T: Union[float, ArrayType] = 36., + phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None, + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + phi = T_base ** ((T - 36) / 10) if phi is None else phi + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi=phi, + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base = parameter(T_base, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + tmp = V - self.V_sh - 15. + return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + + def f_p_beta(self, V): + return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) class IK_TM1991v2(_IK_p4_markov_v2): - r"""The potassium channel described by (Traub and Miles, 1991) [1]_. - - The dynamics of this channel is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ - \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also:: - - INa_TM1991 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - c = 15 - V + self.V_sh - return 0.032 * c / (bm.exp(c / 5) - 1.) - - def f_p_beta(self, V): - return 0.5 * bm.exp((10 - V + self.V_sh) / 40) + r"""The potassium channel described by (Traub and Miles, 1991) [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ + \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also:: + + INa_TM1991 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + c = 15 - V + self.V_sh + return 0.032 * c / (bm.exp(c / 5) - 1.) + + def f_p_beta(self, V): + return 0.5 * bm.exp((10 - V + self.V_sh) / 40) class IK_HH1952v2(_IK_p4_markov_v2): - r"""The potassium channel described by Hodgkin–Huxley model [1]_. - - The dynamics of this channel is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ - \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also:: - - INa_HH1952 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - bm.exp(-temp / 10)) - - def f_p_beta(self, V): - return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) + r"""The potassium channel described by Hodgkin–Huxley model [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ + \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also:: + + INa_HH1952 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh + 10 + return 0.01 * temp / (1 - bm.exp(-temp / 10)) + + def f_p_beta(self, V): + return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) class _IKA_p4q_ss_v2(PotassiumChannel): - r"""The rapidly inactivating Potassium channel of :math:`p^4q` - current which described with steady-state format. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dp, self.dq), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V, C, E): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - return self.g_max * self.p ** 4 * self.q * (E - V) - - def reset_state(self, V, C, E, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The rapidly inactivating Potassium channel of :math:`p^4q` + current which described with steady-state format. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dp, self.dq), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V, C, E): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + return self.g_max * self.p ** 4 * self.q * (E - V) + + def reset_state(self, V, C, E, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKA1_HM1992v2(_IKA_p4q_ss_v2): - r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also:: - - IKA2_HM1992 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 30., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + - bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) - - def f_q_tau(self, V): - return bm.where(V < -63 + self.V_sh, - 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + - bm.exp(-(V - self.V_sh + 238.) / 37.5)), - 19.) + r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also:: + + IKA2_HM1992 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 30., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) + + def f_q_tau(self, V): + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), + 19.) class IKA2_HM1992v2(_IKA_p4q_ss_v2): - r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also:: - - IKA1_HM1992 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 20., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi_q=phi_q, - phi_p=phi_p, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + - bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) - - def f_q_tau(self, V): - return bm.where(V < -63 + self.V_sh, - 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + - bm.exp(-(V - self.V_sh + 238.) / 37.5)), - 19.) + r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also:: + + IKA1_HM1992 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 20., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi_q=phi_q, + phi_p=phi_p, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) + + def f_q_tau(self, V): + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), + 19.) class _IKK2_pq_ss_v2(PotassiumChannel): - r"""The slowly inactivating Potassium channel of :math:`pq` - current which described with steady-state format. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dp, self.dq), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V, C, E): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - return self.g_max * self.p * self.q * (E - V) - - def reset_state(self, V, C, E, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The slowly inactivating Potassium channel of :math:`pq` + current which described with steady-state format. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dp, self.dq), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V, C, E): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + return self.g_max * self.p * self.q * (E - V) + + def reset_state(self, V, C, E, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKK2A_HM1992v2(_IKK2_pq_ss_v2): - r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + - bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) - - def f_q_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + - bm.exp(-(V - self.V_sh + 130.) / 7.1)) + r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) + + def f_q_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)) class IKK2B_HM1992v2(_IKK2_pq_ss_v2): - r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + - \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ - \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + - bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) - - def f_q_tau(self, V): - return bm.where(V < -70 + self.V_sh, - 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + - bm.exp(-(V - self.V_sh + 130.) / 7.1)), - 8.9) + r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + + \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ + \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) + + def f_q_tau(self, V): + return bm.where(V < -70 + self.V_sh, + 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)), + 8.9) class IKNI_Ya1989v2(PotassiumChannel): - r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. - - This slow potassium current can effectively account for spike-frequency adaptation. - - .. math:: - - \begin{aligned} - &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ - &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ - &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ - &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} - \end{aligned} - - where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and - :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer - The :math:`tau_{\max}` parameter. - - References:: - - .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[float, ArrayType, Initializer, Callable] = 0.004, - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3, - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.tau_max = parameter(tau_max, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.dp, method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def update(self, V, C, E): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V, C, E): - return self.g_max * self.p * (E - V) - - def reset_state(self, V, C, E, batch_size=None): - self.p.value = self.f_p_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) - - def f_p_tau(self, V): - temp = V - self.V_sh + 35. - return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) - - -class _IK_p4_markov(PotassiumChannel): - r"""The delayed rectifier potassium channel of :math:`p^4` - current which described with first-order Markov chain. + r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. - This general potassium current model should have the form of + This slow potassium current can effectively account for spike-frequency adaptation. - .. math:: + .. math:: \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) + &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ + &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ + &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ + &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} \end{aligned} - where :math:`\phi` is a temperature-dependent factor. - - Parameters:: - - size: int, sequence of int - The object size. - keep_size: bool - Whether we use `size` to initialize the variable. Otherwise, variable shape - will be initialized as `num`. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi : float, ArrayType, Initializer, Callable - The temperature-dependent factor. - method: str - The numerical integration method. - name: str - The object name. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.derivative, method=method) - - def derivative(self, p, t, V): - return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) - - def update(self, V): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p ** 4 * (self.E - V) - - def reset_state(self, V, batch_size=None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError - - -class IKDR_Ba2002(_IK_p4_markov): - r"""The delayed rectifier potassium channel current. + where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and + :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + tau_max: float, ArrayType, Callable, Initializer + The :math:`tau_{\max}` parameter. + + References:: + + .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[float, ArrayType, Initializer, Callable] = 0.004, + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3, + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.tau_max = parameter(tau_max, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.dp, method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def update(self, V, C, E): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V, C, E): + return self.g_max * self.p * (E - V) + + def reset_state(self, V, C, E, batch_size=None): + self.p.value = self.f_p_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) + + def f_p_tau(self, V): + temp = V - self.V_sh + 35. + return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) - The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: - .. math:: +class _IK_p4_markov(PotassiumChannel): + r"""The delayed rectifier potassium channel of :math:`p^4` + current which described with first-order Markov chain. + + This general potassium current model should have the form of + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor. + + Parameters:: + + size: int, sequence of int + The object size. + keep_size: bool + Whether we use `size` to initialize the variable. Otherwise, variable shape + will be initialized as `num`. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi : float, ArrayType, Initializer, Callable + The temperature-dependent factor. + method: str + The numerical integration method. + name: str + The object name. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.derivative, method=method) + + def derivative(self, p, t, V): + return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) + + def update(self, V): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p ** 4 * (self.E - V) + + def reset_state(self, V, batch_size=None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ - \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) - \end{aligned} - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters:: - - size: int, sequence of int - The object size. - keep_size: bool - Whether we use `size` to initialize the variable. Otherwise, variable shape - will be initialized as `num`. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - T_base : float, ArrayType - The brainpy_object of temperature factor. - T : float, ArrayType, Initializer, Callable - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : float, ArrayType, Initializer, Callable - The shift of the membrane potential to spike. - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = -50., - T_base: Union[float, ArrayType] = 3., - T: Union[float, ArrayType] = 36., - phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None, - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - phi = T_base ** ((T - 36) / 10) if phi is None else phi - super(IKDR_Ba2002, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi=phi, - E=E, - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base = parameter(T_base, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) - - def f_p_beta(self, V): - return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) +class IKDR_Ba2002(_IK_p4_markov): + r"""The delayed rectifier potassium channel current. + + The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ + \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters:: + + size: int, sequence of int + The object size. + keep_size: bool + Whether we use `size` to initialize the variable. Otherwise, variable shape + will be initialized as `num`. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + T_base : float, ArrayType + The brainpy_object of temperature factor. + T : float, ArrayType, Initializer, Callable + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : float, ArrayType, Initializer, Callable + The shift of the membrane potential to spike. + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = -50., + T_base: Union[float, ArrayType] = 3., + T: Union[float, ArrayType] = 36., + phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None, + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + phi = T_base ** ((T - 36) / 10) if phi is None else phi + super(IKDR_Ba2002, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi=phi, + E=E, + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base = parameter(T_base, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + tmp = V - self.V_sh - 15. + return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + + def f_p_beta(self, V): + return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) class IK_TM1991(_IK_p4_markov): - r"""The potassium channel described by (Traub and Miles, 1991) [1]_. - - The dynamics of this channel is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ - \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also:: - - INa_TM1991 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IK_TM1991, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - E=E, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - c = 15 - V + self.V_sh - return 0.032 * c / (bm.exp(c / 5) - 1.) - - def f_p_beta(self, V): - return 0.5 * bm.exp((10 - V + self.V_sh) / 40) + r"""The potassium channel described by (Traub and Miles, 1991) [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ + \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also:: + + INa_TM1991 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IK_TM1991, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + E=E, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + c = 15 - V + self.V_sh + return 0.032 * c / (bm.exp(c / 5) - 1.) + + def f_p_beta(self, V): + return 0.5 * bm.exp((10 - V + self.V_sh) / 40) class IK_HH1952(_IK_p4_markov): - r"""The potassium channel described by Hodgkin–Huxley model [1]_. - - The dynamics of this channel is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ - \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also:: - - INa_HH1952 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IK_HH1952, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - E=E, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - bm.exp(-temp / 10)) - - def f_p_beta(self, V): - return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) + r"""The potassium channel described by Hodgkin–Huxley model [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ + \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also:: + + INa_HH1952 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IK_HH1952, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + E=E, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh + 10 + return 0.01 * temp / (1 - bm.exp(-temp / 10)) + + def f_p_beta(self, V): + return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) class _IKA_p4q_ss(PotassiumChannel): - r"""The rapidly inactivating Potassium channel of :math:`p^4q` - current which described with steady-state format. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dp, self.dq), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p ** 4 * self.q * (self.E - V) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The rapidly inactivating Potassium channel of :math:`p^4q` + current which described with steady-state format. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dp, self.dq), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p ** 4 * self.q * (self.E - V) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKA1_HM1992(_IKA_p4q_ss): - r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also:: - - IKA2_HM1992 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 30., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKA1_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + - bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) - - def f_q_tau(self, V): - return bm.where(V < -63 + self.V_sh, - 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + - bm.exp(-(V - self.V_sh + 238.) / 37.5)), - 19.) + r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also:: + + IKA2_HM1992 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 30., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKA1_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) + + def f_q_tau(self, V): + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), + 19.) class IKA2_HM1992(_IKA_p4q_ss): - r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also:: - - IKA1_HM1992 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 20., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKA2_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - g_max=g_max, - phi_q=phi_q, - phi_p=phi_p, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + - bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) - - def f_q_tau(self, V): - return bm.where(V < -63 + self.V_sh, - 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + - bm.exp(-(V - self.V_sh + 238.) / 37.5)), - 19.) + r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also:: + + IKA1_HM1992 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 20., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKA2_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + g_max=g_max, + phi_q=phi_q, + phi_p=phi_p, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) + + def f_q_tau(self, V): + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), + 19.) class _IKK2_pq_ss(PotassiumChannel): - r"""The slowly inactivating Potassium channel of :math:`pq` - current which described with steady-state format. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dp, self.dq), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p * self.q * (self.E - V) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The slowly inactivating Potassium channel of :math:`pq` + current which described with steady-state format. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dp, self.dq), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p * self.q * (self.E - V) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKK2A_HM1992(_IKK2_pq_ss): - r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKK2A_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - E=E, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + - bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) - - def f_q_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + - bm.exp(-(V - self.V_sh + 130.) / 7.1)) + r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKK2A_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + E=E, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) + + def f_q_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)) class IKK2B_HM1992(_IKK2_pq_ss): - r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + - \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ - \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKK2B_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - E=E, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + - bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) - - def f_q_tau(self, V): - return bm.where(V < -70 + self.V_sh, - 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + - bm.exp(-(V - self.V_sh + 130.) / 7.1)), - 8.9) + r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + + \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ + \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKK2B_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + E=E, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) + + def f_q_tau(self, V): + return bm.where(V < -70 + self.V_sh, + 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)), + 8.9) class IKNI_Ya1989(PotassiumChannel): - r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. - - This slow potassium current can effectively account for spike-frequency adaptation. - - .. math:: - - \begin{aligned} - &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ - &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ - &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ - &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} - \end{aligned} - - where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and - :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer - The :math:`tau_{\max}` parameter. - - References:: - - .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 0.004, - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3, - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKNI_Ya1989, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.tau_max = parameter(tau_max, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.dp, method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def update(self, V): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p * (self.E - V) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) - - def f_p_tau(self, V): - temp = V - self.V_sh + 35. - return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) + r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. + + This slow potassium current can effectively account for spike-frequency adaptation. + + .. math:: + + \begin{aligned} + &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ + &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ + &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ + &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} + \end{aligned} + + where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and + :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + tau_max: float, ArrayType, Callable, Initializer + The :math:`tau_{\max}` parameter. + + References:: + + .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 0.004, + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3, + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKNI_Ya1989, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.tau_max = parameter(tau_max, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.dp, method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def update(self, V): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p * (self.E - V) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) + + def f_p_tau(self, V): + temp = V - self.V_sh + 35. + return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) class IK_Leak(PotassiumChannel): - """The potassium leak channel current. - - Parameters:: - - g_max : float - The potassium leakage conductance which is modulated by both - acetylcholine and norepinephrine. - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005, - method: str = None, - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size=size, - keep_size=keep_size, - method=method, - name=name, - mode=mode) - self.g_max = self.init_param(g_max, self.varshape) - - def reset_state(self, V, C, E, batch_size: int = None): - pass - - def update(self, V, C, E): - pass - - def current(self, V, C, E): - return self.g_max * (E - V) + """The potassium leak channel current. + + Parameters:: + + g_max : float + The potassium leakage conductance which is modulated by both + acetylcholine and norepinephrine. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005, + method: str = None, + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size=size, + keep_size=keep_size, + method=method, + name=name, + mode=mode) + self.g_max = self.init_param(g_max, self.varshape) + + def reset_state(self, V, C, E, batch_size: int = None): + pass + + def update(self, V, C, E): + pass + + def current(self, V, C, E): + return self.g_max * (E - V) diff --git a/brainpy/_src/dyn/channels/potassium_calcium.py b/brainpy/_src/dyn/channels/potassium_calcium.py index 9681183e9..44b23e069 100644 --- a/brainpy/_src/dyn/channels/potassium_calcium.py +++ b/brainpy/_src/dyn/channels/potassium_calcium.py @@ -19,109 +19,109 @@ from .potassium import PotassiumChannel __all__ = [ - 'IAHP_De1994v2', + 'IAHP_De1994v2', ] class KCaChannel(PotassiumChannel, CalciumChannel): - pass + pass class IAHP_De1994v2(KCaChannel): - r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. - - Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in - vitro recordings (Avanzini et al. 1989) show the presence of a - marked after-hyper-polarization (AHP) after each burst of the RE - cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+ - current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a - modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al. - 1989) that requires the binding of :math:`nCa^{2+}` to open the channel - - .. math:: - - (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open }) - - where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and - :math:`\beta` are rate constants. The ionic current is then given by - - .. math:: - - \begin{aligned} - I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\ - {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\ - p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\ - \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)} - \end{aligned} - - where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance, - :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration. - The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and - :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells - recorded in vivo and in vitro. - - Parameters:: - - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - - References:: - - .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated - thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. - - """ - - '''The type of the master object.''' - master_type = JointType[Calcium, Potassium] - - def __init__( - self, - size: Shape, - keep_size: bool = False, - n: Union[float, ArrayType, Initializer, Callable] = 2, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - alpha: Union[float, ArrayType, Initializer, Callable] = 48., - beta: Union[float, ArrayType, Initializer, Callable] = 0.09, - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.n = parameter(n, self.varshape, allow_none=False) - self.alpha = parameter(alpha, self.varshape, allow_none=False) - self.beta = parameter(beta, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.dp, method=method) - - def dp(self, p, t, C_Ca): - C2 = self.alpha * bm.power(C_Ca, self.n) - C3 = C2 + self.beta - return self.phi * (C2 / C3 - p) * C3 - - def update(self, V, Ca_info, K_info): - self.p.value = self.integral(self.p.value, share['t'], C_Ca=Ca_info['C'], dt=share['dt']) - - def current(self, V, Ca_info, K_info): - return self.g_max * self.p * self.p * (K_info['E'] - V) - - def reset_state(self, V, Ca_info, K_info, batch_size=None): - C2 = self.alpha * bm.power(Ca_info['C'], self.n) - C3 = C2 + self.beta - if batch_size is None: - self.p.value = bm.broadcast_to(C2 / C3, self.varshape) - else: - self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape) - assert self.p.shape[0] == batch_size + r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. + + Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in + vitro recordings (Avanzini et al. 1989) show the presence of a + marked after-hyper-polarization (AHP) after each burst of the RE + cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+ + current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a + modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al. + 1989) that requires the binding of :math:`nCa^{2+}` to open the channel + + .. math:: + + (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open }) + + where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and + :math:`\beta` are rate constants. The ionic current is then given by + + .. math:: + + \begin{aligned} + I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\ + {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\ + p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\ + \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)} + \end{aligned} + + where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance, + :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration. + The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and + :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells + recorded in vivo and in vitro. + + Parameters:: + + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + + References:: + + .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated + thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. + + """ + + '''The type of the master object.''' + master_type = JointType[Calcium, Potassium] + + def __init__( + self, + size: Shape, + keep_size: bool = False, + n: Union[float, ArrayType, Initializer, Callable] = 2, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + alpha: Union[float, ArrayType, Initializer, Callable] = 48., + beta: Union[float, ArrayType, Initializer, Callable] = 0.09, + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.n = parameter(n, self.varshape, allow_none=False) + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.dp, method=method) + + def dp(self, p, t, C_Ca): + C2 = self.alpha * bm.power(C_Ca, self.n) + C3 = C2 + self.beta + return self.phi * (C2 / C3 - p) * C3 + + def update(self, V, Ca_info, K_info): + self.p.value = self.integral(self.p.value, share['t'], C_Ca=Ca_info['C'], dt=share['dt']) + + def current(self, V, Ca_info, K_info): + return self.g_max * self.p * self.p * (K_info['E'] - V) + + def reset_state(self, V, Ca_info, K_info, batch_size=None): + C2 = self.alpha * bm.power(Ca_info['C'], self.n) + C3 = C2 + self.beta + if batch_size is None: + self.p.value = bm.broadcast_to(C2 / C3, self.varshape) + else: + self.p.value = bm.broadcast_to(C2 / C3, (batch_size,) + self.varshape) + assert self.p.shape[0] == batch_size diff --git a/brainpy/_src/dyn/channels/potassium_calcium_compatible.py b/brainpy/_src/dyn/channels/potassium_calcium_compatible.py index 80e463a2e..314433114 100644 --- a/brainpy/_src/dyn/channels/potassium_calcium_compatible.py +++ b/brainpy/_src/dyn/channels/potassium_calcium_compatible.py @@ -17,118 +17,117 @@ from .base import IonChannel __all__ = [ - 'IAHP_De1994', + 'IAHP_De1994', ] class IAHP_De1994(IonChannel): - r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. - - Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in - vitro recordings (Avanzini et al. 1989) show the presence of a - marked after-hyper-polarization (AHP) after each burst of the RE - cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+ - current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a - modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al. - 1989) that requires the binding of :math:`nCa^{2+}` to open the channel - - .. math:: - - (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open }) - - where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and - :math:`\beta` are rate constants. The ionic current is then given by - - .. math:: - - \begin{aligned} - I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\ - {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\ - p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\ - \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)} - \end{aligned} - - where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance, - :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration. - The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and - :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells - recorded in vivo and in vitro. - - Parameters:: - - g_max : float - The maximal conductance density (:math:`mS/cm^2`). - E : float - The reversal potential (mV). - - References:: - - .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated - thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. - - """ - - '''The type of the master object.''' - master_type = Calcium - - def __init__( - self, - size: Shape, - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -95., - n: Union[float, ArrayType, Initializer, Callable] = 2, - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - alpha: Union[float, ArrayType, Initializer, Callable] = 48., - beta: Union[float, ArrayType, Initializer, Callable] = 0.09, - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.n = parameter(n, self.varshape, allow_none=False) - self.alpha = parameter(alpha, self.varshape, allow_none=False) - self.beta = parameter(beta, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.dp, method=method) - - def dp(self, p, t, C_Ca): - C2 = self.alpha * bm.power(C_Ca, self.n) - C3 = C2 + self.beta - return self.phi * (C2 / C3 - p) * C3 - - def update(self, V, C_Ca, E_Ca): - self.p.value = self.integral(self.p.value, share['t'], C_Ca=C_Ca, dt=share['dt']) - - def current(self, V, C_Ca, E_Ca): - return self.g_max * self.p * self.p * (self.E - V) - - def reset_state(self, V, C_Ca, E_Ca, batch_size=None): - C2 = self.alpha * bm.power(C_Ca, self.n) - C3 = C2 + self.beta - self.p[:] = C2 / C3 - if isinstance(batch_size, int): - batch_size = batch_size - size = (batch_size,) + self.varshape - elif isinstance(batch_size, bm.Mode): - if isinstance(batch_size, bm.BatchingMode): - size = (batch_size.batch_size,) + self.varshape - else: - batch_size = None - size = self.varshape - else: - size = self.varshape - self.p.value = bm.broadcast_to(C2 / C3, size) - + r"""The calcium-dependent potassium current model proposed by (Destexhe, et al., 1994) [1]_. + + Both in vivo (Contreras et al. 1993; Mulle et al. 1986) and in + vitro recordings (Avanzini et al. 1989) show the presence of a + marked after-hyper-polarization (AHP) after each burst of the RE + cell. This slow AHP is mediated by a slow :math:`Ca^{2+}`-dependent K+ + current (Bal and McCormick 1993). (Destexhe, et al., 1994) adopted a + modified version of a model of :math:`I_{KCa}` introduced previously (Yamada et al. + 1989) that requires the binding of :math:`nCa^{2+}` to open the channel + + .. math:: + + (\text { closed })+n \mathrm{Ca}_{i}^{2+} \underset{\beta}{\stackrel{\alpha}{\rightleftharpoons}(\text { open }) + + where :math:`Ca_i^{2+}` is the intracellular calcium and :math:`\alpha` and + :math:`\beta` are rate constants. The ionic current is then given by + + .. math:: + + \begin{aligned} + I_{AHP} &= g_{\mathrm{max}} p^2 (V - E_K) \\ + {dp \over dt} &= \phi {p_{\infty}(V, [Ca^{2+}]_i) - p \over \tau_p(V, [Ca^{2+}]_i)} \\ + p_{\infty} &=\frac{\alpha[Ca^{2+}]_i^n}{\left(\alpha[Ca^{2+}]_i^n + \beta\right)} \\ + \tau_p &=\frac{1}{\left(\alpha[Ca^{2+}]_i +\beta\right)} + \end{aligned} + + where :math:`E` is the reversal potential, :math:`g_{max}` is the maximum conductance, + :math:`[Ca^{2+}]_i` is the intracellular Calcium concentration. + The values :math:`n=2, \alpha=48 \mathrm{~ms}^{-1} \mathrm{mM}^{-2}` and + :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells + recorded in vivo and in vitro. + + Parameters:: + + g_max : float + The maximal conductance density (:math:`mS/cm^2`). + E : float + The reversal potential (mV). + + References:: + + .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated + thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. + + """ + + '''The type of the master object.''' + master_type = Calcium + + def __init__( + self, + size: Shape, + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -95., + n: Union[float, ArrayType, Initializer, Callable] = 2, + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + alpha: Union[float, ArrayType, Initializer, Callable] = 48., + beta: Union[float, ArrayType, Initializer, Callable] = 0.09, + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.n = parameter(n, self.varshape, allow_none=False) + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.dp, method=method) + + def dp(self, p, t, C_Ca): + C2 = self.alpha * bm.power(C_Ca, self.n) + C3 = C2 + self.beta + return self.phi * (C2 / C3 - p) * C3 + + def update(self, V, C_Ca, E_Ca): + self.p.value = self.integral(self.p.value, share['t'], C_Ca=C_Ca, dt=share['dt']) + + def current(self, V, C_Ca, E_Ca): + return self.g_max * self.p * self.p * (self.E - V) + + def reset_state(self, V, C_Ca, E_Ca, batch_size=None): + C2 = self.alpha * bm.power(C_Ca, self.n) + C3 = C2 + self.beta + self.p[:] = C2 / C3 + if isinstance(batch_size, int): + batch_size = batch_size + size = (batch_size,) + self.varshape + elif isinstance(batch_size, bm.Mode): + if isinstance(batch_size, bm.BatchingMode): + size = (batch_size.batch_size,) + self.varshape + else: + batch_size = None + size = self.varshape + else: + size = self.varshape + self.p.value = bm.broadcast_to(C2 / C3, size) diff --git a/brainpy/_src/dyn/channels/potassium_compatible.py b/brainpy/_src/dyn/channels/potassium_compatible.py index c1ad5763b..19f2568bc 100644 --- a/brainpy/_src/dyn/channels/potassium_compatible.py +++ b/brainpy/_src/dyn/channels/potassium_compatible.py @@ -16,1046 +16,1046 @@ from brainpy.types import ArrayType __all__ = [ - 'IKDR_Ba2002', - 'IK_TM1991', - 'IK_HH1952', - 'IKA1_HM1992', - 'IKA2_HM1992', - 'IKK2A_HM1992', - 'IKK2B_HM1992', - 'IKNI_Ya1989', - 'IKL', + 'IKDR_Ba2002', + 'IK_TM1991', + 'IK_HH1952', + 'IKA1_HM1992', + 'IKA2_HM1992', + 'IKK2A_HM1992', + 'IKK2B_HM1992', + 'IKNI_Ya1989', + 'IKL', ] class _IK_p4_markov(IonChannel): - r"""The delayed rectifier potassium channel of :math:`p^4` - current which described with first-order Markov chain. - - This general potassium current model should have the form of - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor. - - Parameters:: - - size: int, sequence of int - The object size. - keep_size: bool - Whether we use `size` to initialize the variable. Otherwise, variable shape - will be initialized as `num`. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi : float, ArrayType, Initializer, Callable - The temperature-dependent factor. - method: str - The numerical integration method. - name: str - The object name. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.derivative, method=method) - - def derivative(self, p, t, V): - return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) - - def update(self, V): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p ** 4 * (self.E - V) - - def reset_state(self, V, batch_size=None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError + r"""The delayed rectifier potassium channel of :math:`p^4` + current which described with first-order Markov chain. + + This general potassium current model should have the form of + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor. + + Parameters:: + + size: int, sequence of int + The object size. + keep_size: bool + Whether we use `size` to initialize the variable. Otherwise, variable shape + will be initialized as `num`. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi : float, ArrayType, Initializer, Callable + The temperature-dependent factor. + method: str + The numerical integration method. + name: str + The object name. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.derivative, method=method) + + def derivative(self, p, t, V): + return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) + + def update(self, V): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p ** 4 * (self.E - V) + + def reset_state(self, V, batch_size=None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError class IKDR_Ba2002(_IK_p4_markov): - r"""The delayed rectifier potassium channel current. - - The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ - \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters:: - - size: int, sequence of int - The object size. - keep_size: bool - Whether we use `size` to initialize the variable. Otherwise, variable shape - will be initialized as `num`. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - T_base : float, ArrayType - The brainpy_object of temperature factor. - T : float, ArrayType, Initializer, Callable - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : float, ArrayType, Initializer, Callable - The shift of the membrane potential to spike. - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = -50., - T_base: Union[float, ArrayType] = 3., - T: Union[float, ArrayType] = 36., - phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None, - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - phi = T_base ** ((T - 36) / 10) if phi is None else phi - super(IKDR_Ba2002, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - g_max=g_max, - phi=phi, - E=E, - mode=mode) - - # parameters - self.T = parameter(T, self.varshape, allow_none=False) - self.T_base = parameter(T_base, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - tmp = V - self.V_sh - 15. - return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) - - def f_p_beta(self, V): - return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) + r"""The delayed rectifier potassium channel current. + + The potassium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.032\left(V-V_{sh}-15\right)}{1-\exp \left(-\left(V-V_{sh}-15\right) / 5\right)} \\ + \beta_p &= 0.5 \exp \left(-\left(V-V_{sh}-10\right) / 40\right) + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters:: + + size: int, sequence of int + The object size. + keep_size: bool + Whether we use `size` to initialize the variable. Otherwise, variable shape + will be initialized as `num`. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + T_base : float, ArrayType + The brainpy_object of temperature factor. + T : float, ArrayType, Initializer, Callable + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : float, ArrayType, Initializer, Callable + The shift of the membrane potential to spike. + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = -50., + T_base: Union[float, ArrayType] = 3., + T: Union[float, ArrayType] = 36., + phi: Optional[Union[float, ArrayType, Initializer, Callable]] = None, + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + phi = T_base ** ((T - 36) / 10) if phi is None else phi + super(IKDR_Ba2002, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + g_max=g_max, + phi=phi, + E=E, + mode=mode) + + # parameters + self.T = parameter(T, self.varshape, allow_none=False) + self.T_base = parameter(T_base, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + tmp = V - self.V_sh - 15. + return 0.032 * tmp / (1. - bm.exp(-tmp / 5.)) + + def f_p_beta(self, V): + return 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) class IK_TM1991(_IK_p4_markov): - r"""The potassium channel described by (Traub and Miles, 1991) [1]_. - - The dynamics of this channel is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ - \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also:: - - INa_TM1991 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IK_TM1991, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - E=E, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - c = 15 - V + self.V_sh - return 0.032 * c / (bm.exp(c / 5) - 1.) - - def f_p_beta(self, V): - return 0.5 * bm.exp((10 - V + self.V_sh) / 40) + r"""The potassium channel described by (Traub and Miles, 1991) [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= 0.032 \frac{(15 - V + V_{sh})}{(\exp((15 - V + V_{sh}) / 5) - 1.)} \\ + \beta_p &= 0.5 * \exp((10 - V + V_{sh}) / 40) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also:: + + INa_TM1991 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -60., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IK_TM1991, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + E=E, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + c = 15 - V + self.V_sh + return 0.032 * c / (bm.exp(c / 5) - 1.) + + def f_p_beta(self, V): + return 0.5 * bm.exp((10 - V + self.V_sh) / 40) class IK_HH1952(_IK_p4_markov): - r"""The potassium channel described by Hodgkin–Huxley model [1]_. - - The dynamics of this channel is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ - \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ - \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) - \end{aligned} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - method: str - The numerical integration method. - name: str - The object name. - - References:: - - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also:: - - INa_HH1952 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi: Union[float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IK_HH1952, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - E=E, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh + 10 - return 0.01 * temp / (1 - bm.exp(-temp / 10)) - - def f_p_beta(self, V): - return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) + r"""The potassium channel described by Hodgkin–Huxley model [1]_. + + The dynamics of this channel is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{K}} &= g_{\mathrm{max}} * p^4 \\ + \frac{dp}{dt} &= \phi * (\alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &= \frac{0.01 (V -V_{sh} + 10)}{1-\exp \left(-\left(V-V_{sh}+ 10\right) / 10\right)} \\ + \beta_p &= 0.125 \exp \left(-\left(V-V_{sh}+20\right) / 80\right) + \end{aligned} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + method: str + The numerical integration method. + name: str + The object name. + + References:: + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also:: + + INa_HH1952 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi: Union[float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IK_HH1952, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + E=E, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh + 10 + return 0.01 * temp / (1 - bm.exp(-temp / 10)) + + def f_p_beta(self, V): + return 0.125 * bm.exp(-(V - self.V_sh + 20) / 80) class _IKA_p4q_ss(IonChannel): - r"""The rapidly inactivating Potassium channel of :math:`p^4q` - current which described with steady-state format. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dp, self.dq), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p ** 4 * self.q * (self.E - V) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The rapidly inactivating Potassium channel of :math:`p^4q` + current which described with steady-state format. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [3]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dp, self.dq), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p ** 4 * self.q * (self.E - V) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKA1_HM1992(_IKA_p4q_ss): - r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also:: - - IKA2_HM1992 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 30., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKA1_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - g_max=g_max, - phi_p=phi_p, - phi_q=phi_q, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + - bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) - - def f_q_tau(self, V): - return bm.where(V < -63 + self.V_sh, - 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + - bm.exp(-(V - self.V_sh + 238.) / 37.5)), - 19.) + r"""The rapidly inactivating Potassium channel (IA1) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 60)/8.5]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also:: + + IKA2_HM1992 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 30., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKA1_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + g_max=g_max, + phi_p=phi_p, + phi_q=phi_q, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 60.) / 8.5)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) + + def f_q_tau(self, V): + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), + 19.) class IKA2_HM1992(_IKA_p4q_ss): - r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. - - This model is developed according to the average behavior of - rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. - - .. math:: - - &IA = g_{\mathrm{max}} p^4 q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ - \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - See Also:: - - IKA1_HM1992 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 20., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKA2_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - g_max=g_max, - phi_q=phi_q, - phi_p=phi_p, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + - bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) - - def f_q_tau(self, V): - return bm.where(V < -63 + self.V_sh, - 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + - bm.exp(-(V - self.V_sh + 238.) / 37.5)), - 19.) + r"""The rapidly inactivating Potassium channel (IA2) model proposed by (Huguenard & McCormick, 1992) [2]_. + + This model is developed according to the average behavior of + rapidly inactivating Potassium channel in Thalamus relay neurons [2]_ [1]_. + + .. math:: + + &IA = g_{\mathrm{max}} p^4 q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 36)/20.]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}+35.8}{19.7}\right)+ \exp \left(\frac{V -V_{sh}+79.7}{-12.7}\right)}+0.37 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 78)/6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+46)/5.) + \exp((V -V_{sh}+238)/-37.5)} \quad V<(-63+V_{sh})\, mV \\ + \tau_{q} = 19 \quad V \geq (-63 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [1] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + See Also:: + + IKA1_HM1992 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 20., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKA2_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + g_max=g_max, + phi_q=phi_q, + phi_p=phi_p, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 36.) / 20.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh + 35.8) / 19.7) + + bm.exp(-(V - self.V_sh + 79.7) / 12.7)) + 0.37 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 78.) / 6.)) + + def f_q_tau(self, V): + return bm.where(V < -63 + self.V_sh, + 1. / (bm.exp((V - self.V_sh + 46.) / 5.) + + bm.exp(-(V - self.V_sh + 238.) / 37.5)), + 19.) class _IKK2_pq_ss(IonChannel): - r"""The slowly inactivating Potassium channel of :math:`pq` - current which described with steady-state format. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq(self.dp, self.dq), method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def dq(self, q, t, V): - return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) - - def update(self, V): - self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p * self.q * (self.E - V) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - self.q.value = self.f_q_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def f_p_inf(self, V): - raise NotImplementedError - - def f_p_tau(self, V): - raise NotImplementedError - - def f_q_inf(self, V): - raise NotImplementedError - - def f_q_tau(self, V): - raise NotImplementedError + r"""The slowly inactivating Potassium channel of :math:`pq` + current which described with steady-state format. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq(self.dp, self.dq), method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def dq(self, q, t, V): + return self.phi_q * (self.f_q_inf(V) - q) / self.f_q_tau(V) + + def update(self, V): + self.p.value, self.q.value = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p * self.q * (self.E - V) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + self.q.value = self.f_q_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def f_p_inf(self, V): + raise NotImplementedError + + def f_p_tau(self, V): + raise NotImplementedError + + def f_q_inf(self, V): + raise NotImplementedError + + def f_q_tau(self, V): + raise NotImplementedError class IKK2A_HM1992(_IKK2_pq_ss): - r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKK2A_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - E=E, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + - bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) - - def f_q_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + - bm.exp(-(V - self.V_sh + 130.) / 7.1)) + r"""The slowly inactivating Potassium channel (IK2a) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + & \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + \exp((V -V_{sh}+130)/-7.1)} + 120 \\ + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKK2A_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + E=E, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) + + def f_q_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)) class IKK2B_HM1992(_IKK2_pq_ss): - r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. - - The dynamics of the model is given as [2]_ [3]_. - - .. math:: - - &IK2 = g_{\mathrm{max}} p q (E-V) \\ - &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ - &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ - &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ - \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ - &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ - &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ - &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + - \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ - \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} - - where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - phi_q : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`q`. - - References:: - - .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the - currents involved in rhythmic oscillations in thalamic relay - neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. - .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a - TEA-sensitive K current in acutely isolated rat thalamic relay - neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 10., - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKK2B_HM1992, self).__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi_p=phi_p, - phi_q=phi_q, - g_max=g_max, - E=E, - mode=mode) - - # parameters - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) - - def f_p_tau(self, V): - return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + - bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 - - def f_q_inf(self, V): - return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) - - def f_q_tau(self, V): - return bm.where(V < -70 + self.V_sh, - 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + - bm.exp(-(V - self.V_sh + 130.) / 7.1)), - 8.9) + r"""The slowly inactivating Potassium channel (IK2b) model proposed by (Huguenard & McCormick, 1992) [2]_. + + The dynamics of the model is given as [2]_ [3]_. + + .. math:: + + &IK2 = g_{\mathrm{max}} p q (E-V) \\ + &\frac{dp}{dt} = \phi_p \frac{p_{\infty} - p}{\tau_p} \\ + &p_{\infty} = \frac{1}{1+ \exp[-(V -V_{sh}+ 43)/17]} \\ + &\tau_{p}=\frac{1}{\exp \left(\frac{V -V_{sh}-81.}{25.6}\right)+ + \exp \left(\frac{V -V_{sh}+132}{-18}\right)}+9.9 \\ + &\frac{dq}{dt} = \phi_q \frac{q_{\infty} - q}{\tau_q} \\ + &q_{\infty} = \frac{1}{1+ \exp[(V -V_{sh} + 59)/10.6]} \\ + &\begin{array}{l} \tau_{q} = \frac{1}{\exp((V -V_{sh}+1329)/200.) + + \exp((V -V_{sh}+130)/-7.1)} + 120 \quad V<(-70+V_{sh})\, mV \\ + \tau_{q} = 8.9 \quad V \geq (-70 + V_{sh})\, mV \end{array} + + where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + phi_q : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`q`. + + References:: + + .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the + currents involved in rhythmic oscillations in thalamic relay + neurons." Journal of neurophysiology 68.4 (1992): 1373-1383. + .. [3] Huguenard, J. R., and D. A. Prince. "Slow inactivation of a + TEA-sensitive K current in acutely isolated rat thalamic relay + neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 10., + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKK2B_HM1992, self).__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi_p=phi_p, + phi_q=phi_q, + g_max=g_max, + E=E, + mode=mode) + + # parameters + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 43.) / 17.)) + + def f_p_tau(self, V): + return 1. / (bm.exp((V - self.V_sh - 81.) / 25.6) + + bm.exp(-(V - self.V_sh + 132) / 18.)) + 9.9 + + def f_q_inf(self, V): + return 1. / (1. + bm.exp((V - self.V_sh + 58.) / 10.6)) + + def f_q_tau(self, V): + return bm.where(V < -70 + self.V_sh, + 1. / (bm.exp((V - self.V_sh - 1329.) / 200.) + + bm.exp(-(V - self.V_sh + 130.) / 7.1)), + 8.9) class IKNI_Ya1989(IonChannel): - r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. - - This slow potassium current can effectively account for spike-frequency adaptation. - - .. math:: - - \begin{aligned} - &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ - &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ - &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ - &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} - \end{aligned} - - where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and - :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - - Parameters:: - - size: int, sequence of int - The geometry size. - method: str - The numerical integration method. - name: str - The object name. - g_max : float, ArrayType, Initializer, Callable - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Initializer, Callable - The reversal potential (mV). - V_sh : float, ArrayType, Callable, Initializer - The membrane potential shift. - phi_p : optional, float, ArrayType, Callable, Initializer - The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer - The :math:`tau_{\max}` parameter. - - References:: - - .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -90., - g_max: Union[float, ArrayType, Initializer, Callable] = 0.004, - phi_p: Union[float, ArrayType, Initializer, Callable] = 1., - phi_q: Union[float, ArrayType, Initializer, Callable] = 1., - tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3, - V_sh: Union[float, ArrayType, Initializer, Callable] = 0., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(IKNI_Ya1989, self).__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.tau_max = parameter(tau_max, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - self.phi_p = parameter(phi_p, self.varshape, allow_none=False) - self.phi_q = parameter(phi_q, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(self.dp, method=method) - - def dp(self, p, t, V): - return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) - - def update(self, V): - self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) - - def current(self, V): - return self.g_max * self.p * (self.E - V) - - def reset_state(self, V, batch_size=None): - self.p.value = self.f_p_inf(V) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - - def f_p_inf(self, V): - return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) - - def f_p_tau(self, V): - temp = V - self.V_sh + 35. - return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) + r"""A slow non-inactivating K+ current described by Yamada et al. (1989) [1]_. + + This slow potassium current can effectively account for spike-frequency adaptation. + + .. math:: + + \begin{aligned} + &I_{M}=\bar{g}_{M} p\left(V-E_{K}\right) \\ + &\frac{\mathrm{d} p}{\mathrm{~d} t}=\left(p_{\infty}(V)-p\right) / \tau_{p}(V) \\ + &p_{\infty}(V)=\frac{1}{1+\exp [-(V-V_{sh}+35) / 10]} \\ + &\tau_{p}(V)=\frac{\tau_{\max }}{3.3 \exp [(V-V_{sh}+35) / 20]+\exp [-(V-V_{sh}+35) / 20]} + \end{aligned} + + where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and + :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. + + Parameters:: + + size: int, sequence of int + The geometry size. + method: str + The numerical integration method. + name: str + The object name. + g_max : float, ArrayType, Initializer, Callable + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Initializer, Callable + The reversal potential (mV). + V_sh : float, ArrayType, Callable, Initializer + The membrane potential shift. + phi_p : optional, float, ArrayType, Callable, Initializer + The temperature factor for channel :math:`p`. + tau_max: float, ArrayType, Callable, Initializer + The :math:`tau_{\max}` parameter. + + References:: + + .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -90., + g_max: Union[float, ArrayType, Initializer, Callable] = 0.004, + phi_p: Union[float, ArrayType, Initializer, Callable] = 1., + phi_q: Union[float, ArrayType, Initializer, Callable] = 1., + tau_max: Union[float, ArrayType, Initializer, Callable] = 4e3, + V_sh: Union[float, ArrayType, Initializer, Callable] = 0., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(IKNI_Ya1989, self).__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.tau_max = parameter(tau_max, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + self.phi_p = parameter(phi_p, self.varshape, allow_none=False) + self.phi_q = parameter(phi_q, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(self.dp, method=method) + + def dp(self, p, t, V): + return self.phi_p * (self.f_p_inf(V) - p) / self.f_p_tau(V) + + def update(self, V): + self.p.value = self.integral(self.p.value, share['t'], V, share['dt']) + + def current(self, V): + return self.g_max * self.p * (self.E - V) + + def reset_state(self, V, batch_size=None): + self.p.value = self.f_p_inf(V) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + + def f_p_inf(self, V): + return 1. / (1. + bm.exp(-(V - self.V_sh + 35.) / 10.)) + + def f_p_tau(self, V): + temp = V - self.V_sh + 35. + return self.tau_max / (3.3 * bm.exp(temp / 20.) + bm.exp(-temp / 20.)) class IKL(IonChannel): - """The potassium leak channel current. - - Parameters:: - - g_max : float - The potassium leakage conductance which is modulated by both - acetylcholine and norepinephrine. - E : float - The reversal potential. - """ - - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005, - E: Union[int, float, ArrayType, Initializer, Callable] = -90., - method: str = None, - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - mode=mode) - - self.E = parameter(E, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - self.method = method - - def reset_state(self, V, batch_size=None): - pass - - def update(self, V): - pass - - def current(self, V): - return self.g_max * (self.E - V) + """The potassium leak channel current. + + Parameters:: + + g_max : float + The potassium leakage conductance which is modulated by both + acetylcholine and norepinephrine. + E : float + The reversal potential. + """ + + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 0.005, + E: Union[int, float, ArrayType, Initializer, Callable] = -90., + method: str = None, + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + mode=mode) + + self.E = parameter(E, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + self.method = method + + def reset_state(self, V, batch_size=None): + pass + + def update(self, V): + pass + + def current(self, V): + return self.g_max * (self.E - V) diff --git a/brainpy/_src/dyn/channels/sodium.py b/brainpy/_src/dyn/channels/sodium.py index 255227202..166948032 100644 --- a/brainpy/_src/dyn/channels/sodium.py +++ b/brainpy/_src/dyn/channels/sodium.py @@ -16,365 +16,365 @@ from .base import IonChannel __all__ = [ - 'SodiumChannel', - 'INa_Ba2002v2', - 'INa_TM1991v2', - 'INa_HH1952v2', + 'SodiumChannel', + 'INa_Ba2002v2', + 'INa_TM1991v2', + 'INa_HH1952v2', ] class SodiumChannel(IonChannel): - """Base class for sodium channel dynamics.""" + """Base class for sodium channel dynamics.""" - master_type = Sodium + master_type = Sodium - def update(self, V, C, E): - raise NotImplementedError + def update(self, V, C, E): + raise NotImplementedError - def current(self, V, C, E): - raise NotImplementedError + def current(self, V, C, E): + raise NotImplementedError - def reset(self, V, C, E, batch_size: int = None): - self.reset_state(V, C, E, batch_size) + def reset(self, V, C, E, batch_size: int = None): + self.reset_state(V, C, E, batch_size) - def reset_state(self, V, C, E, batch_size: int = None): - raise NotImplementedError('Must be implemented by the subclass.') + def reset_state(self, V, C, E, batch_size: int = None): + raise NotImplementedError('Must be implemented by the subclass.') class _INa_p3q_markov_v2(SodiumChannel): - r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain. - - The general model can be used to model the dynamics with: - - .. math:: - - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ - \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ - \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor. - - Parameters:: - - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - phi : float, ArrayType, Callable, Initializer - The temperature-dependent factor. - method: str - The numerical method - name: str - The name of the object. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., - phi: Union[int, float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.phi = parameter(phi, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq([self.dp, self.dq]), method=method) - - def reset_state(self, V, C, E, batch_size=None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - alpha = self.f_q_alpha(V) - beta = self.f_q_beta(V) - self.q.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def dp(self, p, t, V): - return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) - - def dq(self, q, t, V): - return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) - - def update(self, V, C, E): - p, q = self.integral(self.p, self.q, share['t'], V, share['dt']) - self.p.value, self.q.value = p, q - - def current(self, V, C, E): - return self.g_max * self.p ** 3 * self.q * (E - V) - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError - - def f_q_alpha(self, V): - raise NotImplementedError - - def f_q_beta(self, V): - raise NotImplementedError + r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain. + + The general model can be used to model the dynamics with: + + .. math:: + + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ + \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ + \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor. + + Parameters:: + + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + phi : float, ArrayType, Callable, Initializer + The temperature-dependent factor. + method: str + The numerical method + name: str + The name of the object. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., + phi: Union[int, float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.phi = parameter(phi, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq([self.dp, self.dq]), method=method) + + def reset_state(self, V, C, E, batch_size=None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + alpha = self.f_q_alpha(V) + beta = self.f_q_beta(V) + self.q.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def dp(self, p, t, V): + return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) + + def dq(self, q, t, V): + return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) + + def update(self, V, C, E): + p, q = self.integral(self.p, self.q, share['t'], V, share['dt']) + self.p.value, self.q.value = p, q + + def current(self, V, C, E): + return self.g_max * self.p ** 3 * self.q * (E - V) + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError + + def f_q_alpha(self, V): + raise NotImplementedError + + def f_q_beta(self, V): + raise NotImplementedError class INa_Ba2002v2(_INa_p3q_markov_v2): - r"""The sodium current model. - - The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ - \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\ - \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\ - \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ - \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\ - \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)} - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters:: - - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - T : float, ArrayType - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : float, ArrayType, Callable, Initializer - The shift of the membrane potential to spike. - - References:: - - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - See Also:: - - INa_TM1991 - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[int, float, ArrayType] = 36., - g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -50., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=3 ** ((T - 36) / 10), - g_max=g_max, - mode=mode) - self.T = parameter(T, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh - 13. - return 0.32 * temp / (1. - bm.exp(-temp / 4.)) - - def f_p_beta(self, V): - temp = V - self.V_sh - 40. - return -0.28 * temp / (1. - bm.exp(temp / 5.)) - - def f_q_alpha(self, V): - return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) - - def f_q_beta(self, V): - return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.)) + r"""The sodium current model. + + The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ + \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\ + \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\ + \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ + \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\ + \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)} + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters:: + + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + T : float, ArrayType + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : float, ArrayType, Callable, Initializer + The shift of the membrane potential to spike. + + References:: + + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + See Also:: + + INa_TM1991 + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[int, float, ArrayType] = 36., + g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -50., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=3 ** ((T - 36) / 10), + g_max=g_max, + mode=mode) + self.T = parameter(T, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh - 13. + return 0.32 * temp / (1. - bm.exp(-temp / 4.)) + + def f_p_beta(self, V): + temp = V - self.V_sh - 40. + return -0.28 * temp / (1. - bm.exp(temp / 5.)) + + def f_q_alpha(self, V): + return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) + + def f_q_beta(self, V): + return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.)) class INa_TM1991v2(_INa_p3q_markov_v2): - r"""The sodium current model described by (Traub and Miles, 1991) [1]_. - - The dynamics of this sodium current model is given by: - - .. math:: - - \begin{split} - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ - \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\ - &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\ - &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\ - \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\ - &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\ - &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\ - \end{aligned} - \end{split} - - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer - The membrane shift. - - References:: - - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also:: - - INa_Ba2002 - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., - phi: Union[int, float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -63., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = 13 - V + self.V_sh - return 0.32 * temp / (bm.exp(temp / 4) - 1.) - - def f_p_beta(self, V): - temp = V - self.V_sh - 40 - return 0.28 * temp / (bm.exp(temp / 5) - 1) - - def f_q_alpha(self, V): - return 0.128 * bm.exp((17 - V + self.V_sh) / 18) - - def f_q_beta(self, V): - return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5)) + r"""The sodium current model described by (Traub and Miles, 1991) [1]_. + + The dynamics of this sodium current model is given by: + + .. math:: + + \begin{split} + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ + \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\ + &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\ + &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\ + \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\ + &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\ + &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\ + \end{aligned} + \end{split} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + V_sh: float, ArrayType, Callable, Initializer + The membrane shift. + + References:: + + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also:: + + INa_Ba2002 + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., + phi: Union[int, float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -63., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = 13 - V + self.V_sh + return 0.32 * temp / (bm.exp(temp / 4) - 1.) + + def f_p_beta(self, V): + temp = V - self.V_sh - 40 + return 0.28 * temp / (bm.exp(temp / 5) - 1) + + def f_q_alpha(self, V): + return 0.128 * bm.exp((17 - V + self.V_sh) / 18) + + def f_q_beta(self, V): + return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5)) class INa_HH1952v2(_INa_p3q_markov_v2): - r"""The sodium current model described by Hodgkin–Huxley model [1]_. - - The dynamics of this sodium current model is given by: - - .. math:: - - \begin{split} - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ - \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\ - &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\ - &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\ - \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\ - &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\ - &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\ - \end{aligned} - \end{split} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer - The membrane shift. - - References:: - - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also:: - - IK_HH1952 - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., - phi: Union[int, float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=phi, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh - 5 - return 0.1 * temp / (1 - bm.exp(-temp / 10)) - - def f_p_beta(self, V): - return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) - - def f_q_alpha(self, V): - return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.) - - def f_q_beta(self, V): - return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10)) + r"""The sodium current model described by Hodgkin–Huxley model [1]_. + + The dynamics of this sodium current model is given by: + + .. math:: + + \begin{split} + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ + \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\ + &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\ + &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\ + \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\ + &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\ + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\ + \end{aligned} + \end{split} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + V_sh: float, ArrayType, Callable, Initializer + The membrane shift. + + References:: + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also:: + + IK_HH1952 + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., + phi: Union[int, float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=phi, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh - 5 + return 0.1 * temp / (1 - bm.exp(-temp / 10)) + + def f_p_beta(self, V): + return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) + + def f_q_alpha(self, V): + return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.) + + def f_q_beta(self, V): + return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10)) diff --git a/brainpy/_src/dyn/channels/sodium_compatible.py b/brainpy/_src/dyn/channels/sodium_compatible.py index a5238d443..b05003e80 100644 --- a/brainpy/_src/dyn/channels/sodium_compatible.py +++ b/brainpy/_src/dyn/channels/sodium_compatible.py @@ -16,355 +16,355 @@ from .base import IonChannel __all__ = [ - 'INa_Ba2002', - 'INa_TM1991', - 'INa_HH1952', + 'INa_Ba2002', + 'INa_TM1991', + 'INa_HH1952', ] class _INa_p3q_markov(IonChannel): - r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain. - - The general model can be used to model the dynamics with: - - .. math:: - - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ - \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ - \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor. - - Parameters:: - - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - phi : float, ArrayType, Callable, Initializer - The temperature-dependent factor. - method: str - The numerical method - name: str - The name of the object. - - """ - master_type = HHTypedNeuron - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[int, float, ArrayType, Initializer, Callable] = None, - g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., - phi: Union[int, float, ArrayType, Initializer, Callable] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - - # parameters - self.E = parameter(E, self.varshape, allow_none=False) - self.phi = parameter(phi, self.varshape, allow_none=False) - self.g_max = parameter(g_max, self.varshape, allow_none=False) - - # variables - self.p = variable(bm.zeros, self.mode, self.varshape) - self.q = variable(bm.zeros, self.mode, self.varshape) - - # function - self.integral = odeint(JointEq([self.dp, self.dq]), method=method) - - def reset_state(self, V, batch_size=None): - alpha = self.f_p_alpha(V) - beta = self.f_p_beta(V) - self.p.value = alpha / (alpha + beta) - alpha = self.f_q_alpha(V) - beta = self.f_q_beta(V) - self.q.value = alpha / (alpha + beta) - if isinstance(batch_size, int): - assert self.p.shape[0] == batch_size - assert self.q.shape[0] == batch_size - - def dp(self, p, t, V): - return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) - - def dq(self, q, t, V): - return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) - - def update(self, V): - p, q = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) - self.p.value, self.q.value = p, q - - def current(self, V): - return self.g_max * self.p ** 3 * self.q * (self.E - V) - - def f_p_alpha(self, V): - raise NotImplementedError - - def f_p_beta(self, V): - raise NotImplementedError - - def f_q_alpha(self, V): - raise NotImplementedError - - def f_q_beta(self, V): - raise NotImplementedError + r"""The sodium current model of :math:`p^3q` current which described with first-order Markov chain. + + The general model can be used to model the dynamics with: + + .. math:: + + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ + \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ + \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor. + + Parameters:: + + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + phi : float, ArrayType, Callable, Initializer + The temperature-dependent factor. + method: str + The numerical method + name: str + The name of the object. + + """ + master_type = HHTypedNeuron + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[int, float, ArrayType, Initializer, Callable] = None, + g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., + phi: Union[int, float, ArrayType, Initializer, Callable] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + + # parameters + self.E = parameter(E, self.varshape, allow_none=False) + self.phi = parameter(phi, self.varshape, allow_none=False) + self.g_max = parameter(g_max, self.varshape, allow_none=False) + + # variables + self.p = variable(bm.zeros, self.mode, self.varshape) + self.q = variable(bm.zeros, self.mode, self.varshape) + + # function + self.integral = odeint(JointEq([self.dp, self.dq]), method=method) + + def reset_state(self, V, batch_size=None): + alpha = self.f_p_alpha(V) + beta = self.f_p_beta(V) + self.p.value = alpha / (alpha + beta) + alpha = self.f_q_alpha(V) + beta = self.f_q_beta(V) + self.q.value = alpha / (alpha + beta) + if isinstance(batch_size, int): + assert self.p.shape[0] == batch_size + assert self.q.shape[0] == batch_size + + def dp(self, p, t, V): + return self.phi * (self.f_p_alpha(V) * (1. - p) - self.f_p_beta(V) * p) + + def dq(self, q, t, V): + return self.phi * (self.f_q_alpha(V) * (1. - q) - self.f_q_beta(V) * q) + + def update(self, V): + p, q = self.integral(self.p.value, self.q.value, share['t'], V, share['dt']) + self.p.value, self.q.value = p, q + + def current(self, V): + return self.g_max * self.p ** 3 * self.q * (self.E - V) + + def f_p_alpha(self, V): + raise NotImplementedError + + def f_p_beta(self, V): + raise NotImplementedError + + def f_q_alpha(self, V): + raise NotImplementedError + + def f_q_beta(self, V): + raise NotImplementedError class INa_Ba2002(_INa_p3q_markov): - r"""The sodium current model. - - The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. - It's dynamics is given by: - - .. math:: - - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ - \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ - \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\ - \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\ - \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ - \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\ - \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)} - \end{aligned} - - where :math:`\phi` is a temperature-dependent factor, which is given by - :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - - Parameters:: - - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - T : float, ArrayType - The temperature (Celsius, :math:`^{\circ}C`). - V_sh : float, ArrayType, Callable, Initializer - The shift of the membrane potential to spike. - - References:: - - .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations - and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - - See Also:: - - INa_TM1991 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - T: Union[int, float, ArrayType] = 36., - E: Union[int, float, ArrayType, Initializer, Callable] = 50., - g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -50., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - phi=3 ** ((T - 36) / 10), - g_max=g_max, - E=E, - mode=mode) - self.T = parameter(T, self.varshape, allow_none=False) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh - 13. - return 0.32 * temp / (1. - bm.exp(-temp / 4.)) - - def f_p_beta(self, V): - temp = V - self.V_sh - 40. - return -0.28 * temp / (1. - bm.exp(temp / 5.)) - - def f_q_alpha(self, V): - return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) - - def f_q_beta(self, V): - return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.)) + r"""The sodium current model. + + The sodium current model is adopted from (Bazhenov, et, al. 2002) [1]_. + It's dynamics is given by: + + .. math:: + + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} * p^3 * q \\ + \frac{dp}{dt} &= \phi ( \alpha_p (1-p) - \beta_p p) \\ + \alpha_{p} &=\frac{0.32\left(V-V_{sh}-13\right)}{1-\exp \left(-\left(V-V_{sh}-13\right) / 4\right)} \\ + \beta_{p} &=\frac{-0.28\left(V-V_{sh}-40\right)}{1-\exp \left(\left(V-V_{sh}-40\right) / 5\right)} \\ + \frac{dq}{dt} & = \phi ( \alpha_q (1-h) - \beta_q h) \\ + \alpha_q &=0.128 \exp \left(-\left(V-V_{sh}-17\right) / 18\right) \\ + \beta_q &= \frac{4}{1+\exp \left(-\left(V-V_{sh}-40\right) / 5\right)} + \end{aligned} + + where :math:`\phi` is a temperature-dependent factor, which is given by + :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). + + Parameters:: + + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + T : float, ArrayType + The temperature (Celsius, :math:`^{\circ}C`). + V_sh : float, ArrayType, Callable, Initializer + The shift of the membrane potential to spike. + + References:: + + .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations + and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. + + See Also:: + + INa_TM1991 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + T: Union[int, float, ArrayType] = 36., + E: Union[int, float, ArrayType, Initializer, Callable] = 50., + g_max: Union[int, float, ArrayType, Initializer, Callable] = 90., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -50., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + phi=3 ** ((T - 36) / 10), + g_max=g_max, + E=E, + mode=mode) + self.T = parameter(T, self.varshape, allow_none=False) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh - 13. + return 0.32 * temp / (1. - bm.exp(-temp / 4.)) + + def f_p_beta(self, V): + temp = V - self.V_sh - 40. + return -0.28 * temp / (1. - bm.exp(temp / 5.)) + + def f_q_alpha(self, V): + return 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) + + def f_q_beta(self, V): + return 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.)) class INa_TM1991(_INa_p3q_markov): - r"""The sodium current model described by (Traub and Miles, 1991) [1]_. - - The dynamics of this sodium current model is given by: - - .. math:: - - \begin{split} - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ - \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\ - &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\ - &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\ - \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\ - &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\ - &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\ - \end{aligned} - \end{split} - - where :math:`V_{sh}` is the membrane shift (default -63 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer - The membrane shift. - - References:: - - .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. - Vol. 777. Cambridge University Press, 1991. - - See Also:: - - INa_Ba2002 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[int, float, ArrayType, Initializer, Callable] = 50., - g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., - phi: Union[int, float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -63., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - phi=phi, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = 13 - V + self.V_sh - return 0.32 * temp / (bm.exp(temp / 4) - 1.) - - def f_p_beta(self, V): - temp = V - self.V_sh - 40 - return 0.28 * temp / (bm.exp(temp / 5) - 1) - - def f_q_alpha(self, V): - return 0.128 * bm.exp((17 - V + self.V_sh) / 18) - - def f_q_beta(self, V): - return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5)) + r"""The sodium current model described by (Traub and Miles, 1991) [1]_. + + The dynamics of this sodium current model is given by: + + .. math:: + + \begin{split} + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ + \frac {dm} {dt} &= \phi(\alpha_m (1-x) - \beta_m) \\ + &\alpha_m(V) = 0.32 \frac{(13 - V + V_{sh})}{\exp((13 - V +V_{sh}) / 4) - 1.} \\ + &\beta_m(V) = 0.28 \frac{(V - V_{sh} - 40)}{(\exp((V - V_{sh} - 40) / 5) - 1)} \\ + \frac {dh} {dt} &= \phi(\alpha_h (1-x) - \beta_h) \\ + &\alpha_h(V) = 0.128 * \exp((17 - V + V_{sh}) / 18) \\ + &\beta_h(V) = 4. / (1 + \exp(-(V - V_{sh} - 40) / 5)) \\ + \end{aligned} + \end{split} + + where :math:`V_{sh}` is the membrane shift (default -63 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + V_sh: float, ArrayType, Callable, Initializer + The membrane shift. + + References:: + + .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. + Vol. 777. Cambridge University Press, 1991. + + See Also:: + + INa_Ba2002 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[int, float, ArrayType, Initializer, Callable] = 50., + g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., + phi: Union[int, float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -63., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + phi=phi, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = 13 - V + self.V_sh + return 0.32 * temp / (bm.exp(temp / 4) - 1.) + + def f_p_beta(self, V): + temp = V - self.V_sh - 40 + return 0.28 * temp / (bm.exp(temp / 5) - 1) + + def f_q_alpha(self, V): + return 0.128 * bm.exp((17 - V + self.V_sh) / 18) + + def f_q_beta(self, V): + return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5)) class INa_HH1952(_INa_p3q_markov): - r"""The sodium current model described by Hodgkin–Huxley model [1]_. - - The dynamics of this sodium current model is given by: - - .. math:: - - \begin{split} - \begin{aligned} - I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ - \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\ - &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\ - &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\ - \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\ - &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\ - &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\ - \end{aligned} - \end{split} - - where :math:`V_{sh}` is the membrane shift (default -45 mV), and - :math:`\phi` is the temperature-dependent factor (default 1.). - - Parameters:: - - size: int, tuple of int - The size of the simulation target. - keep_size: bool - Keep size or flatten the size? - method: str - The numerical method - name: str - The name of the object. - g_max : float, ArrayType, Callable, Initializer - The maximal conductance density (:math:`mS/cm^2`). - E : float, ArrayType, Callable, Initializer - The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer - The membrane shift. - - References:: - - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of - membrane current and its application to conduction and excitation in - nerve." The Journal of physiology 117.4 (1952): 500. - - See Also:: - - IK_HH1952 - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - E: Union[int, float, ArrayType, Initializer, Callable] = 50., - g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., - phi: Union[int, float, ArrayType, Initializer, Callable] = 1., - V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(size, - keep_size=keep_size, - name=name, - method=method, - E=E, - phi=phi, - g_max=g_max, - mode=mode) - self.V_sh = parameter(V_sh, self.varshape, allow_none=False) - - def f_p_alpha(self, V): - temp = V - self.V_sh - 5 - return 0.1 * temp / (1 - bm.exp(-temp / 10)) - - def f_p_beta(self, V): - return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) - - def f_q_alpha(self, V): - return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.) - - def f_q_beta(self, V): - return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10)) + r"""The sodium current model described by Hodgkin–Huxley model [1]_. + + The dynamics of this sodium current model is given by: + + .. math:: + + \begin{split} + \begin{aligned} + I_{\mathrm{Na}} &= g_{\mathrm{max}} m^3 h \\ + \frac {dm} {dt} &= \phi (\alpha_m (1-x) - \beta_m) \\ + &\alpha_m(V) = \frac {0.1(V-V_{sh}-5)}{1-\exp(\frac{-(V -V_{sh} -5)} {10})} \\ + &\beta_m(V) = 4.0 \exp(\frac{-(V -V_{sh}+ 20)} {18}) \\ + \frac {dh} {dt} &= \phi (\alpha_h (1-x) - \beta_h) \\ + &\alpha_h(V) = 0.07 \exp(\frac{-(V-V_{sh}+20)}{20}) \\ + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V -V_{sh}-10)} {10})} \\ + \end{aligned} + \end{split} + + where :math:`V_{sh}` is the membrane shift (default -45 mV), and + :math:`\phi` is the temperature-dependent factor (default 1.). + + Parameters:: + + size: int, tuple of int + The size of the simulation target. + keep_size: bool + Keep size or flatten the size? + method: str + The numerical method + name: str + The name of the object. + g_max : float, ArrayType, Callable, Initializer + The maximal conductance density (:math:`mS/cm^2`). + E : float, ArrayType, Callable, Initializer + The reversal potential (mV). + V_sh: float, ArrayType, Callable, Initializer + The membrane shift. + + References:: + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of + membrane current and its application to conduction and excitation in + nerve." The Journal of physiology 117.4 (1952): 500. + + See Also:: + + IK_HH1952 + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + E: Union[int, float, ArrayType, Initializer, Callable] = 50., + g_max: Union[int, float, ArrayType, Initializer, Callable] = 120., + phi: Union[int, float, ArrayType, Initializer, Callable] = 1., + V_sh: Union[int, float, ArrayType, Initializer, Callable] = -45., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(size, + keep_size=keep_size, + name=name, + method=method, + E=E, + phi=phi, + g_max=g_max, + mode=mode) + self.V_sh = parameter(V_sh, self.varshape, allow_none=False) + + def f_p_alpha(self, V): + temp = V - self.V_sh - 5 + return 0.1 * temp / (1 - bm.exp(-temp / 10)) + + def f_p_beta(self, V): + return 4.0 * bm.exp(-(V - self.V_sh + 20) / 18) + + def f_q_alpha(self, V): + return 0.07 * bm.exp(-(V - self.V_sh + 20) / 20.) + + def f_q_beta(self, V): + return 1 / (1 + bm.exp(-(V - self.V_sh - 10) / 10)) diff --git a/brainpy/_src/dyn/channels/tests/test_Ca.py b/brainpy/_src/dyn/channels/tests/test_Ca.py index 0b7593f7b..3383b5966 100644 --- a/brainpy/_src/dyn/channels/tests/test_Ca.py +++ b/brainpy/_src/dyn/channels/tests/test_Ca.py @@ -1,144 +1,145 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized class Test_Ca(parameterized.TestCase): - def test_Ca(self): - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca1 = bp.dyn.CalciumFixed(size) - self.Ca2 = bp.dyn.CalciumDetailed(size) - self.Ca3 = bp.dyn.CalciumFirstOrder(size) - - bm.random.seed(1234) - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', 'Ca2.C', 'Ca3.C'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca2.C'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca3.C'].shape, (100, 1)) - - def test_ICaN_IS2008(self): - bm.random.seed(1234) - - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca = bp.dyn.CalciumDetailed(size, - ICa=bp.dyn.ICaN_IS2008(size), - ) - - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', 'Ca.ICa.p'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) - - def test_ICaT_HM1992(self): - bm.random.seed(1234) - - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca = bp.dyn.CalciumDetailed(size, - ICa=bp.dyn.ICaT_HM1992(size), - ) - - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', - 'Ca.ICa.p', - ], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) - - def test_ICaT_HP1992(self): - bm.random.seed(1234) - - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca = bp.dyn.CalciumDetailed(size, - ICa=bp.dyn.ICaT_HP1992(size), - ) - - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', - 'Ca.ICa.p', - ], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) - - def test_ICaHT_HM1992(self): - bm.random.seed(1234) - - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca = bp.dyn.CalciumDetailed(size, - ICa=bp.dyn.ICaHT_HM1992(size), - ) - - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', - 'Ca.ICa.p', - ], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) - - def test_ICaHT_Re1993(self): - bm.random.seed(1234) - - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca = bp.dyn.CalciumDetailed(size, - ICa=bp.dyn.ICaHT_Re1993(size), - ) - - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', - 'Ca.ICa.p', - ], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) - - def test_ICaL_IS2008(self): - bm.random.seed(1234) - - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.Ca = bp.dyn.CalciumDetailed(size, - ICa=bp.dyn.ICaL_IS2008(size), - ) - - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', - 'Ca.ICa.p', - ], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) + def test_Ca(self): + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca1 = bp.dyn.CalciumFixed(size) + self.Ca2 = bp.dyn.CalciumDetailed(size) + self.Ca3 = bp.dyn.CalciumFirstOrder(size) + + bm.random.seed(1234) + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', 'Ca2.C', 'Ca3.C'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca2.C'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca3.C'].shape, (100, 1)) + + def test_ICaN_IS2008(self): + bm.random.seed(1234) + + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size, + ICa=bp.dyn.ICaN_IS2008(size), + ) + + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', 'Ca.ICa.p'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) + + def test_ICaT_HM1992(self): + bm.random.seed(1234) + + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size, + ICa=bp.dyn.ICaT_HM1992(size), + ) + + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', + 'Ca.ICa.p', + ], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) + + def test_ICaT_HP1992(self): + bm.random.seed(1234) + + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size, + ICa=bp.dyn.ICaT_HP1992(size), + ) + + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', + 'Ca.ICa.p', + ], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) + + def test_ICaHT_HM1992(self): + bm.random.seed(1234) + + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size, + ICa=bp.dyn.ICaHT_HM1992(size), + ) + + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', + 'Ca.ICa.p', + ], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) + + def test_ICaHT_Re1993(self): + bm.random.seed(1234) + + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size, + ICa=bp.dyn.ICaHT_Re1993(size), + ) + + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', + 'Ca.ICa.p', + ], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) + + def test_ICaL_IS2008(self): + bm.random.seed(1234) + + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.Ca = bp.dyn.CalciumDetailed(size, + ICa=bp.dyn.ICaL_IS2008(size), + ) + + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', + 'Ca.ICa.p', + ], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.ICa.p'].shape, (100, 1)) diff --git a/brainpy/_src/dyn/channels/tests/test_IH.py b/brainpy/_src/dyn/channels/tests/test_IH.py index 5860a9cdd..101a678c3 100644 --- a/brainpy/_src/dyn/channels/tests/test_IH.py +++ b/brainpy/_src/dyn/channels/tests/test_IH.py @@ -1,25 +1,27 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized class Test_IH(parameterized.TestCase): - bm.random.seed(1234) - def test_IH(self): - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size) - self.IH = bp.dyn.Ih_HM1992(size) - self.Ca = bp.dyn.CalciumDetailed(size, IH=bp.dyn.Ih_De1996(size)) + bm.random.seed(1234) + + def test_IH(self): + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size) + self.IH = bp.dyn.Ih_HM1992(size) + self.Ca = bp.dyn.CalciumDetailed(size, IH=bp.dyn.Ih_De1996(size)) - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', 'IH.p', 'Ca.IH.O'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IH.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.IH.O'].shape, (100, 1)) \ No newline at end of file + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', 'IH.p', 'Ca.IH.O'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IH.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.IH.O'].shape, (100, 1)) diff --git a/brainpy/_src/dyn/channels/tests/test_K.py b/brainpy/_src/dyn/channels/tests/test_K.py index 2bdd63bde..41207c76f 100644 --- a/brainpy/_src/dyn/channels/tests/test_K.py +++ b/brainpy/_src/dyn/channels/tests/test_K.py @@ -1,36 +1,40 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized + class Test_K(parameterized.TestCase): - bm.random.seed(1234) - def test_K(self): - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) - self.IK_1 = bp.dyn.IKDR_Ba2002(size) - self.IK_2 = bp.dyn.IK_TM1991(size) - self.IK_3 = bp.dyn.IK_HH1952(size) - self.IK_4 = bp.dyn.IKA1_HM1992(size) - self.IK_5 = bp.dyn.IKA2_HM1992(size) - self.IK_6 = bp.dyn.IKK2A_HM1992(size) - self.IK_7 = bp.dyn.IKK2B_HM1992(size) - self.IK_8 = bp.dyn.IKNI_Ya1989(size) + bm.random.seed(1234) + + def test_K(self): + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) + self.IK_1 = bp.dyn.IKDR_Ba2002(size) + self.IK_2 = bp.dyn.IK_TM1991(size) + self.IK_3 = bp.dyn.IK_HH1952(size) + self.IK_4 = bp.dyn.IKA1_HM1992(size) + self.IK_5 = bp.dyn.IKA2_HM1992(size) + self.IK_6 = bp.dyn.IKK2A_HM1992(size) + self.IK_7 = bp.dyn.IKK2B_HM1992(size) + self.IK_8 = bp.dyn.IKNI_Ya1989(size) - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', 'IK_1.p', 'IK_2.p', 'IK_3.p', 'IK_4.p', 'IK_5.p', 'IK_6.p', 'IK_7.p', 'IK_8.p'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_1.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_2.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_3.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_4.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_5.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_6.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_7.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['IK_8.p'].shape, (100, 1)) \ No newline at end of file + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', 'IK_1.p', 'IK_2.p', 'IK_3.p', 'IK_4.p', 'IK_5.p', 'IK_6.p', 'IK_7.p', + 'IK_8.p'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_1.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_2.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_3.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_4.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_5.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_6.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_7.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['IK_8.p'].shape, (100, 1)) diff --git a/brainpy/_src/dyn/channels/tests/test_KCa.py b/brainpy/_src/dyn/channels/tests/test_KCa.py index ad52c0871..957139815 100644 --- a/brainpy/_src/dyn/channels/tests/test_KCa.py +++ b/brainpy/_src/dyn/channels/tests/test_KCa.py @@ -1,24 +1,25 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized class Test_KCa(parameterized.TestCase): - bm.random.seed(1234) + bm.random.seed(1234) - def test_KCa(self): - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) - self.Ca = bp.dyn.CalciumDetailed(size, KCa=bp.dyn.IAHP_De1994(size)) + def test_KCa(self): + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) + self.Ca = bp.dyn.CalciumDetailed(size, KCa=bp.dyn.IAHP_De1994(size)) - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', 'Ca.KCa.p'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Ca.KCa.p'].shape, (100, 1)) + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', 'Ca.KCa.p'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Ca.KCa.p'].shape, (100, 1)) diff --git a/brainpy/_src/dyn/channels/tests/test_Na.py b/brainpy/_src/dyn/channels/tests/test_Na.py index 58002e3f0..04f737823 100644 --- a/brainpy/_src/dyn/channels/tests/test_Na.py +++ b/brainpy/_src/dyn/channels/tests/test_Na.py @@ -1,31 +1,32 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized class Test_Na(parameterized.TestCase): - bm.random.seed(1234) + bm.random.seed(1234) - def test_Na(self): - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) - self.INa_1 = bp.dyn.INa_HH1952(size, E=50., g_max=120.) - self.INa_2 = bp.dyn.INa_TM1991(size) - self.INa_3 = bp.dyn.INa_Ba2002(size) + def test_Na(self): + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) + self.INa_1 = bp.dyn.INa_HH1952(size, E=50., g_max=120.) + self.INa_2 = bp.dyn.INa_TM1991(size) + self.INa_3 = bp.dyn.INa_Ba2002(size) - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V', 'INa_1.p', 'INa_1.q', 'INa_2.p', 'INa_2.q', 'INa_3.p', 'INa_3.q'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['INa_1.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['INa_1.q'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['INa_2.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['INa_2.q'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['INa_3.p'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['INa_3.q'].shape, (100, 1)) + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V', 'INa_1.p', 'INa_1.q', 'INa_2.p', 'INa_2.q', 'INa_3.p', 'INa_3.q'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['INa_1.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['INa_1.q'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['INa_2.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['INa_2.q'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['INa_3.p'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['INa_3.q'].shape, (100, 1)) diff --git a/brainpy/_src/dyn/channels/tests/test_leaky.py b/brainpy/_src/dyn/channels/tests/test_leaky.py index 9535cefde..d82957a55 100644 --- a/brainpy/_src/dyn/channels/tests/test_leaky.py +++ b/brainpy/_src/dyn/channels/tests/test_leaky.py @@ -1,24 +1,25 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized class Test_Leaky(parameterized.TestCase): - bm.random.seed(1234) + bm.random.seed(1234) - def test_leaky(self): - class Neuron(bp.dyn.CondNeuGroup): - def __init__(self, size): - super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) - self.leaky1 = bp.dyn.IL(size) - self.leaky2 = bp.dyn.IKL(size) + def test_leaky(self): + class Neuron(bp.dyn.CondNeuGroup): + def __init__(self, size): + super(Neuron, self).__init__(size, V_initializer=bp.init.Uniform(-70, -50.)) + self.leaky1 = bp.dyn.IL(size) + self.leaky2 = bp.dyn.IKL(size) - model = Neuron(1) - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + model = Neuron(1) + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) diff --git a/brainpy/_src/dyn/ions/__init__.py b/brainpy/_src/dyn/ions/__init__.py index ee840a720..f71653a19 100644 --- a/brainpy/_src/dyn/ions/__init__.py +++ b/brainpy/_src/dyn/ions/__init__.py @@ -1,4 +1,3 @@ - from .base import * from .calcium import * from .potassium import * diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py index 145c1ded0..eeaf581c3 100644 --- a/brainpy/_src/dyn/ions/base.py +++ b/brainpy/_src/dyn/ions/base.py @@ -2,215 +2,217 @@ from typing import Union, Optional, Dict, Sequence, Callable +from brainstate.mixin import _JointGenericAlias + import brainpy.math as bm from brainpy._src.dyn.base import IonChaDyn from brainpy._src.dyn.neurons.hh import HHTypedNeuron -from brainpy._src.mixin import Container, TreeNode, _JointGenericAlias +from brainpy._src.mixin import Container, TreeNode from brainpy.types import Shape __all__ = [ - 'MixIons', - 'mix_ions', - 'Ion', + 'MixIons', + 'mix_ions', + 'Ion', ] class MixIons(IonChaDyn, Container, TreeNode): - """Mixing Ions. - - Args: - ions: Instances of ions. This option defines the master types of all children objects. - channels: Instance of channels. - """ - master_type = HHTypedNeuron - - def __init__(self, *ions, **channels): - # TODO: check "ions" should be independent from each other - assert isinstance(ions, (tuple, list)), f'{self.__class__.__name__} requires at least two ions. ' - assert len(ions) >= 2, f'{self.__class__.__name__} requires at least two ions. ' - assert all([isinstance(cls, Ion) for cls in ions]), f'Must be a sequence of Ion. But got {ions}.' - super().__init__(size=ions[0].size, keep_size=ions[0].keep_size, sharding=ions[0].sharding) - - # Attribute of "Container" - self.children = bm.node_dict() - - self.ions: Sequence['Ion'] = tuple(ions) - self._ion_classes = tuple([type(ion) for ion in self.ions]) - for k, v in channels.items(): - self.add_elem(k=v) - - def update(self, V): - nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) - self.check_hierarchies(self._ion_classes, *nodes) - for node in nodes: - infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__]) - node.update(V, *infos) - - def current(self, V): - """Generate ion channel current. - - Args: - V: The membrane potential. - - Returns: - Current. - """ - nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) - self.check_hierarchies(self._ion_classes, *nodes) - - if len(nodes) == 0: - return 0. - else: - current = 0. - for node in nodes: - infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__]) - current = current + node.current(V, *infos) - return current - - def reset_state(self, V, batch_size=None): - nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) - self.check_hierarchies(self._ion_classes, *nodes) - for node in nodes: - infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__]) - node.reset_state(V, *infos, batch_size) - - def check_hierarchy(self, roots, leaf): - # 'master_type' should be a brainpy.mixin.JointType - self._check_master_type(leaf) - for cls in leaf.master_type.__args__: - if not any([issubclass(root, cls) for root in roots]): - raise TypeError(f'Type does not match. {leaf} requires a master with type ' - f'of {leaf.master_type}, but the master type now is {roots}.') - - def add_elem(self, *elems, **elements): - """Add new elements. + """Mixing Ions. Args: - elements: children objects. + ions: Instances of ions. This option defines the master types of all children objects. + channels: Instance of channels. """ - self.check_hierarchies(self._ion_classes, *elems, **elements) - self.children.update(self.format_elements(IonChaDyn, *elems, **elements)) - for elem in tuple(elems) + tuple(elements.values()): - for ion_root in elem.master_type.__args__: - ion = self._get_imp(ion_root) - ion.add_external_current(elem.name, self._get_ion_fun(ion, elem)) - - def _get_ion_fun(self, ion, node): - def fun(V, *args): - infos = tuple([(ion.pack_info(*args) - if isinstance(ion, root) else - self._get_imp(root).pack_info()) - for root in node.master_type.__args__]) - return node.current(V, *infos) - return fun - - def _get_imp(self, cls): - for ion in self.ions: - if isinstance(ion, cls): - return ion - else: - raise ValueError(f'No instance of {cls} is found.') - - def _check_master_type(self, leaf): - if not isinstance(leaf.master_type, _JointGenericAlias): - raise TypeError(f'{self.__class__.__name__} requires leaf nodes that have the master_type of ' - f'"brainpy.mixin.JointType". However, we got {leaf.master_type}') + master_type = HHTypedNeuron + + def __init__(self, *ions, **channels): + # TODO: check "ions" should be independent from each other + assert isinstance(ions, (tuple, list)), f'{self.__class__.__name__} requires at least two ions. ' + assert len(ions) >= 2, f'{self.__class__.__name__} requires at least two ions. ' + assert all([isinstance(cls, Ion) for cls in ions]), f'Must be a sequence of Ion. But got {ions}.' + super().__init__(size=ions[0].size, keep_size=ions[0].keep_size, sharding=ions[0].sharding) + + # Attribute of "Container" + self.children = bm.node_dict() + + self.ions: Sequence['Ion'] = tuple(ions) + self._ion_classes = tuple([type(ion) for ion in self.ions]) + for k, v in channels.items(): + self.add_elem(k=v) + + def update(self, V): + nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) + self.check_hierarchies(self._ion_classes, *nodes) + for node in nodes: + infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__]) + node.update(V, *infos) + + def current(self, V): + """Generate ion channel current. + + Args: + V: The membrane potential. + + Returns: + Current. + """ + nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) + self.check_hierarchies(self._ion_classes, *nodes) + + if len(nodes) == 0: + return 0. + else: + current = 0. + for node in nodes: + infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__]) + current = current + node.current(V, *infos) + return current + + def reset_state(self, V, batch_size=None): + nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) + self.check_hierarchies(self._ion_classes, *nodes) + for node in nodes: + infos = tuple([self._get_imp(root).pack_info() for root in node.master_type.__args__]) + node.reset_state(V, *infos, batch_size) + + def check_hierarchy(self, roots, leaf): + # 'master_type' should be a brainpy.mixin.JointType + self._check_master_type(leaf) + for cls in leaf.master_type.__args__: + if not any([issubclass(root, cls) for root in roots]): + raise TypeError(f'Type does not match. {leaf} requires a master with type ' + f'of {leaf.master_type}, but the master type now is {roots}.') + + def add_elem(self, *elems, **elements): + """Add new elements. + + Args: + elements: children objects. + """ + self.check_hierarchies(self._ion_classes, *elems, **elements) + self.children.update(self.format_elements(IonChaDyn, *elems, **elements)) + for elem in tuple(elems) + tuple(elements.values()): + for ion_root in elem.master_type.__args__: + ion = self._get_imp(ion_root) + ion.add_external_current(elem.name, self._get_ion_fun(ion, elem)) + + def _get_ion_fun(self, ion, node): + def fun(V, *args): + infos = tuple([(ion.pack_info(*args) + if isinstance(ion, root) else + self._get_imp(root).pack_info()) + for root in node.master_type.__args__]) + return node.current(V, *infos) + + return fun + + def _get_imp(self, cls): + for ion in self.ions: + if isinstance(ion, cls): + return ion + else: + raise ValueError(f'No instance of {cls} is found.') + + def _check_master_type(self, leaf): + if not isinstance(leaf.master_type, _JointGenericAlias): + raise TypeError(f'{self.__class__.__name__} requires leaf nodes that have the master_type of ' + f'"brainpy.mixin.JointType". However, we got {leaf.master_type}') def mix_ions(*ions) -> MixIons: - """Create mixed ions. + """Create mixed ions. - Args: - ions: Ion instances. + Args: + ions: Ion instances. - Returns: - Instance of MixIons. - """ - for ion in ions: - assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}' - assert len(ions) > 0, '' - return MixIons(*ions) + Returns: + Instance of MixIons. + """ + for ion in ions: + assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}' + assert len(ions) > 0, '' + return MixIons(*ions) class Ion(IonChaDyn, Container, TreeNode): - """The brainpy_object calcium dynamics. - - Args: - size: The size of the simulation target. - method: The numerical integration method. - name: The name of the object. - channels: The calcium dependent channels. - """ - - '''The type of the master object.''' - master_type = HHTypedNeuron - - """Reversal potential.""" - E: Union[float, bm.Variable, bm.Array] - - """Calcium concentration.""" - C: Union[float, bm.Variable, bm.Array] - - def __init__( - self, - size: Shape, - keep_size: bool = False, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, keep_size=keep_size, mode=mode, method=method, name=name) - - # Attribute of "Container" - self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels)) - self.external: Dict[str, Callable] = dict() # not found by `.nodes()` or `.vars()` - - def update(self, V): - for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): - node.update(V, self.C, self.E) - - def current(self, V, C=None, E=None, external: bool = False): - """Generate ion channel current. + """The brainpy_object calcium dynamics. Args: - V: The membrane potential. - C: The given ion concentration. - E: The given reversal potential. - external: Include the external current. - - Returns: - Current. + size: The size of the simulation target. + method: The numerical integration method. + name: The name of the object. + channels: The calcium dependent channels. """ - C = self.C if (C is None) else C - E = self.E if (E is None) else E - nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) - self.check_hierarchies(type(self), *nodes) - - current = 0. - if len(nodes) > 0: - for node in nodes: - current = current + node.current(V, C, E) - if external: - for key, node in self.external.items(): - current = current + node(V, C, E) - return current - - def reset_state(self, V, batch_size=None): - nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) - self.check_hierarchies(type(self), *nodes) - for node in nodes: - node.reset_state(V, self.C, self.E, batch_size) - - def pack_info(self, C=None, E=None) -> Dict: - if C is None: - C = self.C - if E is None: - E = self.E - return dict(C=C, E=E) - - def add_external_current(self, key: str, fun: Callable): - if key in self.external: - raise ValueError - self.external[key] = fun + '''The type of the master object.''' + master_type = HHTypedNeuron + + """Reversal potential.""" + E: Union[float, bm.Variable, bm.Array] + + """Calcium concentration.""" + C: Union[float, bm.Variable, bm.Array] + + def __init__( + self, + size: Shape, + keep_size: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, keep_size=keep_size, mode=mode, method=method, name=name) + + # Attribute of "Container" + self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels)) + self.external: Dict[str, Callable] = dict() # not found by `.nodes()` or `.vars()` + + def update(self, V): + for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): + node.update(V, self.C, self.E) + + def current(self, V, C=None, E=None, external: bool = False): + """Generate ion channel current. + + Args: + V: The membrane potential. + C: The given ion concentration. + E: The given reversal potential. + external: Include the external current. + + Returns: + Current. + """ + C = self.C if (C is None) else C + E = self.E if (E is None) else E + nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) + self.check_hierarchies(type(self), *nodes) + + current = 0. + if len(nodes) > 0: + for node in nodes: + current = current + node.current(V, C, E) + if external: + for key, node in self.external.items(): + current = current + node(V, C, E) + return current + + def reset_state(self, V, batch_size=None): + nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) + self.check_hierarchies(type(self), *nodes) + for node in nodes: + node.reset_state(V, self.C, self.E, batch_size) + + def pack_info(self, C=None, E=None) -> Dict: + if C is None: + C = self.C + if E is None: + E = self.E + return dict(C=C, E=E) + + def add_external_current(self, key: str, fun: Callable): + if key in self.external: + raise ValueError + self.external[key] = fun diff --git a/brainpy/_src/dyn/ions/calcium.py b/brainpy/_src/dyn/ions/calcium.py index 7c2b88954..3b2c285b5 100644 --- a/brainpy/_src/dyn/ions/calcium.py +++ b/brainpy/_src/dyn/ions/calcium.py @@ -11,311 +11,311 @@ from .base import Ion __all__ = [ - 'Calcium', - 'CalciumFixed', - 'CalciumDetailed', - 'CalciumFirstOrder', + 'Calcium', + 'CalciumFixed', + 'CalciumDetailed', + 'CalciumFirstOrder', ] class Calcium(Ion): - """Base class for modeling Calcium ion.""" - pass + """Base class for modeling Calcium ion.""" + pass class CalciumFixed(Calcium): - """Fixed Calcium dynamics. - - This calcium model has no dynamics. It holds fixed reversal - potential :math:`E` and concentration :math:`C`. - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = 120., - C: Union[float, ArrayType, Initializer, Callable] = 2.4e-4, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, - keep_size=keep_size, - method=method, - name=name, - mode=mode, - **channels) - self.E = parameter(E, self.varshape, allow_none=False) - self.C = parameter(C, self.varshape, allow_none=False) - - def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): - C_Ca = self.C if C_Ca is None else C_Ca - E_Ca = self.E if E_Ca is None else E_Ca - for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): - node.reset_state(V, C_Ca, E_Ca, batch_size=batch_size) + """Fixed Calcium dynamics. + + This calcium model has no dynamics. It holds fixed reversal + potential :math:`E` and concentration :math:`C`. + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = 120., + C: Union[float, ArrayType, Initializer, Callable] = 2.4e-4, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, + keep_size=keep_size, + method=method, + name=name, + mode=mode, + **channels) + self.E = parameter(E, self.varshape, allow_none=False) + self.C = parameter(C, self.varshape, allow_none=False) + + def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): + C_Ca = self.C if C_Ca is None else C_Ca + E_Ca = self.E if E_Ca is None else E_Ca + for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): + node.reset_state(V, C_Ca, E_Ca, batch_size=batch_size) class CalciumDyna(Calcium): - """Calcium ion flow with dynamics. - - Parameters:: - - size: int, tuple of int - The ion size. - keep_size: bool - Keep the geometry size. - C0: float, ArrayType, Initializer, Callable - The Calcium concentration outside of membrane. - T: float, ArrayType, Initializer, Callable - The temperature. - C_initializer: Initializer, Callable, ArrayType - The initializer for Calcium concentration. - method: str - The numerical method. - name: str - The ion name. - """ - R = 8.31441 # gas constant, J*mol-1*K-1 - F = 96.489 # the Faraday constant - - def __init__( - self, - size: Shape, - keep_size: bool = False, - C0: Union[float, ArrayType, Initializer, Callable] = 2., - T: Union[float, ArrayType, Initializer, Callable] = 36., - C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4), - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, - keep_size=keep_size, - method=method, - name=name, - mode=mode, - **channels) - - # parameters - self.C0 = parameter(C0, self.varshape, allow_none=False) - self.T = parameter(T, self.varshape, allow_none=False) # temperature - self._C_initializer = C_initializer - self._constant = self.R / (2 * self.F) * (273.15 + self.T) - - # variables - self.C = variable(C_initializer, self.mode, self.varshape) # Calcium concentration - self.E = bm.Variable(self._reversal_potential(self.C), - batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) # Reversal potential - - # function - self.integral = odeint(self.derivative, method=method) - - def derivative(self, C, t, V): - raise NotImplementedError - - def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): - self.C.value = variable(self._C_initializer, batch_size, self.varshape) if (C_Ca is None) else C_Ca - self.E.value = self._reversal_potential(self.C) - for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): - node.reset(V, self.C, self.E, batch_size=batch_size) - - def update(self, V): - for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): - node.update(V, self.C.value, self.E.value) - self.C.value = self.integral(self.C.value, share['t'], V, share['dt']) - self.E.value = self._reversal_potential(self.C.value) - - def _reversal_potential(self, C): - return self._constant * bm.log(self.C0 / C) + """Calcium ion flow with dynamics. + + Parameters:: + + size: int, tuple of int + The ion size. + keep_size: bool + Keep the geometry size. + C0: float, ArrayType, Initializer, Callable + The Calcium concentration outside of membrane. + T: float, ArrayType, Initializer, Callable + The temperature. + C_initializer: Initializer, Callable, ArrayType + The initializer for Calcium concentration. + method: str + The numerical method. + name: str + The ion name. + """ + R = 8.31441 # gas constant, J*mol-1*K-1 + F = 96.489 # the Faraday constant + + def __init__( + self, + size: Shape, + keep_size: bool = False, + C0: Union[float, ArrayType, Initializer, Callable] = 2., + T: Union[float, ArrayType, Initializer, Callable] = 36., + C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, + keep_size=keep_size, + method=method, + name=name, + mode=mode, + **channels) + + # parameters + self.C0 = parameter(C0, self.varshape, allow_none=False) + self.T = parameter(T, self.varshape, allow_none=False) # temperature + self._C_initializer = C_initializer + self._constant = self.R / (2 * self.F) * (273.15 + self.T) + + # variables + self.C = variable(C_initializer, self.mode, self.varshape) # Calcium concentration + self.E = bm.Variable(self._reversal_potential(self.C), + batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) # Reversal potential + + # function + self.integral = odeint(self.derivative, method=method) + + def derivative(self, C, t, V): + raise NotImplementedError + + def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): + self.C.value = variable(self._C_initializer, batch_size, self.varshape) if (C_Ca is None) else C_Ca + self.E.value = self._reversal_potential(self.C) + for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): + node.reset(V, self.C, self.E, batch_size=batch_size) + + def update(self, V): + for node in self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values(): + node.update(V, self.C.value, self.E.value) + self.C.value = self.integral(self.C.value, share['t'], V, share['dt']) + self.E.value = self._reversal_potential(self.C.value) + + def _reversal_potential(self, C): + return self._constant * bm.log(self.C0 / C) class CalciumDetailed(CalciumDyna): - r"""Dynamical Calcium model proposed. + r"""Dynamical Calcium model proposed. - **1. The dynamics of intracellular** :math:`Ca^{2+}` + **1. The dynamics of intracellular** :math:`Ca^{2+}` - The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ : + The dynamics of intracellular :math:`Ca^{2+}` were determined by two contributions [1]_ : - *(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents* + *(i) Influx of* :math:`Ca^{2+}` *due to Calcium currents* - :math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the - interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath - the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed: + :math:`Ca^{2+}` ions enter through :math:`Ca^{2+}` channels and diffuse into the + interior of the cell. Only the :math:`Ca^{2+}` concentration in a thin shell beneath + the membrane was modeled. The influx of :math:`Ca^{2+}` into such a thin shell followed: - .. math:: + .. math:: - [Ca]_{i}=-\frac{k}{2 F d} I_{Ca} + [Ca]_{i}=-\frac{k}{2 F d} I_{Ca} - where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant, - :math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane, - the unit conversion constant is :math:`k=0.1` for :math:`I_T` in - :math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar, - and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents. + where :math:`F=96489\, \mathrm{C\, mol^{-1}}` is the Faraday constant, + :math:`d=1\, \mathrm{\mu m}` is the depth of the shell beneath the membrane, + the unit conversion constant is :math:`k=0.1` for :math:`I_T` in + :math:`\mathrm{\mu A/cm^{2}}` and :math:`[Ca]_{i}` in millimolar, + and :math:`I_{Ca}` is the summation of all :math:`Ca^{2+}` currents. - *(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump* + *(ii) Efflux of* :math:`Ca^{2+}` *due to an active pump* - In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a - combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium - efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells. - Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme: - - .. math:: - - Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+} - - where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state, - :math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration, - and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}` - ions have a high affinity for the pump :math:`P`, whereas extrusion of - :math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore, - :math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the - Michaelis-Menten approximation can be used for describing the kinetics of the pump. - According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is: - - .. math:: - - \frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}} - - where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}` - with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}` - is the dissociation constant, which can be interpreted here as the value of - :math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}` - then the efflux is negligible). - - **2.A simple first-order model** - - While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is - described by a simple first-order model, - - .. math:: - - \frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}} - - where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d` - is the thickness of the perimembrane "shell" in which calcium is able to affect - membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the - :math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is - the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was - set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` . - - **3. The reversal potential** - - The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the - Nernst equation: - - .. math:: - - E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i} - - where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`, - :math:`T=309.15^{\circ} \mathrm{K}`, - :math:`F=96,489 \mathrm{C} / \mathrm{mol}`, - and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`. - - Parameters:: - - d : float - The thickness of the peri-membrane "shell". - F : float - The Faraday constant. (:math:`C*mmol^{-1}`) - tau : float - The time constant of the :math:`Ca ^{2+}` removal rate. (ms) - C_rest : float - The resting :math:`Ca ^{2+}` concentration. - C0 : float - The :math:`Ca ^{2+}` concentration outside of the membrane. - R : float - The gas constant. (:math:` J*mol^{-1}*K^{-1}`) - - References:: - - .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski. - "Ionic mechanisms for intrinsic slow oscillations in thalamic - relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552. - .. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J. - Sejnowski. "Cellular and network models for intrathalamic augmenting - responses during 10-Hz stimulation." Journal of neurophysiology 79, - no. 5 (1998): 2730-2748. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType, Initializer, Callable] = 36., - d: Union[float, ArrayType, Initializer, Callable] = 1., - C_rest: Union[float, ArrayType, Initializer, Callable] = 2.4e-4, - tau: Union[float, ArrayType, Initializer, Callable] = 5., - C0: Union[float, ArrayType, Initializer, Callable] = 2., - C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4), - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, - keep_size=keep_size, - method=method, - name=name, - T=T, - C0=C0, - C_initializer=C_initializer, - mode=mode, - **channels) - - # parameters - self.d = parameter(d, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.C_rest = parameter(C_rest, self.varshape, allow_none=False) - - def derivative(self, C, t, V): - ICa = self.current(V, C, self.E, external=True) - drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.) - return drive + (self.C_rest - C) / self.tau + In a thin shell beneath the membrane, :math:`Ca^{2+}` retrieval usually consists of a + combination of several processes, such as binding to :math:`Ca^{2+}` buffers, calcium + efflux due to :math:`Ca^{2+}` ATPase pump activity and diffusion to neighboring shells. + Only the :math:`Ca^{2+}` pump was modeled here. We adopted the following kinetic scheme: + + .. math:: + + Ca _{i}^{2+}+ P \overset{c_1}{\underset{c_2}{\rightleftharpoons}} CaP \xrightarrow{c_3} P+ Ca _{0}^{2+} + + where P represents the :math:`Ca^{2+}` pump, CaP is an intermediate state, + :math:`Ca _{ o }^{2+}` is the extracellular :math:`Ca^{2+}` concentration, + and :math:`c_{1}, c_{2}` and :math:`c_{3}` are rate constants. :math:`Ca^{2+}` + ions have a high affinity for the pump :math:`P`, whereas extrusion of + :math:`Ca^{2+}` follows a slower process (Blaustein, 1988 ). Therefore, + :math:`c_{3}` is low compared to :math:`c_{1}` and :math:`c_{2}` and the + Michaelis-Menten approximation can be used for describing the kinetics of the pump. + According to such a scheme, the kinetic equation for the :math:`Ca^{2+}` pump is: + + .. math:: + + \frac{[Ca^{2+}]_{i}}{dt}=-\frac{K_{T}[Ca]_{i}}{[Ca]_{i}+K_{d}} + + where :math:`K_{T}=10^{-4}\, \mathrm{mM\, ms^{-1}}` is the product of :math:`c_{3}` + with the total concentration of :math:`P` and :math:`K_{d}=c_{2} / c_{1}=10^{-4}\, \mathrm{mM}` + is the dissociation constant, which can be interpreted here as the value of + :math:`[Ca]_{i}` at which the pump is half activated (if :math:`[Ca]_{i} \ll K_{d}` + then the efflux is negligible). + + **2.A simple first-order model** + + While, in (Bazhenov, et al., 1998) [2]_, the :math:`Ca^{2+}` dynamics is + described by a simple first-order model, + + .. math:: + + \frac{d\left[Ca^{2+}\right]_{i}}{d t}=-\frac{I_{Ca}}{z F d}+\frac{\left[Ca^{2+}\right]_{rest}-\left[C a^{2+}\right]_{i}}{\tau_{Ca}} + + where :math:`I_{Ca}` is the summation of all :math:`Ca ^{2+}` currents, :math:`d` + is the thickness of the perimembrane "shell" in which calcium is able to affect + membrane properties :math:`(1.\, \mathrm{\mu M})`, :math:`z=2` is the valence of the + :math:`Ca ^{2+}` ion, :math:`F` is the Faraday constant, and :math:`\tau_{C a}` is + the :math:`Ca ^{2+}` removal rate. The resting :math:`Ca ^{2+}` concentration was + set to be :math:`\left[ Ca ^{2+}\right]_{\text {rest}}=.05\, \mathrm{\mu M}` . + + **3. The reversal potential** + + The reversal potential of calcium :math:`Ca ^{2+}` is calculated according to the + Nernst equation: + + .. math:: + + E = k'{RT \over 2F} log{[Ca^{2+}]_0 \over [Ca^{2+}]_i} + + where :math:`R=8.31441 \, \mathrm{J} /(\mathrm{mol}^{\circ} \mathrm{K})`, + :math:`T=309.15^{\circ} \mathrm{K}`, + :math:`F=96,489 \mathrm{C} / \mathrm{mol}`, + and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`. + + Parameters:: + + d : float + The thickness of the peri-membrane "shell". + F : float + The Faraday constant. (:math:`C*mmol^{-1}`) + tau : float + The time constant of the :math:`Ca ^{2+}` removal rate. (ms) + C_rest : float + The resting :math:`Ca ^{2+}` concentration. + C0 : float + The :math:`Ca ^{2+}` concentration outside of the membrane. + R : float + The gas constant. (:math:` J*mol^{-1}*K^{-1}`) + + References:: + + .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski. + "Ionic mechanisms for intrinsic slow oscillations in thalamic + relay neurons." Biophysical journal 65, no. 4 (1993): 1538-1552. + .. [2] Bazhenov, Maxim, Igor Timofeev, Mircea Steriade, and Terrence J. + Sejnowski. "Cellular and network models for intrathalamic augmenting + responses during 10-Hz stimulation." Journal of neurophysiology 79, + no. 5 (1998): 2730-2748. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType, Initializer, Callable] = 36., + d: Union[float, ArrayType, Initializer, Callable] = 1., + C_rest: Union[float, ArrayType, Initializer, Callable] = 2.4e-4, + tau: Union[float, ArrayType, Initializer, Callable] = 5., + C0: Union[float, ArrayType, Initializer, Callable] = 2., + C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, + keep_size=keep_size, + method=method, + name=name, + T=T, + C0=C0, + C_initializer=C_initializer, + mode=mode, + **channels) + + # parameters + self.d = parameter(d, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.C_rest = parameter(C_rest, self.varshape, allow_none=False) + + def derivative(self, C, t, V): + ICa = self.current(V, C, self.E, external=True) + drive = bm.maximum(- ICa / (2 * self.F * self.d), 0.) + return drive + (self.C_rest - C) / self.tau class CalciumFirstOrder(CalciumDyna): - r"""The first-order calcium concentration model. - - .. math:: - - Ca' = -\alpha I_{Ca} + -\beta Ca - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - T: Union[float, ArrayType, Initializer, Callable] = 36., - alpha: Union[float, ArrayType, Initializer, Callable] = 0.13, - beta: Union[float, ArrayType, Initializer, Callable] = 0.075, - C0: Union[float, ArrayType, Initializer, Callable] = 2., - C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4), - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, - keep_size=keep_size, - method=method, - name=name, - T=T, - C0=C0, - C_initializer=C_initializer, - mode=mode, - **channels) - - # parameters - self.alpha = parameter(alpha, self.varshape, allow_none=False) - self.beta = parameter(beta, self.varshape, allow_none=False) - - def derivative(self, C, t, V): - ICa = self.current(V, C, self.E, external=True) - drive = bm.maximum(- self.alpha * ICa, 0.) - return drive - self.beta * C + r"""The first-order calcium concentration model. + + .. math:: + + Ca' = -\alpha I_{Ca} + -\beta Ca + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + T: Union[float, ArrayType, Initializer, Callable] = 36., + alpha: Union[float, ArrayType, Initializer, Callable] = 0.13, + beta: Union[float, ArrayType, Initializer, Callable] = 0.075, + C0: Union[float, ArrayType, Initializer, Callable] = 2., + C_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.4e-4), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, + keep_size=keep_size, + method=method, + name=name, + T=T, + C0=C0, + C_initializer=C_initializer, + mode=mode, + **channels) + + # parameters + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + + def derivative(self, C, t, V): + ICa = self.current(V, C, self.E, external=True) + drive = bm.maximum(- self.alpha * ICa, 0.) + return drive - self.beta * C diff --git a/brainpy/_src/dyn/ions/potassium.py b/brainpy/_src/dyn/ions/potassium.py index 2f944ad8d..c5d6474e2 100644 --- a/brainpy/_src/dyn/ions/potassium.py +++ b/brainpy/_src/dyn/ions/potassium.py @@ -7,47 +7,47 @@ from .base import Ion __all__ = [ - 'Potassium', - 'PotassiumFixed', + 'Potassium', + 'PotassiumFixed', ] class Potassium(Ion): - """Base class for modeling Potassium ion.""" - pass + """Base class for modeling Potassium ion.""" + pass class PotassiumFixed(Potassium): - """Fixed Sodium dynamics. - - This calcium model has no dynamics. It holds fixed reversal - potential :math:`E` and concentration :math:`C`. - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = -950., - C: Union[float, ArrayType, Initializer, Callable] = 0.0400811, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, - keep_size=keep_size, - method=method, - name=name, - mode=mode, - **channels) - self.E = self.init_param(E, self.varshape) - self.C = self.init_param(C, self.varshape) - - def reset_state(self, V, C=None, E=None, batch_size=None): - C = self.C if C is None else C - E = self.E if E is None else E - nodes = self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values() - self.check_hierarchies(type(self), *tuple(nodes)) - for node in nodes: - node.reset_state(V, C, E, batch_size) + """Fixed Sodium dynamics. + + This calcium model has no dynamics. It holds fixed reversal + potential :math:`E` and concentration :math:`C`. + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = -950., + C: Union[float, ArrayType, Initializer, Callable] = 0.0400811, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, + keep_size=keep_size, + method=method, + name=name, + mode=mode, + **channels) + self.E = self.init_param(E, self.varshape) + self.C = self.init_param(C, self.varshape) + + def reset_state(self, V, C=None, E=None, batch_size=None): + C = self.C if C is None else C + E = self.E if E is None else E + nodes = self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values() + self.check_hierarchies(type(self), *tuple(nodes)) + for node in nodes: + node.reset_state(V, C, E, batch_size) diff --git a/brainpy/_src/dyn/ions/sodium.py b/brainpy/_src/dyn/ions/sodium.py index e08dea778..481ac779c 100644 --- a/brainpy/_src/dyn/ions/sodium.py +++ b/brainpy/_src/dyn/ions/sodium.py @@ -7,47 +7,47 @@ from .base import Ion __all__ = [ - 'Sodium', - 'SodiumFixed', + 'Sodium', + 'SodiumFixed', ] class Sodium(Ion): - """Base class for modeling Sodium ion.""" - pass + """Base class for modeling Sodium ion.""" + pass class SodiumFixed(Sodium): - """Fixed Sodium dynamics. - - This calcium model has no dynamics. It holds fixed reversal - potential :math:`E` and concentration :math:`C`. - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - E: Union[float, ArrayType, Initializer, Callable] = 50., - C: Union[float, ArrayType, Initializer, Callable] = 0.0400811, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - **channels - ): - super().__init__(size, - keep_size=keep_size, - method=method, - name=name, - mode=mode, - **channels) - self.E = parameter(E, self.varshape, allow_none=False) - self.C = parameter(C, self.varshape, allow_none=False) - - def reset_state(self, V, C=None, E=None, batch_size=None): - C = self.C if C is None else C - E = self.E if E is None else E - nodes = self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values() - self.check_hierarchies(type(self), *tuple(nodes)) - for node in nodes: - node.reset_state(V, C, E, batch_size) + """Fixed Sodium dynamics. + + This calcium model has no dynamics. It holds fixed reversal + potential :math:`E` and concentration :math:`C`. + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + E: Union[float, ArrayType, Initializer, Callable] = 50., + C: Union[float, ArrayType, Initializer, Callable] = 0.0400811, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + **channels + ): + super().__init__(size, + keep_size=keep_size, + method=method, + name=name, + mode=mode, + **channels) + self.E = parameter(E, self.varshape, allow_none=False) + self.C = parameter(C, self.varshape, allow_none=False) + + def reset_state(self, V, C=None, E=None, batch_size=None): + C = self.C if C is None else C + E = self.E if E is None else E + nodes = self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values() + self.check_hierarchies(type(self), *tuple(nodes)) + for node in nodes: + node.reset_state(V, C, E, batch_size) diff --git a/brainpy/_src/dyn/ions/tests/test_MixIons.py b/brainpy/_src/dyn/ions/tests/test_MixIons.py index e196ca4d4..31ae6fe1f 100644 --- a/brainpy/_src/dyn/ions/tests/test_MixIons.py +++ b/brainpy/_src/dyn/ions/tests/test_MixIons.py @@ -1,98 +1,94 @@ +import unittest + import brainpy as bp import brainpy.math as bm -import unittest - class TestMixIons(unittest.TestCase): - def test_init(self): - class HH(bp.dyn.CondNeuGroup): - def __init__(self, size): - super().__init__(size) + def test_init(self): + class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) - self.k = bp.dyn.PotassiumFixed(size) - self.ca = bp.dyn.CalciumFirstOrder(size) + self.k = bp.dyn.PotassiumFixed(size) + self.ca = bp.dyn.CalciumFirstOrder(size) - self.kca = bp.dyn.mix_ions(self.k, self.ca) - self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) + self.kca = bp.dyn.mix_ions(self.k, self.ca) + self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) - bm.random.seed() - HH(10) + bm.random.seed() + HH(10) - def test_init2(self): - class HH(bp.dyn.CondNeuGroup): - def __init__(self, size): - super().__init__(size) + def test_init2(self): + class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) - self.k = bp.dyn.PotassiumFixed(size) - self.ca = bp.dyn.CalciumFirstOrder(size) + self.k = bp.dyn.PotassiumFixed(size) + self.ca = bp.dyn.CalciumFirstOrder(size) - self.kca = bp.dyn.mix_ions(self.k, self.ca) - self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) - self.kca.add_elem(na=bp.dyn.INa_Ba2002(size)) + self.kca = bp.dyn.mix_ions(self.k, self.ca) + self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) + self.kca.add_elem(na=bp.dyn.INa_Ba2002(size)) - bm.random.seed() - with self.assertRaises(TypeError): - HH(10) + bm.random.seed() + with self.assertRaises(TypeError): + HH(10) - def test_init3(self): - class HH(bp.dyn.CondNeuGroup): - def __init__(self, size): - super().__init__(size) + def test_init3(self): + class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) - self.na = bp.dyn.SodiumFixed(size) - self.ca = bp.dyn.CalciumFirstOrder(size) + self.na = bp.dyn.SodiumFixed(size) + self.ca = bp.dyn.CalciumFirstOrder(size) - self.kca = bp.dyn.mix_ions(self.na, self.ca) - self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) - self.kca.add_elem(na=bp.dyn.INa_Ba2002(size)) + self.kca = bp.dyn.mix_ions(self.na, self.ca) + self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) + self.kca.add_elem(na=bp.dyn.INa_Ba2002(size)) - bm.random.seed() - with self.assertRaises(TypeError): - HH(10) + bm.random.seed() + with self.assertRaises(TypeError): + HH(10) - def test_init4(self): - class HH(bp.dyn.CondNeuGroup): - def __init__(self, size): - super().__init__(size) + def test_init4(self): + class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) - self.na = bp.dyn.SodiumFixed(size) - self.k = bp.dyn.PotassiumFixed(size) - self.ca = bp.dyn.CalciumFirstOrder(size) + self.na = bp.dyn.SodiumFixed(size) + self.k = bp.dyn.PotassiumFixed(size) + self.ca = bp.dyn.CalciumFirstOrder(size) - self.kca = bp.dyn.mix_ions(self.na, self.k, self.ca) - self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) + self.kca = bp.dyn.mix_ions(self.na, self.k, self.ca) + self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) - bm.random.seed() - HH(10) + bm.random.seed() + HH(10) class TestMixIons2(unittest.TestCase): - def test_current1(self): - class HH(bp.dyn.CondNeuGroup): - def __init__(self, size): - super().__init__(size) - - self.k = bp.dyn.PotassiumFixed(size) - self.na = bp.dyn.SodiumFixed(size) - self.ca = bp.dyn.CalciumFirstOrder(size) - self.kca = bp.dyn.MixIons(self.na, self.k, self.ca) - - self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) - - bm.random.seed() - hh = HH(10) - - hh.reset_state() + def test_current1(self): + class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) - ICa = hh.ca.current(hh.V, external=True) - INa = hh.na.current(hh.V, external=True) - IK = hh.k.current(hh.V, external=True) - print(ICa, INa, IK) + self.k = bp.dyn.PotassiumFixed(size) + self.na = bp.dyn.SodiumFixed(size) + self.ca = bp.dyn.CalciumFirstOrder(size) + self.kca = bp.dyn.MixIons(self.na, self.k, self.ca) - self.assertTrue(bm.allclose(INa, 0.)) - self.assertTrue(bm.allclose(ICa, IK)) + self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) + bm.random.seed() + hh = HH(10) + hh.reset_state() + ICa = hh.ca.current(hh.V, external=True) + INa = hh.na.current(hh.V, external=True) + IK = hh.k.current(hh.V, external=True) + print(ICa, INa, IK) + self.assertTrue(bm.allclose(INa, 0.)) + self.assertTrue(bm.allclose(ICa, IK)) diff --git a/brainpy/_src/dyn/neurons/__init__.py b/brainpy/_src/dyn/neurons/__init__.py index dee13d730..12beba70f 100644 --- a/brainpy/_src/dyn/neurons/__init__.py +++ b/brainpy/_src/dyn/neurons/__init__.py @@ -1,3 +1 @@ from .lif import * - - diff --git a/brainpy/_src/dyn/neurons/base.py b/brainpy/_src/dyn/neurons/base.py index 7a05d625e..580ba9981 100644 --- a/brainpy/_src/dyn/neurons/base.py +++ b/brainpy/_src/dyn/neurons/base.py @@ -9,73 +9,73 @@ class GradNeuDyn(NeuDyn): - """Differentiable and Parallelizable Neuron Group. + """Differentiable and Parallelizable Neuron Group. - Args: - {pneu} - {dpneu} - """ + Args: + {pneu} + {dpneu} + """ - supported_modes = (bm.TrainingMode, bm.NonBatchingMode) + supported_modes = (bm.TrainingMode, bm.NonBatchingMode) - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Any = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - method: str = 'exp_auto', - scaling: Optional[bm.Scaling] = None, + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + method: str = 'exp_auto', + scaling: Optional[bm.Scaling] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - ): - super().__init__(size=size, - mode=mode, - keep_size=keep_size, - name=name, - sharding=sharding, - method=method) + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + ): + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name, + sharding=sharding, + method=method) - self.spk_reset = spk_reset - self.spk_fun = is_callable(spk_fun) - self.detach_spk = detach_spk - self._spk_dtype = spk_dtype - if scaling is None: - self.scaling = bm.get_membrane_scaling() - else: - self.scaling = scaling + self.spk_reset = spk_reset + self.spk_fun = is_callable(spk_fun) + self.detach_spk = detach_spk + self._spk_dtype = spk_dtype + if scaling is None: + self.scaling = bm.get_membrane_scaling() + else: + self.scaling = scaling - @property - def spk_dtype(self): - if self._spk_dtype is None: - return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_ - else: - return self._spk_dtype + @property + def spk_dtype(self): + if self._spk_dtype is None: + return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_ + else: + return self._spk_dtype - def offset_scaling(self, x, bias=None, scale=None): - s = self.scaling.offset_scaling(x, bias=bias, scale=scale) - if isinstance(x, bm.BaseArray): - x.value = s - return x - return s + def offset_scaling(self, x, bias=None, scale=None): + s = self.scaling.offset_scaling(x, bias=bias, scale=scale) + if isinstance(x, bm.BaseArray): + x.value = s + return x + return s - def std_scaling(self, x, scale=None): - s = self.scaling.std_scaling(x, scale=scale) - if isinstance(x, bm.BaseArray): - x.value = s - return x - return s + def std_scaling(self, x, scale=None): + s = self.scaling.std_scaling(x, scale=scale) + if isinstance(x, bm.BaseArray): + x.value = s + return x + return s - def inv_scaling(self, x, scale=None): - s = self.scaling.inv_scaling(x, scale=scale) - if isinstance(x, bm.BaseArray): - x.value = s - return x - return s + def inv_scaling(self, x, scale=None): + s = self.scaling.inv_scaling(x, scale=scale) + if isinstance(x, bm.BaseArray): + x.value = s + return x + return s GradNeuDyn.__doc__ = GradNeuDyn.__doc__.format(pneu=pneu_doc, dpneu=dpneu_doc) diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index 6052a589d..6e0f9fc3a 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -15,1105 +15,1105 @@ from brainpy.types import Shape __all__ = [ - 'HHTypedNeuron', - 'CondNeuGroupLTC', - 'CondNeuGroup', - 'HHLTC', - 'HH', - 'MorrisLecarLTC', - 'MorrisLecar', - 'WangBuzsakiHHLTC', - 'WangBuzsakiHH' + 'HHTypedNeuron', + 'CondNeuGroupLTC', + 'CondNeuGroup', + 'HHLTC', + 'HH', + 'MorrisLecarLTC', + 'MorrisLecar', + 'WangBuzsakiHHLTC', + 'WangBuzsakiHH' ] class HHTypedNeuron(NeuDyn): - pass + pass class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode): - r"""Base class to model conductance-based neuron group. - - The standard formulation for a conductance-based model is given as - - .. math:: - - C_m {dV \over dt} = \sum_jg_j(E - V) + I_{ext} - - where :math:`g_j=\bar{g}_{j} M^x N^y` is the channel conductance, :math:`E` is the - reversal potential, :math:`M` is the activation variable, and :math:`N` is the - inactivation variable. - - :math:`M` and :math:`N` have the dynamics of - - .. math:: - - {dx \over dt} = \phi_x {x_\infty (V) - x \over \tau_x(V)} - - where :math:`x \in [M, N]`, :math:`\phi_x` is a temperature-dependent factor, - :math:`x_\infty` is the steady state, and :math:`\tau_x` is the time constant. - Equivalently, the above equation can be written as: - - .. math:: - - \frac{d x}{d t}=\phi_{x}\left(\alpha_{x}(1-x)-\beta_{x} x\right) - - where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants. - - .. versionadded:: 2.1.9 - Modeling the conductance-based neuron model. - - Parameters:: - - size : int, sequence of int - The network size of this neuron group. - method: str - The numerical integration method. - name : optional, str - The neuron group name. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - C: Union[float, ArrayType, Callable] = 1., - A: Union[float, ArrayType, Callable] = 1e-3, - V_th: Union[float, ArrayType, Callable] = 0., - V_initializer: Union[Callable, ArrayType] = Uniform(-70, -60.), - noise: Optional[Union[float, ArrayType, Callable]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - init_var: bool = True, - input_var: bool = True, - spk_type: Optional[type] = None, - **channels - ): - super().__init__(size, keep_size=keep_size, mode=mode, name=name, ) - - # attribute for ``Container`` - self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels)) - - # parameters for neurons - self.input_var = input_var - self.C = C - self.A = A - self.V_th = V_th - self.noise = init_noise(noise, self.varshape, num_vars=1) - self._V_initializer = V_initializer - self.spk_type = ((bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool) - if (spk_type is None) else spk_type) - - # function - if self.noise is None: - self.integral = odeint(f=self.derivative, method=method) - else: - self.integral = sdeint(f=self.derivative, g=self.noise, method=method) - - if init_var: - self.reset_state(self.mode) - - def derivative(self, V, t, I): - # synapses - I = self.sum_current_inputs(V, init=I) - # channels - for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): - I = I + ch.current(V) - return I / self.C - - def reset_state(self, batch_size=None): - self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.spike = variable_(partial(bm.zeros, dtype=self.spk_type), self.varshape, batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - for channel in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): - channel.reset_state(self.V.value, batch_size=batch_size) - - def update(self, x=None): - # inputs - x = 0. if x is None else x - if self.input_var: - self.input += x - x = self.input.value - x = x * (1e-3 / self.A) - - # integral - V = self.integral(self.V.value, share['t'], x, share['dt']) + self.sum_delta_inputs() - - # check whether the children channels have the correct parents. - channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique() - self.check_hierarchies(self.__class__, **channels) - - # update channels - for node in channels.values(): - node(self.V.value) - - # update variables - if self.spike.dtype == bool: - self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) - else: - self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th).astype(self.spike.dtype) - self.V.value = V - return self.spike.value - - def clear_input(self): - """Useful for monitoring inputs. """ - if self.input_var: - self.input.value = bm.zeros_like(self.input) - - def return_info(self): - return self.spike + r"""Base class to model conductance-based neuron group. + + The standard formulation for a conductance-based model is given as + + .. math:: + + C_m {dV \over dt} = \sum_jg_j(E - V) + I_{ext} + + where :math:`g_j=\bar{g}_{j} M^x N^y` is the channel conductance, :math:`E` is the + reversal potential, :math:`M` is the activation variable, and :math:`N` is the + inactivation variable. + + :math:`M` and :math:`N` have the dynamics of + + .. math:: + + {dx \over dt} = \phi_x {x_\infty (V) - x \over \tau_x(V)} + + where :math:`x \in [M, N]`, :math:`\phi_x` is a temperature-dependent factor, + :math:`x_\infty` is the steady state, and :math:`\tau_x` is the time constant. + Equivalently, the above equation can be written as: + + .. math:: + + \frac{d x}{d t}=\phi_{x}\left(\alpha_{x}(1-x)-\beta_{x} x\right) + + where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants. + + .. versionadded:: 2.1.9 + Modeling the conductance-based neuron model. + + Parameters:: + + size : int, sequence of int + The network size of this neuron group. + method: str + The numerical integration method. + name : optional, str + The neuron group name. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + C: Union[float, ArrayType, Callable] = 1., + A: Union[float, ArrayType, Callable] = 1e-3, + V_th: Union[float, ArrayType, Callable] = 0., + V_initializer: Union[Callable, ArrayType] = Uniform(-70, -60.), + noise: Optional[Union[float, ArrayType, Callable]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + init_var: bool = True, + input_var: bool = True, + spk_type: Optional[type] = None, + **channels + ): + super().__init__(size, keep_size=keep_size, mode=mode, name=name, ) + + # attribute for ``Container`` + self.children = bm.node_dict(self.format_elements(IonChaDyn, **channels)) + + # parameters for neurons + self.input_var = input_var + self.C = C + self.A = A + self.V_th = V_th + self.noise = init_noise(noise, self.varshape, num_vars=1) + self._V_initializer = V_initializer + self.spk_type = ((bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool) + if (spk_type is None) else spk_type) + + # function + if self.noise is None: + self.integral = odeint(f=self.derivative, method=method) + else: + self.integral = sdeint(f=self.derivative, g=self.noise, method=method) + + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + # synapses + I = self.sum_current_inputs(V, init=I) + # channels + for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): + I = I + ch.current(V) + return I / self.C + + def reset_state(self, batch_size=None): + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.spike = variable_(partial(bm.zeros, dtype=self.spk_type), self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + for channel in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): + channel.reset_state(self.V.value, batch_size=batch_size) + + def update(self, x=None): + # inputs + x = 0. if x is None else x + if self.input_var: + self.input += x + x = self.input.value + x = x * (1e-3 / self.A) + + # integral + V = self.integral(self.V.value, share['t'], x, share['dt']) + self.sum_delta_inputs() + + # check whether the children channels have the correct parents. + channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique() + self.check_hierarchies(self.__class__, **channels) + + # update channels + for node in channels.values(): + node(self.V.value) + + # update variables + if self.spike.dtype == bool: + self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) + else: + self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th).astype(self.spike.dtype) + self.V.value = V + return self.spike.value + + def clear_input(self): + """Useful for monitoring inputs. """ + if self.input_var: + self.input.value = bm.zeros_like(self.input) + + def return_info(self): + return self.spike class CondNeuGroup(CondNeuGroupLTC): - def derivative(self, V, t, I): - for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): - I = I + ch.current(V) - return I / self.C + def derivative(self, V, t, I): + for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): + I = I + ch.current(V) + return I / self.C - def update(self, x=None): - # inputs - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + # inputs + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class HHLTC(NeuDyn): - r"""Hodgkin–Huxley neuron model with liquid time constant. + r"""Hodgkin–Huxley neuron model with liquid time constant. - **Model Descriptions** + **Model Descriptions** - The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of - the nerve action potential is one of the most successful mathematical models of - a complex biological process that has ever been formulated. The basic concepts - expressed in the model have proved a valid approach to the study of bio-electrical - activity from the most primitive single-celled organisms such as *Paramecium*, - right through to the neurons within our own brains. + The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of + the nerve action potential is one of the most successful mathematical models of + a complex biological process that has ever been formulated. The basic concepts + expressed in the model have proved a valid approach to the study of bio-electrical + activity from the most primitive single-celled organisms such as *Paramecium*, + right through to the neurons within our own brains. - Mathematically, the model is given by, + Mathematically, the model is given by, - .. math:: + .. math:: - C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) - + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) + C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) + + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) - \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} + \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} - &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} + &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} - &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) + &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) - &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) + &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) - &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} - &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} + &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} - &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) + &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) - The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. + The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. - The Hodgkin–Huxley model can be thought of as a differential equation system with - four state variables, :math:`V_{m}(t),n(t),m(t)`, and :math:`h(t)`, that change - with respect to time :math:`t`. The system is difficult to study because it is a - nonlinear system and cannot be solved analytically. However, there are many numeric - methods available to analyze the system. Certain properties and general behaviors, - such as limit cycles, can be proven to exist. + The Hodgkin–Huxley model can be thought of as a differential equation system with + four state variables, :math:`V_{m}(t),n(t),m(t)`, and :math:`h(t)`, that change + with respect to time :math:`t`. The system is difficult to study because it is a + nonlinear system and cannot be solved analytically. However, there are many numeric + methods available to analyze the system. Certain properties and general behaviors, + such as limit cycles, can be proven to exist. - References:: + References:: - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description - of membrane current and its application to conduction and excitation - in nerve." The Journal of physiology 117.4 (1952): 500. - .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model - .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical - frameworks for oscillatory network dynamics in neuroscience." - The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. - - **Examples** - - Here is a simple usage example: - - .. code-block:: python - - import brainpy as bp - - neu = bp.dyn.HHLTC(1) - - # raise input current from 4 mA to 40 mA - inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600,) - - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) - - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - - - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable - The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable - The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable - The membrane capacitance. Default is 1 ufarad. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - m_initializer: ArrayType, Initializer, callable - The initializer of m channel. - h_initializer: ArrayType, Initializer, callable - The initializer of h channel. - n_initializer: ArrayType, Initializer, callable - The initializer of n channel. - method: str - The numerical integration method. - name: str - The group name. - - - - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Any = None, - keep_size: bool = False, - mode: bm.Mode = None, - name: str = None, - method: str = 'exp_auto', - init_var: bool = True, - - # neuron parameters - ENa: Union[float, ArrayType, Callable] = 50., - gNa: Union[float, ArrayType, Callable] = 120., - EK: Union[float, ArrayType, Callable] = -77., - gK: Union[float, ArrayType, Callable] = 36., - EL: Union[float, ArrayType, Callable] = -54.387, - gL: Union[float, ArrayType, Callable] = 0.03, - V_th: Union[float, ArrayType, Callable] = 20., - C: Union[float, ArrayType, Callable] = 1.0, - V_initializer: Union[Callable, ArrayType] = Uniform(-70, -60.), - m_initializer: Optional[Union[Callable, ArrayType]] = None, - h_initializer: Optional[Union[Callable, ArrayType]] = None, - n_initializer: Optional[Union[Callable, ArrayType]] = None, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - sharding=sharding, - keep_size=keep_size, - mode=mode, - name=name, - method=method) - - # parameters - self.ENa = self.init_param(ENa) - self.EK = self.init_param(EK) - self.EL = self.init_param(EL) - self.gNa = self.init_param(gNa) - self.gK = self.init_param(gK) - self.gL = self.init_param(gL) - self.C = self.init_param(C) - self.V_th = self.init_param(V_th) - - # initializers - self._m_initializer = is_initializer(m_initializer, allow_none=True) - self._h_initializer = is_initializer(h_initializer, allow_none=True) - self._n_initializer = is_initializer(n_initializer, allow_none=True) - self._V_initializer = is_initializer(V_initializer) - - # noise - self.noise = init_noise(noise, self.varshape, num_vars=4) - - # integral - if self.noise is None: - self.integral = odeint(method=method, f=self.derivative) - else: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - - # model - if init_var: - self.reset_state(self.mode) - - # m channel - # m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) - m_alpha = lambda self, V: 1. / bm.exprel(-(V + 40) / 10) - m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) - m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) - dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m - - # h channel - h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.) - h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10)) - h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) - dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h - - # n channel - # n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) - n_alpha = lambda self, V: 0.1 / bm.exprel(-(V + 55) / 10) - n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80) - n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) - dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.init_variable(self._V_initializer, batch_size) - if self._m_initializer is None: - self.m = bm.Variable(self.m_inf(self.V.value), batch_axis=self.V.batch_axis) - else: - self.m = self.init_variable(self._m_initializer, batch_size) - if self._h_initializer is None: - self.h = bm.Variable(self.h_inf(self.V.value), batch_axis=self.V.batch_axis) - else: - self.h = self.init_variable(self._h_initializer, batch_size) - if self._n_initializer is None: - self.n = bm.Variable(self.n_inf(self.V.value), batch_axis=self.V.batch_axis) - else: - self.n = self.init_variable(self._n_initializer, batch_size) - self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def dV(self, V, t, m, h, n, I): - I = self.sum_current_inputs(V, init=I) - I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) - n2 = n * n - I_K = (self.gK * n2 * n2) * (V - self.EK) - I_leak = self.gL * (V - self.EL) - dVdt = (- I_Na - I_K - I_leak + I) / self.C - return dVdt - - @property - def derivative(self): - return JointEq(self.dV, self.dm, self.dh, self.dn) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt) - V += self.sum_delta_inputs() - self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - self.V.value = V - self.m.value = m - self.h.value = h - self.n.value = n - return self.spike.value - - def return_info(self): - return self.spike + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description + of membrane current and its application to conduction and excitation + in nerve." The Journal of physiology 117.4 (1952): 500. + .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model + .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical + frameworks for oscillatory network dynamics in neuroscience." + The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. + + **Examples** + + Here is a simple usage example: + + .. code-block:: python + + import brainpy as bp + + neu = bp.dyn.HHLTC(1) + + # raise input current from 4 mA to 40 mA + inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600,) + + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + + + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + m_initializer: ArrayType, Initializer, callable + The initializer of m channel. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + + + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + ENa: Union[float, ArrayType, Callable] = 50., + gNa: Union[float, ArrayType, Callable] = 120., + EK: Union[float, ArrayType, Callable] = -77., + gK: Union[float, ArrayType, Callable] = 36., + EL: Union[float, ArrayType, Callable] = -54.387, + gL: Union[float, ArrayType, Callable] = 0.03, + V_th: Union[float, ArrayType, Callable] = 20., + C: Union[float, ArrayType, Callable] = 1.0, + V_initializer: Union[Callable, ArrayType] = Uniform(-70, -60.), + m_initializer: Optional[Union[Callable, ArrayType]] = None, + h_initializer: Optional[Union[Callable, ArrayType]] = None, + n_initializer: Optional[Union[Callable, ArrayType]] = None, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + sharding=sharding, + keep_size=keep_size, + mode=mode, + name=name, + method=method) + + # parameters + self.ENa = self.init_param(ENa) + self.EK = self.init_param(EK) + self.EL = self.init_param(EL) + self.gNa = self.init_param(gNa) + self.gK = self.init_param(gK) + self.gL = self.init_param(gL) + self.C = self.init_param(C) + self.V_th = self.init_param(V_th) + + # initializers + self._m_initializer = is_initializer(m_initializer, allow_none=True) + self._h_initializer = is_initializer(h_initializer, allow_none=True) + self._n_initializer = is_initializer(n_initializer, allow_none=True) + self._V_initializer = is_initializer(V_initializer) + + # noise + self.noise = init_noise(noise, self.varshape, num_vars=4) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + + # model + if init_var: + self.reset_state(self.mode) + + # m channel + # m_alpha = lambda self, V: 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + m_alpha = lambda self, V: 1. / bm.exprel(-(V + 40) / 10) + m_beta = lambda self, V: 4.0 * bm.exp(-(V + 65) / 18) + m_inf = lambda self, V: self.m_alpha(V) / (self.m_alpha(V) + self.m_beta(V)) + dm = lambda self, m, t, V: self.m_alpha(V) * (1 - m) - self.m_beta(V) * m + + # h channel + h_alpha = lambda self, V: 0.07 * bm.exp(-(V + 65) / 20.) + h_beta = lambda self, V: 1 / (1 + bm.exp(-(V + 35) / 10)) + h_inf = lambda self, V: self.h_alpha(V) / (self.h_alpha(V) + self.h_beta(V)) + dh = lambda self, h, t, V: self.h_alpha(V) * (1 - h) - self.h_beta(V) * h + + # n channel + # n_alpha = lambda self, V: 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + n_alpha = lambda self, V: 0.1 / bm.exprel(-(V + 55) / 10) + n_beta = lambda self, V: 0.125 * bm.exp(-(V + 65) / 80) + n_inf = lambda self, V: self.n_alpha(V) / (self.n_alpha(V) + self.n_beta(V)) + dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.init_variable(self._V_initializer, batch_size) + if self._m_initializer is None: + self.m = bm.Variable(self.m_inf(self.V.value), batch_axis=self.V.batch_axis) + else: + self.m = self.init_variable(self._m_initializer, batch_size) + if self._h_initializer is None: + self.h = bm.Variable(self.h_inf(self.V.value), batch_axis=self.V.batch_axis) + else: + self.h = self.init_variable(self._h_initializer, batch_size) + if self._n_initializer is None: + self.n = bm.Variable(self.n_inf(self.V.value), batch_axis=self.V.batch_axis) + else: + self.n = self.init_variable(self._n_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def dV(self, V, t, m, h, n, I): + I = self.sum_current_inputs(V, init=I) + I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) + n2 = n * n + I_K = (self.gK * n2 * n2) * (V - self.EK) + I_leak = self.gL * (V - self.EL) + dVdt = (- I_Na - I_K - I_leak + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dm, self.dh, self.dn) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt) + V += self.sum_delta_inputs() + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.m.value = m + self.h.value = h + self.n.value = n + return self.spike.value + + def return_info(self): + return self.spike class HH(HHLTC): - r"""Hodgkin–Huxley neuron model. - - **Model Descriptions** + r"""Hodgkin–Huxley neuron model. - The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of - the nerve action potential is one of the most successful mathematical models of - a complex biological process that has ever been formulated. The basic concepts - expressed in the model have proved a valid approach to the study of bio-electrical - activity from the most primitive single-celled organisms such as *Paramecium*, - right through to the neurons within our own brains. + **Model Descriptions** - Mathematically, the model is given by, + The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of + the nerve action potential is one of the most successful mathematical models of + a complex biological process that has ever been formulated. The basic concepts + expressed in the model have proved a valid approach to the study of bio-electrical + activity from the most primitive single-celled organisms such as *Paramecium*, + right through to the neurons within our own brains. - .. math:: + Mathematically, the model is given by, - C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) - + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) + .. math:: - \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} + C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) + + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) - &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} + \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} - &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) + &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} - &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) + &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) - &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} + &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) - &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} - &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) + &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} - References:: + &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description - of membrane current and its application to conduction and excitation - in nerve." The Journal of physiology 117.4 (1952): 500. - .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model - .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical - frameworks for oscillatory network dynamics in neuroscience." - The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. + References:: - **Examples** + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description + of membrane current and its application to conduction and excitation + in nerve." The Journal of physiology 117.4 (1952): 500. + .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model + .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical + frameworks for oscillatory network dynamics in neuroscience." + The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. - Here is a simple usage example: + **Examples** - .. code-block:: python + Here is a simple usage example: - import brainpy as bp - import matplotlib.pyplot as plt - - neu = bp.dyn.HH(1,) - - inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600, ) - - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs = inputs) - - plt.plot(runner.mon['ts'], runner.mon['V']) - plt.plot(runner.mon.ts, inputs.value) # show input current - plt.legend(['Membrane potential/mA', 'Input current/mA'], loc='upper right') - - plt.tight_layout() - plt.show() - - The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. + .. code-block:: python + import brainpy as bp + import matplotlib.pyplot as plt - Parameters:: + neu = bp.dyn.HH(1,) - size: sequence of int, int - The size of the neuron group. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable - The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable - The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable - The membrane capacitance. Default is 1 ufarad. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - m_initializer: ArrayType, Initializer, callable - The initializer of m channel. - h_initializer: ArrayType, Initializer, callable - The initializer of h channel. - n_initializer: ArrayType, Initializer, callable - The initializer of n channel. - method: str - The numerical integration method. - name: str - The group name. + inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600, ) - """ + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs = inputs) - def dV(self, V, t, m, h, n, I): - I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) - n2 = n * n - I_K = (self.gK * n2 * n2) * (V - self.EK) - I_leak = self.gL * (V - self.EL) - dVdt = (- I_Na - I_K - I_leak + I) / self.C - return dVdt + plt.plot(runner.mon['ts'], runner.mon['V']) + plt.plot(runner.mon.ts, inputs.value) # show input current + plt.legend(['Membrane potential/mA', 'Input current/mA'], loc='upper right') - @property - def derivative(self): - return JointEq(self.dV, self.dm, self.dh, self.dn) + plt.tight_layout() + plt.show() - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. + + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + m_initializer: ArrayType, Initializer, callable + The initializer of m channel. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + """ + + def dV(self, V, t, m, h, n, I): + I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) + n2 = n * n + I_K = (self.gK * n2 * n2) * (V - self.EK) + I_leak = self.gL * (V - self.EL) + dVdt = (- I_Na - I_K - I_leak + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dm, self.dh, self.dn) + + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class MorrisLecarLTC(NeuDyn): - r"""The Morris-Lecar neuron model with liquid time constant. - - **Model Descriptions** - - The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) - is a two-dimensional "reduced" excitation model applicable to - systems having two non-inactivating voltage-sensitive conductances. - This model was named after Cathy Morris and Harold Lecar, who - derived it in 1981. Because it is two-dimensional, the Morris-Lecar - model is one of the favorite conductance-based models in computational neuroscience. - - The original form of the model employed an instantaneously - responding voltage-sensitive Ca2+ conductance for excitation and a delayed - voltage-dependent K+ conductance for recovery. The equations of the model are: - - .. math:: - - \begin{aligned} - C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - - g_{Leak} (V - V_{Leak}) + I_{ext} \\ - \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} - \end{aligned} - - Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", - which is almost invariably the normalized :math:`K^+`-ion conductance, and - :math:`I_{ext}` is the applied current stimulus. - - - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_Ca 130 mV Equilibrium potentials of Ca+.(mV) - g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) - V_K -84 mV Equilibrium potentials of K+.(mV) - g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) - V_Leak -60 mV Equilibrium potentials of leak current.(mV) - g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) - C 20 \ Membrane capacitance.(uF/cm2) - V1 -1.2 \ Potential at which M_inf = 0.5.(mV) - V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) - V3 2 \ Potential at which W_inf = 0.5.(mV) - V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) - phi 0.04 \ A temperature factor. (1/s) - V_th 10 mV The spike threshold. - ============= ============== ======== ======================================================= - - References:: - - .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. - .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model - .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model - """ - - supported_modes = (bm.NonBatchingMode, bm.BatchingMode) - - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Any = None, - keep_size: bool = False, - mode: bm.Mode = None, - name: str = None, - method: str = 'exp_auto', - init_var: bool = True, - - # neuron parameters - V_Ca: Union[float, ArrayType, Callable] = 130., - g_Ca: Union[float, ArrayType, Callable] = 4.4, - V_K: Union[float, ArrayType, Callable] = -84., - g_K: Union[float, ArrayType, Callable] = 8., - V_leak: Union[float, ArrayType, Callable] = -60., - g_leak: Union[float, ArrayType, Callable] = 2., - C: Union[float, ArrayType, Callable] = 20., - V1: Union[float, ArrayType, Callable] = -1.2, - V2: Union[float, ArrayType, Callable] = 18., - V3: Union[float, ArrayType, Callable] = 2., - V4: Union[float, ArrayType, Callable] = 30., - phi: Union[float, ArrayType, Callable] = 0.04, - V_th: Union[float, ArrayType, Callable] = 10., - W_initializer: Union[Callable, ArrayType] = OneInit(0.02), - V_initializer: Union[Callable, ArrayType] = Uniform(-70., -60.), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - sharding=sharding, - keep_size=keep_size, - mode=mode, - name=name, - method=method) - - # parameters - self.V_Ca = self.init_param(V_Ca) - self.g_Ca = self.init_param(g_Ca) - self.V_K = self.init_param(V_K) - self.g_K = self.init_param(g_K) - self.V_leak = self.init_param(V_leak) - self.g_leak = self.init_param(g_leak) - self.C = self.init_param(C) - self.V1 = self.init_param(V1) - self.V2 = self.init_param(V2) - self.V3 = self.init_param(V3) - self.V4 = self.init_param(V4) - self.phi = self.init_param(phi) - self.V_th = self.init_param(V_th) - - # initializers - self._W_initializer = is_initializer(W_initializer) - self._V_initializer = is_initializer(V_initializer) - - # noise - self.noise = init_noise(noise, self.varshape, num_vars=2) - # integral - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # model - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.V = self.init_variable(self._V_initializer, batch_or_mode) - self.W = self.init_variable(self._W_initializer, batch_or_mode) - self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_or_mode) - - def dV(self, V, t, W, I): - I = self.sum_current_inputs(V, init=I) - M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) - I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) - I_K = self.g_K * W * (V - self.V_K) - I_Leak = self.g_leak * (V - self.V_leak) - dVdt = (- I_Ca - I_K - I_Leak + I) / self.C - return dVdt - - def dW(self, W, t, V): - tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4))) - W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4)) - dWdt = (W_inf - W) / tau_W - return dWdt - - @property - def derivative(self): - return JointEq(self.dV, self.dW) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - V, W = self.integral(self.V.value, self.W.value, t, x, dt) - V += self.sum_delta_inputs() - spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) - self.V.value = V - self.W.value = W - self.spike.value = spike - return spike - - def return_info(self): - return self.spike + r"""The Morris-Lecar neuron model with liquid time constant. + + **Model Descriptions** + + The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) + is a two-dimensional "reduced" excitation model applicable to + systems having two non-inactivating voltage-sensitive conductances. + This model was named after Cathy Morris and Harold Lecar, who + derived it in 1981. Because it is two-dimensional, the Morris-Lecar + model is one of the favorite conductance-based models in computational neuroscience. + + The original form of the model employed an instantaneously + responding voltage-sensitive Ca2+ conductance for excitation and a delayed + voltage-dependent K+ conductance for recovery. The equations of the model are: + + .. math:: + + \begin{aligned} + C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - + g_{Leak} (V - V_{Leak}) + I_{ext} \\ + \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} + \end{aligned} + + Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", + which is almost invariably the normalized :math:`K^+`-ion conductance, and + :math:`I_{ext}` is the applied current stimulus. + + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_Ca 130 mV Equilibrium potentials of Ca+.(mV) + g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) + V_K -84 mV Equilibrium potentials of K+.(mV) + g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) + V_Leak -60 mV Equilibrium potentials of leak current.(mV) + g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) + C 20 \ Membrane capacitance.(uF/cm2) + V1 -1.2 \ Potential at which M_inf = 0.5.(mV) + V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) + V3 2 \ Potential at which W_inf = 0.5.(mV) + V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) + phi 0.04 \ A temperature factor. (1/s) + V_th 10 mV The spike threshold. + ============= ============== ======== ======================================================= + + References:: + + .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. + .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model + .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model + """ + + supported_modes = (bm.NonBatchingMode, bm.BatchingMode) + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_Ca: Union[float, ArrayType, Callable] = 130., + g_Ca: Union[float, ArrayType, Callable] = 4.4, + V_K: Union[float, ArrayType, Callable] = -84., + g_K: Union[float, ArrayType, Callable] = 8., + V_leak: Union[float, ArrayType, Callable] = -60., + g_leak: Union[float, ArrayType, Callable] = 2., + C: Union[float, ArrayType, Callable] = 20., + V1: Union[float, ArrayType, Callable] = -1.2, + V2: Union[float, ArrayType, Callable] = 18., + V3: Union[float, ArrayType, Callable] = 2., + V4: Union[float, ArrayType, Callable] = 30., + phi: Union[float, ArrayType, Callable] = 0.04, + V_th: Union[float, ArrayType, Callable] = 10., + W_initializer: Union[Callable, ArrayType] = OneInit(0.02), + V_initializer: Union[Callable, ArrayType] = Uniform(-70., -60.), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + sharding=sharding, + keep_size=keep_size, + mode=mode, + name=name, + method=method) + + # parameters + self.V_Ca = self.init_param(V_Ca) + self.g_Ca = self.init_param(g_Ca) + self.V_K = self.init_param(V_K) + self.g_K = self.init_param(g_K) + self.V_leak = self.init_param(V_leak) + self.g_leak = self.init_param(g_leak) + self.C = self.init_param(C) + self.V1 = self.init_param(V1) + self.V2 = self.init_param(V2) + self.V3 = self.init_param(V3) + self.V4 = self.init_param(V4) + self.phi = self.init_param(phi) + self.V_th = self.init_param(V_th) + + # initializers + self._W_initializer = is_initializer(W_initializer) + self._V_initializer = is_initializer(V_initializer) + + # noise + self.noise = init_noise(noise, self.varshape, num_vars=2) + # integral + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # model + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.V = self.init_variable(self._V_initializer, batch_or_mode) + self.W = self.init_variable(self._W_initializer, batch_or_mode) + self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_or_mode) + + def dV(self, V, t, W, I): + I = self.sum_current_inputs(V, init=I) + M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) + I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) + I_K = self.g_K * W * (V - self.V_K) + I_Leak = self.g_leak * (V - self.V_leak) + dVdt = (- I_Ca - I_K - I_Leak + I) / self.C + return dVdt + + def dW(self, W, t, V): + tau_W = 1 / (self.phi * bm.cosh((V - self.V3) / (2 * self.V4))) + W_inf = (1 / 2) * (1 + bm.tanh((V - self.V3) / self.V4)) + dWdt = (W_inf - W) / tau_W + return dWdt + + @property + def derivative(self): + return JointEq(self.dV, self.dW) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + V, W = self.integral(self.V.value, self.W.value, t, x, dt) + V += self.sum_delta_inputs() + spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.W.value = W + self.spike.value = spike + return spike + + def return_info(self): + return self.spike class MorrisLecar(MorrisLecarLTC): - r"""The Morris-Lecar neuron model. - - **Model Descriptions** - - The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) - is a two-dimensional "reduced" excitation model applicable to - systems having two non-inactivating voltage-sensitive conductances. - This model was named after Cathy Morris and Harold Lecar, who - derived it in 1981. Because it is two-dimensional, the Morris-Lecar - model is one of the favorite conductance-based models in computational neuroscience. - - The original form of the model employed an instantaneously - responding voltage-sensitive Ca2+ conductance for excitation and a delayed - voltage-dependent K+ conductance for recovery. The equations of the model are: - - .. math:: - - \begin{aligned} - C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - - g_{Leak} (V - V_{Leak}) + I_{ext} \\ - \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} - \end{aligned} - - Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", - which is almost invariably the normalized :math:`K^+`-ion conductance, and - :math:`I_{ext}` is the applied current stimulus. - - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_Ca 130 mV Equilibrium potentials of Ca+.(mV) - g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) - V_K -84 mV Equilibrium potentials of K+.(mV) - g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) - V_Leak -60 mV Equilibrium potentials of leak current.(mV) - g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) - C 20 \ Membrane capacitance.(uF/cm2) - V1 -1.2 \ Potential at which M_inf = 0.5.(mV) - V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) - V3 2 \ Potential at which W_inf = 0.5.(mV) - V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) - phi 0.04 \ A temperature factor. (1/s) - V_th 10 mV The spike threshold. - ============= ============== ======== ======================================================= - - References:: - - .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. - .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model - .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model - """ - - def dV(self, V, t, W, I): - M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) - I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) - I_K = self.g_K * W * (V - self.V_K) - I_Leak = self.g_leak * (V - self.V_leak) - dVdt = (- I_Ca - I_K - I_Leak + I) / self.C - return dVdt - - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + r"""The Morris-Lecar neuron model. + + **Model Descriptions** + + The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) + is a two-dimensional "reduced" excitation model applicable to + systems having two non-inactivating voltage-sensitive conductances. + This model was named after Cathy Morris and Harold Lecar, who + derived it in 1981. Because it is two-dimensional, the Morris-Lecar + model is one of the favorite conductance-based models in computational neuroscience. + + The original form of the model employed an instantaneously + responding voltage-sensitive Ca2+ conductance for excitation and a delayed + voltage-dependent K+ conductance for recovery. The equations of the model are: + + .. math:: + + \begin{aligned} + C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - + g_{Leak} (V - V_{Leak}) + I_{ext} \\ + \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} + \end{aligned} + + Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", + which is almost invariably the normalized :math:`K^+`-ion conductance, and + :math:`I_{ext}` is the applied current stimulus. + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_Ca 130 mV Equilibrium potentials of Ca+.(mV) + g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) + V_K -84 mV Equilibrium potentials of K+.(mV) + g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) + V_Leak -60 mV Equilibrium potentials of leak current.(mV) + g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) + C 20 \ Membrane capacitance.(uF/cm2) + V1 -1.2 \ Potential at which M_inf = 0.5.(mV) + V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) + V3 2 \ Potential at which W_inf = 0.5.(mV) + V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) + phi 0.04 \ A temperature factor. (1/s) + V_th 10 mV The spike threshold. + ============= ============== ======== ======================================================= + + References:: + + .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. + .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model + .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model + """ + + def dV(self, V, t, W, I): + M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) + I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) + I_K = self.g_K * W * (V - self.V_K) + I_Leak = self.g_leak * (V - self.V_leak) + dVdt = (- I_Ca - I_K - I_Leak + I) / self.C + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class WangBuzsakiHHLTC(NeuDyn): - r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant. - - Each model is described by a single compartment and obeys the current balance equation: - - .. math:: - - C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} - - where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the - injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current - :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance - :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant - :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. - - The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion - currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the - Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current - :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, - where the activation variable :math:`m` is assumed fast and substituted by its steady-state - function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; - :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. - The inactivation variable :math:`h` obeys a first-order kinetics: - - .. math:: - - \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) + r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant. - where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and - :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; - :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` + Each model is described by a single compartment and obeys the current balance equation: - The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, - where the activation variable :math:`n` obeys the following equation: + .. math:: - .. math:: - - \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) - - with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and - :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and - :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. - - - References:: - - .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic - inhibition in a hippocampal interneuronal network model. Journal of - neuroscience, 16(20), pp.6402-6413. - - **Examples** - - Here is a simple usage example: - - .. code-block:: python - - import brainpy as bp - import matplotlib.pyplot as plt - - neu = bp.dyn.WangBuzsakiHHLTC(1, ) - - inputs = bp.inputs.ramp_input(.1, 1, 700, 100, 600, ) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) - plt.plot(runner.mon['ts'], runner.mon['V']) - plt.legend(['Membrane potential/mA', loc='upper right') - plt.tight_layout() - plt.show() - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable - The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable - The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable - The membrane capacitance. Default is 1 ufarad. - phi: float, ArrayType, Initializer, callable - The temperature regulator constant. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - h_initializer: ArrayType, Initializer, callable - The initializer of h channel. - n_initializer: ArrayType, Initializer, callable - The initializer of n channel. - method: str - The numerical integration method. - name: str - The group name. - - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Any = None, - keep_size: bool = False, - mode: bm.Mode = None, - name: str = None, - method: str = 'exp_auto', - init_var: bool = True, - - # neuron parameters - ENa: Union[float, ArrayType, Callable] = 55., - gNa: Union[float, ArrayType, Callable] = 35., - EK: Union[float, ArrayType, Callable] = -90., - gK: Union[float, ArrayType, Callable] = 9., - EL: Union[float, ArrayType, Callable] = -65, - gL: Union[float, ArrayType, Callable] = 0.1, - V_th: Union[float, ArrayType, Callable] = 20., - phi: Union[float, ArrayType, Callable] = 5.0, - C: Union[float, ArrayType, Callable] = 1.0, - V_initializer: Union[Callable, ArrayType] = OneInit(-65.), - h_initializer: Union[Callable, ArrayType] = OneInit(0.6), - n_initializer: Union[Callable, ArrayType] = OneInit(0.32), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - sharding=sharding, - keep_size=keep_size, - mode=mode, - name=name, - method=method) - - # parameters - self.ENa = self.init_param(ENa) - self.EK = self.init_param(EK) - self.EL = self.init_param(EL) - self.gNa = self.init_param(gNa) - self.gK = self.init_param(gK) - self.gL = self.init_param(gL) - self.phi = self.init_param(phi) - self.C = self.init_param(C) - self.V_th = self.init_param(V_th) - - # initializers - self._h_initializer = is_initializer(h_initializer) - self._n_initializer = is_initializer(n_initializer) - self._V_initializer = is_initializer(V_initializer) - - # noise - self.noise = init_noise(noise, self.varshape, num_vars=3) - # integral - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # model - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - self.V = self.init_variable(self._V_initializer, batch_size) - self.h = self.init_variable(self._h_initializer, batch_size) - self.n = self.init_variable(self._n_initializer, batch_size) - self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def m_inf(self, V): - # alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) - alpha = 1. / bm.exprel(-0.1 * (V + 35)) - beta = 4. * bm.exp(-(V + 60.) / 18.) - return alpha / (alpha + beta) - - def dh(self, h, t, V): - alpha = 0.07 * bm.exp(-(V + 58) / 20) - beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) - dhdt = alpha * (1 - h) - beta * h - return self.phi * dhdt - - def dn(self, n, t, V): - # alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) - alpha = 1. / bm.exprel(-0.1 * (V + 34)) - beta = 0.125 * bm.exp(-(V + 44) / 80) - dndt = alpha * (1 - n) - beta * n - return self.phi * dndt - - def dV(self, V, t, h, n, I): - I = self.sum_current_inputs(V, init=I) - INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + I) / self.C - return dVdt - - @property - def derivative(self): - return JointEq(self.dV, self.dh, self.dn) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - V, h, n = self.integral(self.V.value, self.h.value, self.n.value, t, x, dt) - V += self.sum_delta_inputs() - self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - self.V.value = V - self.h.value = h - self.n.value = n - return self.spike.value - - def return_info(self): - return self.spike + C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} + + where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the + injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current + :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance + :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant + :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. + + The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion + currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the + Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current + :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, + where the activation variable :math:`m` is assumed fast and substituted by its steady-state + function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; + :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. + The inactivation variable :math:`h` obeys a first-order kinetics: + + .. math:: + + \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) + + where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and + :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; + :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` + + The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, + where the activation variable :math:`n` obeys the following equation: + + .. math:: + + \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) + + with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and + :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and + :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. + + + References:: + + .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic + inhibition in a hippocampal interneuronal network model. Journal of + neuroscience, 16(20), pp.6402-6413. + + **Examples** + + Here is a simple usage example: + + .. code-block:: python + + import brainpy as bp + import matplotlib.pyplot as plt + + neu = bp.dyn.WangBuzsakiHHLTC(1, ) + + inputs = bp.inputs.ramp_input(.1, 1, 700, 100, 600, ) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + plt.plot(runner.mon['ts'], runner.mon['V']) + plt.legend(['Membrane potential/mA', loc='upper right') + plt.tight_layout() + plt.show() + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + phi: float, ArrayType, Initializer, callable + The temperature regulator constant. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + ENa: Union[float, ArrayType, Callable] = 55., + gNa: Union[float, ArrayType, Callable] = 35., + EK: Union[float, ArrayType, Callable] = -90., + gK: Union[float, ArrayType, Callable] = 9., + EL: Union[float, ArrayType, Callable] = -65, + gL: Union[float, ArrayType, Callable] = 0.1, + V_th: Union[float, ArrayType, Callable] = 20., + phi: Union[float, ArrayType, Callable] = 5.0, + C: Union[float, ArrayType, Callable] = 1.0, + V_initializer: Union[Callable, ArrayType] = OneInit(-65.), + h_initializer: Union[Callable, ArrayType] = OneInit(0.6), + n_initializer: Union[Callable, ArrayType] = OneInit(0.32), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + sharding=sharding, + keep_size=keep_size, + mode=mode, + name=name, + method=method) + + # parameters + self.ENa = self.init_param(ENa) + self.EK = self.init_param(EK) + self.EL = self.init_param(EL) + self.gNa = self.init_param(gNa) + self.gK = self.init_param(gK) + self.gL = self.init_param(gL) + self.phi = self.init_param(phi) + self.C = self.init_param(C) + self.V_th = self.init_param(V_th) + + # initializers + self._h_initializer = is_initializer(h_initializer) + self._n_initializer = is_initializer(n_initializer) + self._V_initializer = is_initializer(V_initializer) + + # noise + self.noise = init_noise(noise, self.varshape, num_vars=3) + # integral + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # model + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.V = self.init_variable(self._V_initializer, batch_size) + self.h = self.init_variable(self._h_initializer, batch_size) + self.n = self.init_variable(self._n_initializer, batch_size) + self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def m_inf(self, V): + # alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + alpha = 1. / bm.exprel(-0.1 * (V + 35)) + beta = 4. * bm.exp(-(V + 60.) / 18.) + return alpha / (alpha + beta) + + def dh(self, h, t, V): + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + dhdt = alpha * (1 - h) - beta * h + return self.phi * dhdt + + def dn(self, n, t, V): + # alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + alpha = 1. / bm.exprel(-0.1 * (V + 34)) + beta = 0.125 * bm.exp(-(V + 44) / 80) + dndt = alpha * (1 - n) - beta * n + return self.phi * dndt + + def dV(self, V, t, h, n, I): + I = self.sum_current_inputs(V, init=I) + INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + I) / self.C + return dVdt + + @property + def derivative(self): + return JointEq(self.dV, self.dh, self.dn) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + V, h, n = self.integral(self.V.value, self.h.value, self.n.value, t, x, dt) + V += self.sum_delta_inputs() + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.h.value = h + self.n.value = n + return self.spike.value + + def return_info(self): + return self.spike class WangBuzsakiHH(WangBuzsakiHHLTC): - r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. - - Each model is described by a single compartment and obeys the current balance equation: - - .. math:: - - C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} - - where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the - injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current - :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance - :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant - :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. - - The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion - currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the - Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current - :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, - where the activation variable :math:`m` is assumed fast and substituted by its steady-state - function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; - :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. - The inactivation variable :math:`h` obeys a first-order kinetics: - - .. math:: - - \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) - - where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and - :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; - :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` - - The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, - where the activation variable :math:`n` obeys the following equation: - - .. math:: - - \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) - - with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and - :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and - :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. - - - References:: - - .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic - inhibition in a hippocampal interneuronal network model. Journal of - neuroscience, 16(20), pp.6402-6413. - - **Examples** - - Here is an example: - - .. code-block:: python - - import brainpy as bp - import matplotlib.pyplot as plt - - neu = bp.dyn.WangBuzsakiHH(1, ) - - inputs = bp.inputs.ramp_input(.1, 1, 700, 100, 600, ) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) - plt.plot(runner.mon['ts'], runner.mon['V']) - plt.legend(['Membrane potential/mA', loc='upper right') - plt.tight_layout() - plt.show() - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable - The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable - The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable - The membrane capacitance. Default is 1 ufarad. - phi: float, ArrayType, Initializer, callable - The temperature regulator constant. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - h_initializer: ArrayType, Initializer, callable - The initializer of h channel. - n_initializer: ArrayType, Initializer, callable - The initializer of n channel. - method: str - The numerical integration method. - name: str - The group name. - - """ - - def dV(self, V, t, h, n, I): - INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + I) / self.C - return dVdt - - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. + + Each model is described by a single compartment and obeys the current balance equation: + + .. math:: + + C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} + + where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the + injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current + :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance + :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant + :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. + + The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion + currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the + Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current + :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, + where the activation variable :math:`m` is assumed fast and substituted by its steady-state + function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; + :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. + The inactivation variable :math:`h` obeys a first-order kinetics: + + .. math:: + + \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) + + where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and + :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; + :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` + + The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, + where the activation variable :math:`n` obeys the following equation: + + .. math:: + + \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) + + with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and + :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and + :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. + + + References:: + + .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic + inhibition in a hippocampal interneuronal network model. Journal of + neuroscience, 16(20), pp.6402-6413. + + **Examples** + + Here is an example: + + .. code-block:: python + + import brainpy as bp + import matplotlib.pyplot as plt + + neu = bp.dyn.WangBuzsakiHH(1, ) + + inputs = bp.inputs.ramp_input(.1, 1, 700, 100, 600, ) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + plt.plot(runner.mon['ts'], runner.mon['V']) + plt.legend(['Membrane potential/mA', loc='upper right') + plt.tight_layout() + plt.show() + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + phi: float, ArrayType, Initializer, callable + The temperature regulator constant. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + """ + + def dV(self, V, t, h, n, I): + INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + I) / self.C + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 30b8b29ca..ef0d2d19c 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -13,141 +13,141 @@ from brainpy.types import Shape, ArrayType, Sharding __all__ = [ - 'IF', - 'IFLTC', - 'Lif', - 'LifLTC', - 'LifRef', - 'LifRefLTC', - 'ExpIF', - 'ExpIFLTC', - 'ExpIFRef', - 'ExpIFRefLTC', - 'AdExIF', - 'AdExIFLTC', - 'AdExIFRef', - 'AdExIFRefLTC', - 'QuaIF', - 'QuaIFLTC', - 'QuaIFRef', - 'QuaIFRefLTC', - 'AdQuaIF', - 'AdQuaIFLTC', - 'AdQuaIFRef', - 'AdQuaIFRefLTC', - 'Gif', - 'GifLTC', - 'GifRef', - 'GifRefLTC', - 'Izhikevich', - 'IzhikevichLTC', - 'IzhikevichRef', - 'IzhikevichRefLTC', + 'IF', + 'IFLTC', + 'Lif', + 'LifLTC', + 'LifRef', + 'LifRefLTC', + 'ExpIF', + 'ExpIFLTC', + 'ExpIFRef', + 'ExpIFRefLTC', + 'AdExIF', + 'AdExIFLTC', + 'AdExIFRef', + 'AdExIFRefLTC', + 'QuaIF', + 'QuaIFLTC', + 'QuaIFRef', + 'QuaIFRefLTC', + 'AdQuaIF', + 'AdQuaIFLTC', + 'AdQuaIFRef', + 'AdQuaIFRefLTC', + 'Gif', + 'GifLTC', + 'GifRef', + 'GifRefLTC', + 'Izhikevich', + 'IzhikevichLTC', + 'IzhikevichRef', + 'IzhikevichRefLTC', ] class IFLTC(GradNeuDyn): - r"""Leaky Integrator Model %s. + r"""Leaky Integrator Model %s. - **Model Descriptions** + **Model Descriptions** - This class implements a leaky integrator model, in which its dynamics is - given by: + This class implements a leaky integrator model, in which its dynamics is + given by: - .. math:: + .. math:: - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`\tau` is the time constant, and :math:`R` is the - resistance. + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`\tau` is the time constant, and :math:`R` is the + resistance. - Args: - %s - %s - %s - """ + Args: + %s + %s + %s + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = 0., - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.tau = self.init_param(tau) - self.R = self.init_param(R) - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # integral - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def derivative(self, V, t, I): - I = self.sum_current_inputs(V, init=I) - return (-V + self.V_rest + self.R * I) / self.tau - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - self.V.value = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - return self.V.value - - def return_info(self): - return self.V + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = 0., + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.tau = self.init_param(tau) + self.R = self.init_param(R) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + I = self.sum_current_inputs(V, init=I) + return (-V + self.V_rest + self.R * I) / self.tau + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + self.V.value = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + return self.V.value + + def return_info(self): + return self.V class IF(IFLTC): - def derivative(self, V, t, I): - return (-V + self.V_rest + self.R * I) / self.tau + def derivative(self, V, t, I): + return (-V + self.V_rest + self.R * I) / self.tau - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) IF.__doc__ = IFLTC.__doc__ % ('', if_doc, pneu_doc, dpneu_doc) @@ -155,199 +155,199 @@ def update(self, x=None): class LifLTC(GradNeuDyn): - r"""Leaky integrate-and-fire neuron model with liquid time-constant. + r"""Leaky integrate-and-fire neuron model with liquid time-constant. - The formal equations of a LIF model [1]_ is given by: + The formal equations of a LIF model [1]_ is given by: - .. math:: + .. math:: - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ - \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`V_{reset}` is the reset membrane potential, - :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, - and :math:`I` is the time-variant synaptic inputs. + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + and :math:`I` is the time-variant synaptic inputs. - .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model - neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. - **Examples** + **Examples** - There is an example usage: mustang u r lvd by the blonde boy + There is an example usage: mustang u r lvd by the blonde boy - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - lif = bp.dyn.LifLTC(1) + lif = bp.dyn.LifLTC(1) - # raise input current from 4 mA to 40 mA - inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600,) + # raise input current from 4 mA to 40 mA + inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600,) - runner = bp.DSRunner(lif, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(lif, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s + Args: + %s + %s + %s - """ + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = 0., - V_reset: Union[float, ArrayType, Callable] = -5., - V_th: Union[float, ArrayType, Callable] = 20., - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # noise - noise: Optional[Union[float, ArrayType, Callable]] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.V_reset = self.offset_scaling(self.init_param(V_reset)) - self.V_th = self.offset_scaling(self.init_param(V_th)) - self.tau = self.init_param(tau) - self.R = self.init_param(R) - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # noise - self.noise = init_noise(noise, self.varshape) - - # integral - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def derivative(self, V, t, I): - I = self.sum_current_inputs(V, init=I) - return (-V + self.V_rest + self.R * I) / self.tau - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike - else: - raise ValueError - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - - self.V.value = V - self.spike.value = spike - return spike - - def return_info(self): - return self.spike + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = 0., + V_reset: Union[float, ArrayType, Callable] = -5., + V_th: Union[float, ArrayType, Callable] = 20., + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Optional[Union[float, ArrayType, Callable]] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.V_reset = self.offset_scaling(self.init_param(V_reset)) + self.V_th = self.offset_scaling(self.init_param(V_th)) + self.tau = self.init_param(tau) + self.R = self.init_param(R) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # noise + self.noise = init_noise(noise, self.varshape) + + # integral + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + I = self.sum_current_inputs(V, init=I) + return (-V + self.V_rest + self.R * I) / self.tau + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike + else: + raise ValueError + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + + self.V.value = V + self.spike.value = spike + return spike + + def return_info(self): + return self.spike class Lif(LifLTC): - r"""Leaky integrate-and-fire neuron model. + r"""Leaky integrate-and-fire neuron model. - The formal equations of a LIF model [1]_ is given by: + The formal equations of a LIF model [1]_ is given by: - .. math:: + .. math:: - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ - \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`V_{reset}` is the reset membrane potential, - :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, - and :math:`I` is the time-variant synaptic inputs. + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + and :math:`I` is the time-variant synaptic inputs. - .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model - neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - lif = bp.dyn.Lif(1) + lif = bp.dyn.Lif(1) - # raise input current from 4 mA to 40 mA - inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600,) - runner = bp.DSRunner(lif, monitors=['V']) - runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + # raise input current from 4 mA to 40 mA + inputs = bp.inputs.ramp_input(4, 40, 700, 100, 600,) + runner = bp.DSRunner(lif, monitors=['V']) + runner.run(inputs=inputs) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s + Args: + %s + %s + %s - """ + """ - def derivative(self, V, t, I): - return (-V + self.V_rest + self.R * I) / self.tau + def derivative(self, V, t, I): + return (-V + self.V_rest + self.R * I) / self.tau - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) Lif.__doc__ = Lif.__doc__ % (lif_doc, pneu_doc, dpneu_doc) @@ -355,219 +355,219 @@ def update(self, x=None): class LifRefLTC(LifLTC): - r"""Leaky integrate-and-fire neuron model with liquid time-constant which has refractory periods . + r"""Leaky integrate-and-fire neuron model with liquid time-constant which has refractory periods . - The formal equations of a LIF model [1]_ is given by: + The formal equations of a LIF model [1]_ is given by: - .. math:: + .. math:: - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ - \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad - \text{last} \quad \tau_{ref} \quad \text{ms} + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad + \text{last} \quad \tau_{ref} \quad \text{ms} - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`V_{reset}` is the reset membrane potential, - :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, - :math:`\tau_{ref}` is the refractory time period, - and :math:`I` is the time-variant synaptic inputs. + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + :math:`\tau_{ref}` is the refractory time period, + and :math:`I` is the time-variant synaptic inputs. - .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model - neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.LifRefLTC(1, ) + neu = bp.dyn.LifRefLTC(1, ) - # example for section input - inputs = bp.inputs.section_input([0., 21., 0.], [100., 300., 100.]) + # example for section input + inputs = bp.inputs.section_input([0., 21., 0.], [100., 300., 100.]) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s - %s + Args: + %s + %s + %s + %s - """ + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - detach_spk: bool = False, - spk_reset: str = 'soft', - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_rest: Union[float, ArrayType, Callable] = 0., - V_reset: Union[float, ArrayType, Callable] = -5., - V_th: Union[float, ArrayType, Callable] = 20., - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Optional[Union[float, ArrayType, Callable]] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_rest=V_rest, - V_reset=V_reset, - V_th=V_th, - R=R, - tau=tau, - V_initializer=V_initializer, - - noise=noise, - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e7) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike_no_grad - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike_no_grad - else: - raise ValueError - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + detach_spk: bool = False, + spk_reset: str = 'soft', + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = 0., + V_reset: Union[float, ArrayType, Callable] = -5., + V_th: Union[float, ArrayType, Callable] = 20., + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Optional[Union[float, ArrayType, Callable]] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + R=R, + tau=tau, + V_initializer=V_initializer, + + noise=noise, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike_no_grad + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike_no_grad + else: + raise ValueError + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike class LifRef(LifRefLTC): - r"""Leaky integrate-and-fire neuron model %s which has refractory periods. + r"""Leaky integrate-and-fire neuron model %s which has refractory periods. - The formal equations of a LIF model [1]_ is given by: + The formal equations of a LIF model [1]_ is given by: - .. math:: + .. math:: - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ - \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad - \text{last} \quad \tau_{ref} \quad \text{ms} + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad + \text{last} \quad \tau_{ref} \quad \text{ms} - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`V_{reset}` is the reset membrane potential, - :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, - :math:`\tau_{ref}` is the refractory time period, - and :math:`I` is the time-variant synaptic inputs. + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + :math:`\tau_{ref}` is the refractory time period, + and :math:`I` is the time-variant synaptic inputs. - .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model - neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.LifRef(1, ) + neu = bp.dyn.LifRef(1, ) - # example for section input - inputs = bp.inputs.section_input([0., 21., 0.], [100., 300., 100.]) + # example for section input + inputs = bp.inputs.section_input([0., 21., 0.], [100., 300., 100.]) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s - %s + Args: + %s + %s + %s + %s - """ + """ - def derivative(self, V, t, I): - return (-V + self.V_rest + self.R * I) / self.tau + def derivative(self, V, t, I): + return (-V + self.V_rest + self.R * I) / self.tau - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) LifRef.__doc__ = LifRefLTC.__doc__ % (lif_doc, pneu_doc, dpneu_doc, ref_doc) @@ -575,7 +575,218 @@ def update(self, x=None): class ExpIFLTC(GradNeuDyn): - r"""Exponential integrate-and-fire neuron model with liquid time-constant. + r"""Exponential integrate-and-fire neuron model with liquid time-constant. + + **Model Descriptions** + + In the exponential integrate-and-fire model [1]_, the differential + equation for the membrane potential is given by + + .. math:: + + \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ + \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} + + This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` + and "threshold" :math:`\vartheta_{rh}`. + + The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` + defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to + :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, + where :math:`\tau_{\rm ref}` is an absolute refractory time. + If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, + its exact value does not play any role. The reason is that the upswing of the action + potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in + an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical + convenience. For a formal mathematical analysis of the model, the threshold can be pushed + to infinity. + + The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk + and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. + It is one of the prominent examples of a precise theoretical prediction in computational + neuroscience that was later confirmed by experimental neuroscience. + + Two important remarks: + + - (i) The right-hand side of the above equation contains a nonlinearity + that can be directly extracted from experimental data [3]_. In this sense the exponential + nonlinearity is not an arbitrary choice but directly supported by experimental evidence. + - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing + rate for constant input, and the linear response to fluctuations, even in the presence + of input noise [4]_. + + **References** + + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models + of cognition. Cambridge University Press. + .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, + Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves + are reliable predictors of naturalistic pyramidal-neuron voltage + traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. + .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear + integrate-and-fire neurons to modulated current-based and + conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. + .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire + + **Examples** + + There is a simple usage example:: + + import brainpy as bp + + neu = bp.dyn.ExpIFLTC(1, ) + + # example for section input + inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) + + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + + + **Model Parameters** + + ============= ============== ======== =================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- --------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + R 1 \ Membrane resistance. + tau 10 \ Membrane time constant. Compute by R * C. + tau_ref 1.7 \ Refractory period length. + ============= ============== ======== =================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + + """ + + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -55., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.V_reset = self.offset_scaling(self.init_param(V_reset)) + self.V_th = self.offset_scaling(self.init_param(V_th)) + self.V_T = self.offset_scaling(self.init_param(V_T)) + self.delta_T = self.std_scaling(self.init_param(delta_T)) + self.tau = self.init_param(tau) + self.R = self.init_param(R) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # noise + self.noise = init_noise(noise, self.varshape) + # integral + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + I = self.sum_current_inputs(V, init=I) + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau + return dvdt + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike + else: + raise ValueError + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + + self.V.value = V + self.spike.value = spike + return spike + + def return_info(self): + return self.spike + + +class ExpIF(ExpIFLTC): + r"""Exponential integrate-and-fire neuron model. **Model Descriptions** @@ -638,7 +849,7 @@ class ExpIFLTC(GradNeuDyn): import brainpy as bp - neu = bp.dyn.ExpIFLTC(1, ) + neu = bp.dyn.ExpIF(1, ) # example for section input inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) @@ -677,1276 +888,1065 @@ class ExpIFLTC(GradNeuDyn): ================== ================= ========================================================= - """ - - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -55., - V_T: Union[float, ArrayType, Callable] = -59.9, - delta_T: Union[float, ArrayType, Callable] = 3.48, - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.V_reset = self.offset_scaling(self.init_param(V_reset)) - self.V_th = self.offset_scaling(self.init_param(V_th)) - self.V_T = self.offset_scaling(self.init_param(V_T)) - self.delta_T = self.std_scaling(self.init_param(delta_T)) - self.tau = self.init_param(tau) - self.R = self.init_param(R) - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # noise - self.noise = init_noise(noise, self.varshape) - # integral - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def derivative(self, V, t, I): - I = self.sum_current_inputs(V, init=I) - exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau - return dvdt - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike - else: - raise ValueError - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - - self.V.value = V - self.spike.value = spike - return spike - - def return_info(self): - return self.spike -class ExpIF(ExpIFLTC): - r"""Exponential integrate-and-fire neuron model. - - **Model Descriptions** - - In the exponential integrate-and-fire model [1]_, the differential - equation for the membrane potential is given by - - .. math:: - - \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ - \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} - - This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` - and "threshold" :math:`\vartheta_{rh}`. - - The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` - defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to - :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, - where :math:`\tau_{\rm ref}` is an absolute refractory time. - If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, - its exact value does not play any role. The reason is that the upswing of the action - potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in - an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical - convenience. For a formal mathematical analysis of the model, the threshold can be pushed - to infinity. - - The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk - and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. - It is one of the prominent examples of a precise theoretical prediction in computational - neuroscience that was later confirmed by experimental neuroscience. - - Two important remarks: - - - (i) The right-hand side of the above equation contains a nonlinearity - that can be directly extracted from experimental data [3]_. In this sense the exponential - nonlinearity is not an arbitrary choice but directly supported by experimental evidence. - - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing - rate for constant input, and the linear response to fluctuations, even in the presence - of input noise [4]_. - - **References** - - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). - Neuronal dynamics: From single neurons to networks and models - of cognition. Cambridge University Press. - .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, - Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves - are reliable predictors of naturalistic pyramidal-neuron voltage - traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. - .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear - integrate-and-fire neurons to modulated current-based and - conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. - .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire - - **Examples** - - There is a simple usage example:: - - import brainpy as bp + Args: + %s + %s + """ - neu = bp.dyn.ExpIF(1, ) + def derivative(self, V, t, I): + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau + return dvdt - # example for section input - inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) +class ExpIFRefLTC(ExpIFLTC): + r"""Exponential integrate-and-fire neuron model with liquid time-constant. + **Model Descriptions** - **Model Parameters** + In the exponential integrate-and-fire model [1]_, the differential + equation for the membrane potential is given by - ============= ============== ======== =================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- --------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - R 1 \ Membrane resistance. - tau 10 \ Membrane time constant. Compute by R * C. - tau_ref 1.7 \ Refractory period length. - ============= ============== ======== =================================================== + .. math:: - **Model Variables** + \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ + \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` + and "threshold" :math:`\vartheta_{rh}`. + The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` + defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to + :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, + where :math:`\tau_{\rm ref}` is an absolute refractory time. + If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, + its exact value does not play any role. The reason is that the upswing of the action + potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in + an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical + convenience. For a formal mathematical analysis of the model, the threshold can be pushed + to infinity. + The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk + and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. + It is one of the prominent examples of a precise theoretical prediction in computational + neuroscience that was later confirmed by experimental neuroscience. + Two important remarks: - Args: - %s - %s - """ + - (i) The right-hand side of the above equation contains a nonlinearity + that can be directly extracted from experimental data [3]_. In this sense the exponential + nonlinearity is not an arbitrary choice but directly supported by experimental evidence. + - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing + rate for constant input, and the linear response to fluctuations, even in the presence + of input noise [4]_. - def derivative(self, V, t, I): - exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau - return dvdt + **References** - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models + of cognition. Cambridge University Press. + .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, + Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves + are reliable predictors of naturalistic pyramidal-neuron voltage + traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. + .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear + integrate-and-fire neurons to modulated current-based and + conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. + .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire + **Examples** -class ExpIFRefLTC(ExpIFLTC): - r"""Exponential integrate-and-fire neuron model with liquid time-constant. - - **Model Descriptions** - - In the exponential integrate-and-fire model [1]_, the differential - equation for the membrane potential is given by - - .. math:: - - \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ - \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} - - This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` - and "threshold" :math:`\vartheta_{rh}`. - - The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` - defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to - :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, - where :math:`\tau_{\rm ref}` is an absolute refractory time. - If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, - its exact value does not play any role. The reason is that the upswing of the action - potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in - an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical - convenience. For a formal mathematical analysis of the model, the threshold can be pushed - to infinity. - - The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk - and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. - It is one of the prominent examples of a precise theoretical prediction in computational - neuroscience that was later confirmed by experimental neuroscience. - - Two important remarks: - - - (i) The right-hand side of the above equation contains a nonlinearity - that can be directly extracted from experimental data [3]_. In this sense the exponential - nonlinearity is not an arbitrary choice but directly supported by experimental evidence. - - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing - rate for constant input, and the linear response to fluctuations, even in the presence - of input noise [4]_. - - **References** - - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). - Neuronal dynamics: From single neurons to networks and models - of cognition. Cambridge University Press. - .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, - Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves - are reliable predictors of naturalistic pyramidal-neuron voltage - traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. - .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear - integrate-and-fire neurons to modulated current-based and - conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. - .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire - - **Examples** - - There is a simple usage example:: + There is a simple usage example:: - import brainpy as bp + import brainpy as bp - neu = bp.dyn.ExpIFRefLTC(1, ) + neu = bp.dyn.ExpIFRefLTC(1, ) - # example for section input - inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) + # example for section input + inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - **Model Parameters** + **Model Parameters** - ============= ============== ======== =================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- --------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - R 1 \ Membrane resistance. - tau 10 \ Membrane time constant. Compute by R * C. - tau_ref 1.7 \ Refractory period length. - ============= ============== ======== =================================================== + ============= ============== ======== =================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- --------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + R 1 \ Membrane resistance. + tau 10 \ Membrane time constant. Compute by R * C. + tau_ref 1.7 \ Refractory period length. + ============= ============== ======== =================================================== - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - %s + Args: + %s + %s + %s - """ + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - detach_spk: bool = False, - spk_reset: str = 'soft', - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -55., - V_T: Union[float, ArrayType, Callable] = -59.9, - delta_T: Union[float, ArrayType, Callable] = 3.48, - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_rest=V_rest, - V_reset=V_reset, - V_th=V_th, - V_T=V_T, - delta_T=delta_T, - R=R, - tau=tau, - V_initializer=V_initializer, - noise=noise, - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # integral - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e7) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike_no_grad - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike_no_grad - else: - raise ValueError - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + detach_spk: bool = False, + spk_reset: str = 'soft', + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -55., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_T=V_T, + delta_T=delta_T, + R=R, + tau=tau, + V_initializer=V_initializer, + noise=noise, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike_no_grad + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike_no_grad + else: + raise ValueError + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike class ExpIFRef(ExpIFRefLTC): - r"""Exponential integrate-and-fire neuron model . - - **Model Descriptions** - - In the exponential integrate-and-fire model [1]_, the differential - equation for the membrane potential is given by - - .. math:: - - \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ - \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} - - This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` - and "threshold" :math:`\vartheta_{rh}`. - - The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` - defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to - :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, - where :math:`\tau_{\rm ref}` is an absolute refractory time. - If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, - its exact value does not play any role. The reason is that the upswing of the action - potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in - an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical - convenience. For a formal mathematical analysis of the model, the threshold can be pushed - to infinity. - - The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk - and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. - It is one of the prominent examples of a precise theoretical prediction in computational - neuroscience that was later confirmed by experimental neuroscience. - - Two important remarks: - - - (i) The right-hand side of the above equation contains a nonlinearity - that can be directly extracted from experimental data [3]_. In this sense the exponential - nonlinearity is not an arbitrary choice but directly supported by experimental evidence. - - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing - rate for constant input, and the linear response to fluctuations, even in the presence - of input noise [4]_. - - **References** - - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). - Neuronal dynamics: From single neurons to networks and models - of cognition. Cambridge University Press. - .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, - Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves - are reliable predictors of naturalistic pyramidal-neuron voltage - traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. - .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear - integrate-and-fire neurons to modulated current-based and - conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. - .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire - - **Examples** - - There is a simple usage example:: - - import brainpy as bp + r"""Exponential integrate-and-fire neuron model . - neu = bp.dyn.ExpIFRef(1, ) - - # example for section input - inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) + **Model Descriptions** - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + In the exponential integrate-and-fire model [1]_, the differential + equation for the membrane potential is given by - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) + .. math:: + \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ + \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} - **Model Parameters** + This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` + and "threshold" :math:`\vartheta_{rh}`. - ============= ============== ======== =================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- --------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - R 1 \ Membrane resistance. - tau 10 \ Membrane time constant. Compute by R * C. - tau_ref 1.7 \ Refractory period length. - ============= ============== ======== =================================================== + The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` + defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to + :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, + where :math:`\tau_{\rm ref}` is an absolute refractory time. + If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, + its exact value does not play any role. The reason is that the upswing of the action + potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in + an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical + convenience. For a formal mathematical analysis of the model, the threshold can be pushed + to infinity. - **Model Variables** + The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk + and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. + It is one of the prominent examples of a precise theoretical prediction in computational + neuroscience that was later confirmed by experimental neuroscience. - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + Two important remarks: + - (i) The right-hand side of the above equation contains a nonlinearity + that can be directly extracted from experimental data [3]_. In this sense the exponential + nonlinearity is not an arbitrary choice but directly supported by experimental evidence. + - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing + rate for constant input, and the linear response to fluctuations, even in the presence + of input noise [4]_. + **References** - Args: - %s - %s - %s - """ + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models + of cognition. Cambridge University Press. + .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, + Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves + are reliable predictors of naturalistic pyramidal-neuron voltage + traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. + .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear + integrate-and-fire neurons to modulated current-based and + conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. + .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire - def derivative(self, V, t, I): - exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau - return dvdt + **Examples** - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + There is a simple usage example:: + import brainpy as bp -ExpIF.__doc__ = ExpIF.__doc__ % (pneu_doc, dpneu_doc) -ExpIFRefLTC.__doc__ = ExpIFRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) -ExpIFRef.__doc__ = ExpIFRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) -ExpIFLTC.__doc__ = ExpIFLTC.__doc__ % () + neu = bp.dyn.ExpIFRef(1, ) + # example for section input + inputs = bp.inputs.section_input([0., 5., 0.], [100., 300., 100.]) -class AdExIFLTC(GradNeuDyn): - r"""Adaptive exponential integrate-and-fire neuron model with liquid time-constant. + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - **Model Descriptions** + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - The **adaptive exponential integrate-and-fire model**, also called AdEx, is a - spiking neuron model with two variables [1]_ [2]_. - .. math:: + **Model Parameters** - \begin{aligned} - \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ - \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w - \end{aligned} + ============= ============== ======== =================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- --------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + R 1 \ Membrane resistance. + tau 10 \ Membrane time constant. Compute by R * C. + tau_ref 1.7 \ Refractory period length. + ============= ============== ======== =================================================== - once the membrane potential reaches the spike threshold, + **Model Variables** - .. math:: + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. - The first equation describes the dynamics of the membrane potential and includes - an activation term with an exponential voltage dependence. Voltage is coupled to - a second equation which describes adaptation. Both variables are reset if an action - potential has been triggered. The combination of adaptation and exponential voltage - dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. - The adaptive exponential integrate-and-fire model is capable of describing known - neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, - initial bursting, fast spiking, and regular spiking. + Args: + %s + %s + %s + """ - **References** + def derivative(self, V, t, I): + exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau + return dvdt - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) - **Examples** - An example usage: +ExpIF.__doc__ = ExpIF.__doc__ % (pneu_doc, dpneu_doc) +ExpIFRefLTC.__doc__ = ExpIFRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +ExpIFRef.__doc__ = ExpIFRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +ExpIFLTC.__doc__ = ExpIFLTC.__doc__ % () - .. code-block:: python - import brainpy as bp +class AdExIFLTC(GradNeuDyn): + r"""Adaptive exponential integrate-and-fire neuron model with liquid time-constant. - neu = bp.dyn.AdExIFLTC(2) + **Model Descriptions** - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + The **adaptive exponential integrate-and-fire model**, also called AdEx, is a + spiking neuron model with two variables [1]_ [2]_. - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + .. math:: - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + \begin{aligned} + \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ + \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w + \end{aligned} - **Model Examples** + once the membrane potential reaches the spike threshold, - - `Examples for different firing patterns `_ + .. math:: - **Model Parameters** + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` - b 1 \ The increment of :math:`w` produced by a spike. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_w 30 ms Time constant of the adaptation current. - tau_ref 0. ms Refractory time. - ============= ============== ======== ======================================================================================================================== + The first equation describes the dynamics of the membrane potential and includes + an activation term with an exponential voltage dependence. Voltage is coupled to + a second equation which describes adaptation. Both variables are reset if an action + potential has been triggered. The combination of adaptation and exponential voltage + dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. - **Model Variables** + The adaptive exponential integrate-and-fire model is capable of describing known + neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, + initial bursting, fast spiking, and regular spiking. - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + **References** + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + **Examples** - """ + An example usage: - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -55., - V_T: Union[float, ArrayType, Callable] = -59.9, - delta_T: Union[float, ArrayType, Callable] = 3.48, - a: Union[float, ArrayType, Callable] = 1., - b: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - tau_w: Union[float, ArrayType, Callable] = 30., - R: Union[float, ArrayType, Callable] = 1., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - w_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.V_reset = self.offset_scaling(self.init_param(V_reset)) - self.V_th = self.offset_scaling(self.init_param(V_th)) - self.V_T = self.offset_scaling(self.init_param(V_T)) - self.a = self.init_param(a) - self.b = self.std_scaling(self.init_param(b)) - self.R = self.init_param(R) - self.delta_T = self.std_scaling(self.init_param(delta_T)) - self.tau = self.init_param(tau) - self.tau_w = self.init_param(tau_w) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._w_initializer = is_initializer(w_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def dV(self, V, t, w, I): - I = self.sum_current_inputs(V, init=I) - exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau - return dVdt - - def dw(self, w, t, V): - dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w - return dwdt - - @property - def derivative(self): - return JointEq([self.dV, self.dw]) - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) - V += self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike - else: - raise ValueError - w += self.b * spike - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - w = bm.where(spike, w + self.b, w) - - self.V.value = V - self.w.value = w - self.spike.value = spike - return spike - - def return_info(self): - return self.spike + .. code-block:: python + import brainpy as bp -class AdExIF(AdExIFLTC): - r"""Adaptive exponential integrate-and-fire neuron model. + neu = bp.dyn.AdExIFLTC(2) - **Model Descriptions** + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - The **adaptive exponential integrate-and-fire model**, also called AdEx, is a - spiking neuron model with two variables [1]_ [2]_. + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - .. math:: + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - \begin{aligned} - \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ - \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w - \end{aligned} + **Model Examples** - once the membrane potential reaches the spike threshold, + - `Examples for different firing patterns `_ - .. math:: + **Model Parameters** - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` + b 1 \ The increment of :math:`w` produced by a spike. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_w 30 ms Time constant of the adaptation current. + tau_ref 0. ms Refractory time. + ============= ============== ======== ======================================================================================================================== - The first equation describes the dynamics of the membrane potential and includes - an activation term with an exponential voltage dependence. Voltage is coupled to - a second equation which describes adaptation. Both variables are reset if an action - potential has been triggered. The combination of adaptation and exponential voltage - dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. + **Model Variables** - The adaptive exponential integrate-and-fire model is capable of describing known - neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, - initial bursting, fast spiking, and regular spiking. + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - **References** - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model - **Examples** + """ - An example usage: + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -55., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 30., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.V_reset = self.offset_scaling(self.init_param(V_reset)) + self.V_th = self.offset_scaling(self.init_param(V_th)) + self.V_T = self.offset_scaling(self.init_param(V_T)) + self.a = self.init_param(a) + self.b = self.std_scaling(self.init_param(b)) + self.R = self.init_param(R) + self.delta_T = self.std_scaling(self.init_param(delta_T)) + self.tau = self.init_param(tau) + self.tau_w = self.init_param(tau_w) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dV(self, V, t, w, I): + I = self.sum_current_inputs(V, init=I) + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike + else: + raise ValueError + w += self.b * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + + self.V.value = V + self.w.value = w + self.spike.value = spike + return spike + + def return_info(self): + return self.spike - .. code-block:: python - import brainpy as bp +class AdExIF(AdExIFLTC): + r"""Adaptive exponential integrate-and-fire neuron model. - neu = bp.dyn.AdExIF(2) + **Model Descriptions** - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + The **adaptive exponential integrate-and-fire model**, also called AdEx, is a + spiking neuron model with two variables [1]_ [2]_. - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + .. math:: - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + \begin{aligned} + \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ + \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w + \end{aligned} - **Model Examples** - - - `Examples for different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` - b 1 \ The increment of :math:`w` produced by a spike. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_w 30 ms Time constant of the adaptation current. - tau_ref 0. ms Refractory time. - ============= ============== ======== ======================================================================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - - - Args: - %s - %s - """ + once the membrane potential reaches the spike threshold, + + .. math:: + + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - def dV(self, V, t, w, I): - exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau - return dVdt + The first equation describes the dynamics of the membrane potential and includes + an activation term with an exponential voltage dependence. Voltage is coupled to + a second equation which describes adaptation. Both variables are reset if an action + potential has been triggered. The combination of adaptation and exponential voltage + dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + The adaptive exponential integrate-and-fire model is capable of describing known + neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, + initial bursting, fast spiking, and regular spiking. + + **References** + + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + + **Examples** + + An example usage: + + .. code-block:: python + + import brainpy as bp + + neu = bp.dyn.AdExIF(2) + + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + + **Model Examples** + + - `Examples for different firing patterns `_ + + **Model Parameters** + + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` + b 1 \ The increment of :math:`w` produced by a spike. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_w 30 ms Time constant of the adaptation current. + tau_ref 0. ms Refractory time. + ============= ============== ======== ======================================================================================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + + + Args: + %s + %s + """ + + def dV(self, V, t, w, I): + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class AdExIFRefLTC(AdExIFLTC): - r"""Adaptive exponential integrate-and-fire neuron model with liquid time-constant. + r"""Adaptive exponential integrate-and-fire neuron model with liquid time-constant. - **Model Descriptions** + **Model Descriptions** - The **adaptive exponential integrate-and-fire model**, also called AdEx, is a - spiking neuron model with two variables [1]_ [2]_. + The **adaptive exponential integrate-and-fire model**, also called AdEx, is a + spiking neuron model with two variables [1]_ [2]_. - .. math:: + .. math:: - \begin{aligned} - \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ - \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w - \end{aligned} + \begin{aligned} + \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ + \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w + \end{aligned} - once the membrane potential reaches the spike threshold, + once the membrane potential reaches the spike threshold, - .. math:: + .. math:: - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - The first equation describes the dynamics of the membrane potential and includes - an activation term with an exponential voltage dependence. Voltage is coupled to - a second equation which describes adaptation. Both variables are reset if an action - potential has been triggered. The combination of adaptation and exponential voltage - dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. + The first equation describes the dynamics of the membrane potential and includes + an activation term with an exponential voltage dependence. Voltage is coupled to + a second equation which describes adaptation. Both variables are reset if an action + potential has been triggered. The combination of adaptation and exponential voltage + dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. - The adaptive exponential integrate-and-fire model is capable of describing known - neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, - initial bursting, fast spiking, and regular spiking. + The adaptive exponential integrate-and-fire model is capable of describing known + neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, + initial bursting, fast spiking, and regular spiking. - **References** + **References** - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model - **Examples** + **Examples** - An example usage: + An example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.AdExIFRefLTC(2) + neu = bp.dyn.AdExIFRefLTC(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Examples** + **Model Examples** - - `Examples for different firing patterns `_ + - `Examples for different firing patterns `_ - **Model Parameters** + **Model Parameters** - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` - b 1 \ The increment of :math:`w` produced by a spike. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_w 30 ms Time constant of the adaptation current. - tau_ref 0. ms Refractory time. - ============= ============== ======== ======================================================================================================================== + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` + b 1 \ The increment of :math:`w` produced by a spike. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_w 30 ms Time constant of the adaptation current. + tau_ref 0. ms Refractory time. + ============= ============== ======== ======================================================================================================================== - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - %s - """ + Args: + %s + %s + %s + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -55., - V_T: Union[float, ArrayType, Callable] = -59.9, - delta_T: Union[float, ArrayType, Callable] = 3.48, - a: Union[float, ArrayType, Callable] = 1., - b: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - tau_w: Union[float, ArrayType, Callable] = 30., - R: Union[float, ArrayType, Callable] = 1., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - w_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_rest=V_rest, - V_reset=V_reset, - V_th=V_th, - V_T=V_T, - delta_T=delta_T, - a=a, - b=b, - R=R, - tau=tau, - tau_w=tau_w, - V_initializer=V_initializer, - w_initializer=w_initializer - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._w_initializer = is_initializer(w_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e8) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) - V += self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike_no_grad - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike_no_grad - else: - raise ValueError - w += self.b * spike_no_grad - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - w = bm.where(spike, w + self.b, w) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.w.value = w - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -55., + V_T: Union[float, ArrayType, Callable] = -59.9, + delta_T: Union[float, ArrayType, Callable] = 3.48, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 30., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_T=V_T, + delta_T=delta_T, + a=a, + b=b, + R=R, + tau=tau, + tau_w=tau_w, + V_initializer=V_initializer, + w_initializer=w_initializer + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e8) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike_no_grad + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike_no_grad + else: + raise ValueError + w += self.b * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.w.value = w + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike class AdExIFRef(AdExIFRefLTC): - r"""Adaptive exponential integrate-and-fire neuron model. + r"""Adaptive exponential integrate-and-fire neuron model. - **Model Descriptions** + **Model Descriptions** - The **adaptive exponential integrate-and-fire model**, also called AdEx, is a - spiking neuron model with two variables [1]_ [2]_. + The **adaptive exponential integrate-and-fire model**, also called AdEx, is a + spiking neuron model with two variables [1]_ [2]_. - .. math:: + .. math:: - \begin{aligned} - \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ - \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w - \end{aligned} + \begin{aligned} + \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ + \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w + \end{aligned} - once the membrane potential reaches the spike threshold, + once the membrane potential reaches the spike threshold, - .. math:: + .. math:: - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - The first equation describes the dynamics of the membrane potential and includes - an activation term with an exponential voltage dependence. Voltage is coupled to - a second equation which describes adaptation. Both variables are reset if an action - potential has been triggered. The combination of adaptation and exponential voltage - dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. + The first equation describes the dynamics of the membrane potential and includes + an activation term with an exponential voltage dependence. Voltage is coupled to + a second equation which describes adaptation. Both variables are reset if an action + potential has been triggered. The combination of adaptation and exponential voltage + dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. - The adaptive exponential integrate-and-fire model is capable of describing known - neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, - initial bursting, fast spiking, and regular spiking. + The adaptive exponential integrate-and-fire model is capable of describing known + neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, + initial bursting, fast spiking, and regular spiking. - **References** + **References** - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model - **Examples** + **Examples** - Here is an example usage: + Here is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.AdExIFRef(2) + neu = bp.dyn.AdExIFRef(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Examples** - - - `Examples for different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` - b 1 \ The increment of :math:`w` produced by a spike. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_w 30 ms Time constant of the adaptation current. - tau_ref 0. ms Refractory time. - ============= ============== ======== ======================================================================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - Args: - %s - %s - %s - """ + **Model Examples** - def dV(self, V, t, w, I): - exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau - return dVdt + - `Examples for different firing patterns `_ - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + **Model Parameters** + + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` + b 1 \ The increment of :math:`w` produced by a spike. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_w 30 ms Time constant of the adaptation current. + tau_ref 0. ms Refractory time. + ============= ============== ======== ======================================================================================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + Args: + %s + %s + %s + """ + + def dV(self, V, t, w, I): + exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) + dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) AdExIF.__doc__ = AdExIF.__doc__ % (pneu_doc, dpneu_doc) @@ -1956,560 +1956,560 @@ def update(self, x=None): class QuaIFLTC(GradNeuDyn): - r"""Quadratic Integrate-and-Fire neuron model with liquid time-constant. + r"""Quadratic Integrate-and-Fire neuron model with liquid time-constant. - **Model Descriptions** + **Model Descriptions** - In contrast to physiologically accurate but computationally expensive - neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only - to produce **action potential-like patterns** and ignores subtleties - like gating variables, which play an important role in generating action - potentials in a real neuron. However, the QIF model is incredibly easy - to implement and compute, and relatively straightforward to study and - understand, thus has found ubiquitous use in computational neuroscience. + In contrast to physiologically accurate but computationally expensive + neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only + to produce **action potential-like patterns** and ignores subtleties + like gating variables, which play an important role in generating action + potentials in a real neuron. However, the QIF model is incredibly easy + to implement and compute, and relatively straightforward to study and + understand, thus has found ubiquitous use in computational neuroscience. - .. math:: + .. math:: - \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) + \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) - where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). + where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). - **References** + **References** - .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg - (2000) Intrinsic dynamics in neuronal networks. I. Theory. - J. Neurophysiology 83, pp. 808–827. + .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg + (2000) Intrinsic dynamics in neuronal networks. I. Theory. + J. Neurophysiology 83, pp. 808–827. - **Examples** + **Examples** - Here is an example usage: + Here is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.QuaIFLTC(2) + neu = bp.dyn.QuaIFLTC(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Parameters** + **Model Parameters** - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. - c .07 \ Coefficient describes membrane potential update. Larger than 0. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_ref 0 ms Refractory period length. - ============= ============== ======== ======================================================================================================================== + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. + c .07 \ Coefficient describes membrane potential update. Larger than 0. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_ref 0 ms Refractory period length. + ============= ============== ======== ======================================================================================================================== - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - """ + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -30., - V_c: Union[float, ArrayType, Callable] = -50.0, - c: Union[float, ArrayType, Callable] = 0.07, - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.V_reset = self.offset_scaling(self.init_param(V_reset)) - self.V_th = self.offset_scaling(self.init_param(V_th)) - self.V_c = self.offset_scaling(self.init_param(V_c)) - self.c = self.inv_scaling(self.init_param(c)) - self.R = self.init_param(R) - self.tau = self.init_param(tau) - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=1) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def derivative(self, V, t, I): - I = self.sum_current_inputs(V, init=I) - dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau - return dVdt - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike - else: - raise ValueError - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - - self.V.value = V - self.spike.value = spike - return spike - - def return_info(self): - return self.spike + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + c: Union[float, ArrayType, Callable] = 0.07, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.V_reset = self.offset_scaling(self.init_param(V_reset)) + self.V_th = self.offset_scaling(self.init_param(V_th)) + self.V_c = self.offset_scaling(self.init_param(V_c)) + self.c = self.inv_scaling(self.init_param(c)) + self.R = self.init_param(R) + self.tau = self.init_param(tau) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=1) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + I = self.sum_current_inputs(V, init=I) + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau + return dVdt + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike + else: + raise ValueError + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + + self.V.value = V + self.spike.value = spike + return spike + + def return_info(self): + return self.spike class QuaIF(QuaIFLTC): - r"""Quadratic Integrate-and-Fire neuron model. + r"""Quadratic Integrate-and-Fire neuron model. - **Model Descriptions** + **Model Descriptions** - In contrast to physiologically accurate but computationally expensive - neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only - to produce **action potential-like patterns** and ignores subtleties - like gating variables, which play an important role in generating action - potentials in a real neuron. However, the QIF model is incredibly easy - to implement and compute, and relatively straightforward to study and - understand, thus has found ubiquitous use in computational neuroscience. + In contrast to physiologically accurate but computationally expensive + neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only + to produce **action potential-like patterns** and ignores subtleties + like gating variables, which play an important role in generating action + potentials in a real neuron. However, the QIF model is incredibly easy + to implement and compute, and relatively straightforward to study and + understand, thus has found ubiquitous use in computational neuroscience. - .. math:: + .. math:: - \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) + \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) - where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). + where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). - **References** + **References** - .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg - (2000) Intrinsic dynamics in neuronal networks. I. Theory. - J. Neurophysiology 83, pp. 808–827. + .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg + (2000) Intrinsic dynamics in neuronal networks. I. Theory. + J. Neurophysiology 83, pp. 808–827. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.QuaIF(2) + neu = bp.dyn.QuaIF(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Parameters** + **Model Parameters** - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. - c .07 \ Coefficient describes membrane potential update. Larger than 0. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_ref 0 ms Refractory period length. - ============= ============== ======== ======================================================================================================================== + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. + c .07 \ Coefficient describes membrane potential update. Larger than 0. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_ref 0 ms Refractory period length. + ============= ============== ======== ======================================================================================================================== - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - """ + Args: + %s + %s + """ - def derivative(self, V, t, I): - dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau - return dVdt + def derivative(self, V, t, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau + return dVdt - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class QuaIFRefLTC(QuaIFLTC): - r"""Quadratic Integrate-and-Fire neuron model with liquid time-constant. + r"""Quadratic Integrate-and-Fire neuron model with liquid time-constant. - **Model Descriptions** + **Model Descriptions** - In contrast to physiologically accurate but computationally expensive - neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only - to produce **action potential-like patterns** and ignores subtleties - like gating variables, which play an important role in generating action - potentials in a real neuron. However, the QIF model is incredibly easy - to implement and compute, and relatively straightforward to study and - understand, thus has found ubiquitous use in computational neuroscience. + In contrast to physiologically accurate but computationally expensive + neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only + to produce **action potential-like patterns** and ignores subtleties + like gating variables, which play an important role in generating action + potentials in a real neuron. However, the QIF model is incredibly easy + to implement and compute, and relatively straightforward to study and + understand, thus has found ubiquitous use in computational neuroscience. - .. math:: + .. math:: - \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) + \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) - where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). + where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). - **References** + **References** - .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg - (2000) Intrinsic dynamics in neuronal networks. I. Theory. - J. Neurophysiology 83, pp. 808–827. + .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg + (2000) Intrinsic dynamics in neuronal networks. I. Theory. + J. Neurophysiology 83, pp. 808–827. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.QuaIFRefLTC(2) + neu = bp.dyn.QuaIFRefLTC(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Parameters** + **Model Parameters** - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. - c .07 \ Coefficient describes membrane potential update. Larger than 0. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_ref 0 ms Refractory period length. - ============= ============== ======== ======================================================================================================================== + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. + c .07 \ Coefficient describes membrane potential update. Larger than 0. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_ref 0 ms Refractory period length. + ============= ============== ======== ======================================================================================================================== - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - %s - """ + Args: + %s + %s + %s + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -30., - V_c: Union[float, ArrayType, Callable] = -50.0, - c: Union[float, ArrayType, Callable] = 0.07, - R: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_rest=V_rest, - V_reset=V_reset, - V_th=V_th, - V_c=V_c, - c=c, - R=R, - tau=tau, - V_initializer=V_initializer, - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=1) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e7) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike_no_grad - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike_no_grad - else: - raise ValueError - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + c: Union[float, ArrayType, Callable] = 0.07, + R: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_c=V_c, + c=c, + R=R, + tau=tau, + V_initializer=V_initializer, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=1) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike_no_grad + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike_no_grad + else: + raise ValueError + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike class QuaIFRef(QuaIFRefLTC): - r"""Quadratic Integrate-and-Fire neuron model. + r"""Quadratic Integrate-and-Fire neuron model. - **Model Descriptions** + **Model Descriptions** - In contrast to physiologically accurate but computationally expensive - neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only - to produce **action potential-like patterns** and ignores subtleties - like gating variables, which play an important role in generating action - potentials in a real neuron. However, the QIF model is incredibly easy - to implement and compute, and relatively straightforward to study and - understand, thus has found ubiquitous use in computational neuroscience. + In contrast to physiologically accurate but computationally expensive + neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only + to produce **action potential-like patterns** and ignores subtleties + like gating variables, which play an important role in generating action + potentials in a real neuron. However, the QIF model is incredibly easy + to implement and compute, and relatively straightforward to study and + understand, thus has found ubiquitous use in computational neuroscience. - .. math:: + .. math:: - \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) + \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) - where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). + where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). - **References** + **References** - .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg - (2000) Intrinsic dynamics in neuronal networks. I. Theory. - J. Neurophysiology 83, pp. 808–827. + .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg + (2000) Intrinsic dynamics in neuronal networks. I. Theory. + J. Neurophysiology 83, pp. 808–827. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.QuaIFRef(2) + neu = bp.dyn.QuaIFRef(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Parameters** + **Model Parameters** - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. - c .07 \ Coefficient describes membrane potential update. Larger than 0. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_ref 0 ms Refractory period length. - ============= ============== ======== ======================================================================================================================== + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. + c .07 \ Coefficient describes membrane potential update. Larger than 0. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_ref 0 ms Refractory period length. + ============= ============== ======== ======================================================================================================================== - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - %s - """ + Args: + %s + %s + %s + """ - def derivative(self, V, t, I): - dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau - return dVdt + def derivative(self, V, t, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau + return dVdt - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) QuaIF.__doc__ = QuaIF.__doc__ % (pneu_doc, dpneu_doc) @@ -2519,1432 +2519,1647 @@ def update(self, x=None): class AdQuaIFLTC(GradNeuDyn): - r"""Adaptive quadratic integrate-and-fire neuron model with liquid time-constant. + r"""Adaptive quadratic integrate-and-fire neuron model with liquid time-constant. - **Model Descriptions** + **Model Descriptions** - The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: + The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: - .. math:: + .. math:: - \begin{aligned} - \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ - \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, - \end{aligned} + \begin{aligned} + \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ + \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, + \end{aligned} - once the membrane potential reaches the spike threshold, + once the membrane potential reaches the spike threshold, - .. math:: + .. math:: - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - **References** + **References** - .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking - neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. - .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of - nonlinear integrate-and-fire neurons." SIAM Journal on Applied - Mathematics 68, no. 4 (2008): 1045-1079. + .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking + neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. + .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of + nonlinear integrate-and-fire neurons." SIAM Journal on Applied + Mathematics 68, no. 4 (2008): 1045-1079. - **Examples** + **Examples** - Here is an example usage: + Here is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.AdQuaIFLTC(2) + neu = bp.dyn.AdQuaIFLTC(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger + than :math:`V_{rest}`. + a 1 \ The sensitivity of the recovery variable :math:`u` to + the sub-threshold fluctuations of the membrane + potential :math:`v` + b .1 \ The increment of :math:`w` produced by a spike. + c .07 \ Coefficient describes membrane potential update. + Larger than 0. + tau 10 ms Membrane time constant. + tau_w 10 ms Time constant of the adaptation current. + ============= ============== ======== ======================================================= + **Model Variables** - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger - than :math:`V_{rest}`. - a 1 \ The sensitivity of the recovery variable :math:`u` to - the sub-threshold fluctuations of the membrane - potential :math:`v` - b .1 \ The increment of :math:`w` produced by a spike. - c .07 \ Coefficient describes membrane potential update. - Larger than 0. - tau 10 ms Membrane time constant. - tau_w 10 ms Time constant of the adaptation current. - ============= ============== ======== ======================================================= - - **Model Variables** - - ================== ================= ========================================================== - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- ---------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================== + ================== ================= ========================================================== + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- ---------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================== - """ + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -30., - V_c: Union[float, ArrayType, Callable] = -50.0, - a: Union[float, ArrayType, Callable] = 1., - b: Union[float, ArrayType, Callable] = .1, - c: Union[float, ArrayType, Callable] = .07, - tau: Union[float, ArrayType, Callable] = 10., - tau_w: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - w_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.V_reset = self.offset_scaling(self.init_param(V_reset)) - self.V_th = self.offset_scaling(self.init_param(V_th)) - self.V_c = self.offset_scaling(self.init_param(V_c)) - self.a = self.init_param(a) - self.b = self.std_scaling(self.init_param(b)) - self.c = self.inv_scaling(self.init_param(c)) - self.tau = self.init_param(tau) - self.tau_w = self.init_param(tau_w) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._w_initializer = is_initializer(w_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def dV(self, V, t, w, I): - I = self.sum_current_inputs(V, init=I) - dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau - return dVdt - - def dw(self, w, t, V): - dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w - return dwdt - - @property - def derivative(self): - return JointEq([self.dV, self.dw]) - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) - V += self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike - else: - raise ValueError - w += self.b * spike - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - w = bm.where(spike, w + self.b, w) - - self.V.value = V - self.w.value = w - self.spike.value = spike - return spike - - def return_info(self): - return self.spike + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = .1, + c: Union[float, ArrayType, Callable] = .07, + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.V_reset = self.offset_scaling(self.init_param(V_reset)) + self.V_th = self.offset_scaling(self.init_param(V_th)) + self.V_c = self.offset_scaling(self.init_param(V_c)) + self.a = self.init_param(a) + self.b = self.std_scaling(self.init_param(b)) + self.c = self.inv_scaling(self.init_param(c)) + self.tau = self.init_param(tau) + self.tau_w = self.init_param(tau_w) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dV(self, V, t, w, I): + I = self.sum_current_inputs(V, init=I) + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau + return dVdt + + def dw(self, w, t, V): + dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w + return dwdt + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike + else: + raise ValueError + w += self.b * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + + self.V.value = V + self.w.value = w + self.spike.value = spike + return spike + + def return_info(self): + return self.spike class AdQuaIF(AdQuaIFLTC): - r"""Adaptive quadratic integrate-and-fire neuron model. + r"""Adaptive quadratic integrate-and-fire neuron model. - **Model Descriptions** + **Model Descriptions** - The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: + The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: - .. math:: + .. math:: - \begin{aligned} - \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ - \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, - \end{aligned} + \begin{aligned} + \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ + \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, + \end{aligned} - once the membrane potential reaches the spike threshold, + once the membrane potential reaches the spike threshold, - .. math:: + .. math:: - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - **References** + **References** - .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking - neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. - .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of - nonlinear integrate-and-fire neurons." SIAM Journal on Applied - Mathematics 68, no. 4 (2008): 1045-1079. + .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking + neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. + .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of + nonlinear integrate-and-fire neurons." SIAM Journal on Applied + Mathematics 68, no. 4 (2008): 1045-1079. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp - neu = bp.dyn.AdQuaIF(2) + neu = bp.dyn.AdQuaIF(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Parameters** + **Model Parameters** - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger - than :math:`V_{rest}`. - a 1 \ The sensitivity of the recovery variable :math:`u` to - the sub-threshold fluctuations of the membrane - potential :math:`v` - b .1 \ The increment of :math:`w` produced by a spike. - c .07 \ Coefficient describes membrane potential update. - Larger than 0. - tau 10 ms Membrane time constant. - tau_w 10 ms Time constant of the adaptation current. - ============= ============== ======== ======================================================= + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger + than :math:`V_{rest}`. + a 1 \ The sensitivity of the recovery variable :math:`u` to + the sub-threshold fluctuations of the membrane + potential :math:`v` + b .1 \ The increment of :math:`w` produced by a spike. + c .07 \ Coefficient describes membrane potential update. + Larger than 0. + tau 10 ms Membrane time constant. + tau_w 10 ms Time constant of the adaptation current. + ============= ============== ======== ======================================================= - **Model Variables** + **Model Variables** - ================== ================= ========================================================== - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- ---------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================== + ================== ================= ========================================================== + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- ---------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================== - Args: - %s - %s - """ + Args: + %s + %s + """ - def dV(self, V, t, w, I): - dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau - return dVdt + def dV(self, V, t, w, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau + return dVdt - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class AdQuaIFRefLTC(AdQuaIFLTC): - r"""Adaptive quadratic integrate-and-fire neuron model with liquid time-constant. + r"""Adaptive quadratic integrate-and-fire neuron model with liquid time-constant. - **Model Descriptions** + **Model Descriptions** - The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: + The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: - .. math:: + .. math:: - \begin{aligned} - \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ - \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, - \end{aligned} + \begin{aligned} + \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ + \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, + \end{aligned} - once the membrane potential reaches the spike threshold, + once the membrane potential reaches the spike threshold, - .. math:: + .. math:: - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. - **References** + **References** - .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking - neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. - .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of - nonlinear integrate-and-fire neurons." SIAM Journal on Applied - Mathematics 68, no. 4 (2008): 1045-1079. + .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking + neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. + .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of + nonlinear integrate-and-fire neurons." SIAM Journal on Applied + Mathematics 68, no. 4 (2008): 1045-1079. - **Examples** + **Examples** - There is an example usage: + There is an example usage: - .. code-block:: python + .. code-block:: python - import brainpy as bp + import brainpy as bp + + neu = bp.dyn.AdQuaIFRefLTC(2) + + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + + + + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger + than :math:`V_{rest}`. + a 1 \ The sensitivity of the recovery variable :math:`u` to + the sub-threshold fluctuations of the membrane + potential :math:`v` + b .1 \ The increment of :math:`w` produced by a spike. + c .07 \ Coefficient describes membrane potential update. + Larger than 0. + tau 10 ms Membrane time constant. + tau_w 10 ms Time constant of the adaptation current. + ============= ============== ======== ======================================================= + + **Model Variables** + + ================== ================= ========================================================== + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- ---------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================== + + Args: + %s + %s + %s + """ + + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -65., + V_reset: Union[float, ArrayType, Callable] = -68., + V_th: Union[float, ArrayType, Callable] = -30., + V_c: Union[float, ArrayType, Callable] = -50.0, + a: Union[float, ArrayType, Callable] = 1., + b: Union[float, ArrayType, Callable] = .1, + c: Union[float, ArrayType, Callable] = .07, + tau: Union[float, ArrayType, Callable] = 10., + tau_w: Union[float, ArrayType, Callable] = 10., + V_initializer: Union[Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + V_c=V_c, + a=a, + b=b, + c=c, + tau=tau, + tau_w=tau_w, + V_initializer=V_initializer, + w_initializer=w_initializer + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e8) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike_no_grad + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike_no_grad + else: + raise ValueError + w += self.b * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + w = bm.where(spike, w + self.b, w) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.w.value = w + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class AdQuaIFRef(AdQuaIFRefLTC): + r"""Adaptive quadratic integrate-and-fire neuron model. + + **Model Descriptions** + + The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: + + .. math:: + + \begin{aligned} + \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ + \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, + \end{aligned} + + once the membrane potential reaches the spike threshold, + + .. math:: + + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. + + **References** + + .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking + neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. + .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of + nonlinear integrate-and-fire neurons." SIAM Journal on Applied + Mathematics 68, no. 4 (2008): 1045-1079. + + **Examples** + + There is an example usage: + + .. code-block:: python + + import brainpy as bp + + neu = bp.dyn.AdQuaIFRef(2) + + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) + + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger + than :math:`V_{rest}`. + a 1 \ The sensitivity of the recovery variable :math:`u` to + the sub-threshold fluctuations of the membrane + potential :math:`v` + b .1 \ The increment of :math:`w` produced by a spike. + c .07 \ Coefficient describes membrane potential update. + Larger than 0. + tau 10 ms Membrane time constant. + tau_w 10 ms Time constant of the adaptation current. + ============= ============== ======== ======================================================= + + **Model Variables** + + ================== ================= ========================================================== + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- ---------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================== + + + + Args: + %s + %s + %s + """ + + def dV(self, V, t, w, I): + dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau + return dVdt + + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) + + +AdQuaIF.__doc__ = AdQuaIF.__doc__ % (pneu_doc, dpneu_doc) +AdQuaIFRefLTC.__doc__ = AdQuaIFRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +AdQuaIFRef.__doc__ = AdQuaIFRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +AdQuaIFLTC.__doc__ = AdQuaIFLTC.__doc__ % () + + +class GifLTC(GradNeuDyn): + r"""Generalized Integrate-and-Fire model with liquid time-constant. + + **Model Descriptions** + + The generalized integrate-and-fire model [1]_ is given by + + .. math:: + + &\frac{d I_j}{d t} = - k_j I_j + + &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau + + &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) + + When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + + .. math:: + + &I_j \leftarrow R_j I_j + A_j + + &V \leftarrow V_{reset} + + &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) + + Note that :math:`I_j` refers to arbitrary number of internal currents. + + + **References** + + .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear + integrate-and-fire neural model produces diverse spiking + behaviors." Neural computation 21.3 (2009): 704-718. + .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan + Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized + leaky integrate-and-fire models classify multiple neuron types." + Nature communications 9, no. 1 (2018): 1-15. + + **Examples** + + There is a simple usage: you r bound to be together, roy and edward + + .. code-block:: python + + import brainpy as bp + import matplotlib.pyplot as plt + + # Tonic Spiking + neu = bp.dyn.Gif(1) + inputs = bp.inputs.ramp_input(.2, 2, 400, 0, 400) + + runner = bp.DSRunner(neu, monitors=['V', 'V_th']) + runner.run(inputs=inputs) + + ts = runner.mon.ts + + fig, gs = bp.visualize.get_figure(1, 1, 4, 8) + ax1 = fig.add_subplot(gs[0, 0]) + + ax1.plot(ts, runner.mon.V[:, 0], label='V') + ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') + + plt.show() + + **Model Examples** + + - `Detailed examples to reproduce different firing patterns `_ + + **Model Parameters** + + ============= ============== ======== ==================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------- + V_rest -70 mV Resting potential. + V_reset -70 mV Reset potential after spike. + V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. + V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. + R 20 \ Membrane resistance. + tau 20 ms Membrane time constant. Compute by :math:`R * C`. + a 0 \ Coefficient describes the dependence of + :math:`V_{th}` on membrane potential. + b 0.01 \ Coefficient describes :math:`V_{th}` update. + k1 0.2 \ Constant pf :math:`I1`. + k2 0.02 \ Constant of :math:`I2`. + R1 0 \ Free parameter. + Describes dependence of :math:`I_1` reset value on + :math:`I_1` value before spiking. + R2 1 \ Free parameter. + Describes dependence of :math:`I_2` reset value on + :math:`I_2` value before spiking. + A1 0 \ Free parameter. + A2 0 \ Free parameter. + ============= ============== ======== ==================================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -70 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + V_th -50 Spiking threshold potential. + I1 0 Internal current 1. + I2 0 Internal current 2. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + + """ + + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_rest: Union[float, ArrayType, Callable] = -70., + V_reset: Union[float, ArrayType, Callable] = -70., + V_th_inf: Union[float, ArrayType, Callable] = -50., + V_th_reset: Union[float, ArrayType, Callable] = -60., + R: Union[float, ArrayType, Callable] = 20., + tau: Union[float, ArrayType, Callable] = 20., + a: Union[float, ArrayType, Callable] = 0., + b: Union[float, ArrayType, Callable] = 0.01, + k1: Union[float, ArrayType, Callable] = 0.2, + k2: Union[float, ArrayType, Callable] = 0.02, + R1: Union[float, ArrayType, Callable] = 0., + R2: Union[float, ArrayType, Callable] = 1., + A1: Union[float, ArrayType, Callable] = 0., + A2: Union[float, ArrayType, Callable] = 0., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + I1_initializer: Union[Callable, ArrayType] = ZeroInit(), + I2_initializer: Union[Callable, ArrayType] = ZeroInit(), + Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + # parameters + self.V_rest = self.offset_scaling(self.init_param(V_rest)) + self.V_reset = self.offset_scaling(self.init_param(V_reset)) + self.V_th_inf = self.offset_scaling(self.init_param(V_th_inf)) + self.V_th_reset = self.offset_scaling(self.init_param(V_th_reset)) + self.R = self.init_param(R) + self.a = self.init_param(a) + self.b = self.init_param(b) + self.k1 = self.init_param(k1) + self.k2 = self.init_param(k2) + self.R1 = self.init_param(R1) + self.R2 = self.init_param(R2) + self.A1 = self.std_scaling(self.init_param(A1)) + self.A2 = self.std_scaling(self.init_param(A2)) + self.tau = self.init_param(tau) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._I1_initializer = is_initializer(I1_initializer) + self._I2_initializer = is_initializer(I2_initializer) + self._Vth_initializer = is_initializer(Vth_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=4) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dI1(self, I1, t): + return - self.k1 * I1 + + def dI2(self, I2, t): + return - self.k2 * I2 + + def dVth(self, V_th, t, V): + return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) + + def dV(self, V, t, I1, I2, I): + I = self.sum_current_inputs(V, init=I) + return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau + + @property + def derivative(self): + return JointEq(self.dI1, self.dI2, self.dVth, self.dV) + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) + self.V_th = self.offset_scaling(self.init_variable(self._Vth_initializer, batch_size)) + self.I1 = self.std_scaling(self.init_variable(self._I1_initializer, batch_size)) + self.I2 = self.std_scaling(self.init_variable(self._I2_initializer, batch_size)) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + V += self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike + else: + raise ValueError + I1 += spike * (self.R1 * I1 + self.A1 - I1) + I2 += spike * (self.R2 * I2 + self.A2 - I2) + V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike + + else: + spike = self.V_th <= V + V = bm.where(spike, self.V_reset, V) + I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) + I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) + V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) + self.spike.value = spike + self.I1.value = I1 + self.I2.value = I2 + self.V_th.value = V_th + self.V.value = V + return spike + + def return_info(self): + return self.spike - neu = bp.dyn.AdQuaIFRefLTC(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 +class Gif(GifLTC): + r"""Generalized Integrate-and-Fire model. - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + **Model Descriptions** - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + The generalized integrate-and-fire model [1]_ is given by + .. math:: + &\frac{d I_j}{d t} = - k_j I_j + &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger - than :math:`V_{rest}`. - a 1 \ The sensitivity of the recovery variable :math:`u` to - the sub-threshold fluctuations of the membrane - potential :math:`v` - b .1 \ The increment of :math:`w` produced by a spike. - c .07 \ Coefficient describes membrane potential update. - Larger than 0. - tau 10 ms Membrane time constant. - tau_w 10 ms Time constant of the adaptation current. - ============= ============== ======== ======================================================= - - **Model Variables** - - ================== ================= ========================================================== - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- ---------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================== - - Args: - %s - %s - %s - """ + &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_rest: Union[float, ArrayType, Callable] = -65., - V_reset: Union[float, ArrayType, Callable] = -68., - V_th: Union[float, ArrayType, Callable] = -30., - V_c: Union[float, ArrayType, Callable] = -50.0, - a: Union[float, ArrayType, Callable] = 1., - b: Union[float, ArrayType, Callable] = .1, - c: Union[float, ArrayType, Callable] = .07, - tau: Union[float, ArrayType, Callable] = 10., - tau_w: Union[float, ArrayType, Callable] = 10., - V_initializer: Union[Callable, ArrayType] = ZeroInit(), - w_initializer: Union[Callable, ArrayType] = ZeroInit(), - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_rest=V_rest, - V_reset=V_reset, - V_th=V_th, - V_c=V_c, - a=a, - b=b, - c=c, - tau=tau, - tau_w=tau_w, - V_initializer=V_initializer, - w_initializer=w_initializer - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._w_initializer = is_initializer(w_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e8) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) - V += self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike_no_grad - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike_no_grad - else: - raise ValueError - w += self.b * spike_no_grad - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - w = bm.where(spike, w + self.b, w) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.w.value = w - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike + When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + .. math:: -class AdQuaIFRef(AdQuaIFRefLTC): - r"""Adaptive quadratic integrate-and-fire neuron model. + &I_j \leftarrow R_j I_j + A_j - **Model Descriptions** + &V \leftarrow V_{reset} - The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: + &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) - .. math:: + Note that :math:`I_j` refers to arbitrary number of internal currents. - \begin{aligned} - \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ - \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, - \end{aligned} - once the membrane potential reaches the spike threshold, + **References** - .. math:: + .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear + integrate-and-fire neural model produces diverse spiking + behaviors." Neural computation 21.3 (2009): 704-718. + .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan + Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized + leaky integrate-and-fire models classify multiple neuron types." + Nature communications 9, no. 1 (2018): 1-15. - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. + **Examples** - **References** + There is a simple usage: - .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking - neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. - .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of - nonlinear integrate-and-fire neurons." SIAM Journal on Applied - Mathematics 68, no. 4 (2008): 1045-1079. + .. code-block:: python - **Examples** + import brainpy as bp + import matplotlib.pyplot as plt - There is an example usage: + # Phasic Spiking + neu = bp.dyn.Gif(1, a=0.005) + inputs = bp.inputs.section_input((0, 1.5), (50, 500)) - .. code-block:: python + runner = bp.DSRunner(neu, monitors=['V', 'V_th']) + runner.run(inputs=inputs) - import brainpy as bp + ts = runner.mon.ts - neu = bp.dyn.AdQuaIFRef(2) + fig, gs = bp.visualize.get_figure(1, 1, 4, 8) + ax1 = fig.add_subplot(gs[0, 0]) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + ax1.plot(ts, runner.mon.V[:, 0], label='V') + ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + plt.show() - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + **Model Examples** + - `Detailed examples to reproduce different firing patterns `_ - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger - than :math:`V_{rest}`. - a 1 \ The sensitivity of the recovery variable :math:`u` to - the sub-threshold fluctuations of the membrane - potential :math:`v` - b .1 \ The increment of :math:`w` produced by a spike. - c .07 \ Coefficient describes membrane potential update. - Larger than 0. - tau 10 ms Membrane time constant. - tau_w 10 ms Time constant of the adaptation current. - ============= ============== ======== ======================================================= - - **Model Variables** - - ================== ================= ========================================================== - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- ---------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================== - - - - Args: - %s - %s - %s - """ + **Model Parameters** - def dV(self, V, t, w, I): - dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau - return dVdt + ============= ============== ======== ==================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------- + V_rest -70 mV Resting potential. + V_reset -70 mV Reset potential after spike. + V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. + V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. + R 20 \ Membrane resistance. + tau 20 ms Membrane time constant. Compute by :math:`R * C`. + a 0 \ Coefficient describes the dependence of + :math:`V_{th}` on membrane potential. + b 0.01 \ Coefficient describes :math:`V_{th}` update. + k1 0.2 \ Constant pf :math:`I1`. + k2 0.02 \ Constant of :math:`I2`. + R1 0 \ Free parameter. + Describes dependence of :math:`I_1` reset value on + :math:`I_1` value before spiking. + R2 1 \ Free parameter. + Describes dependence of :math:`I_2` reset value on + :math:`I_2` value before spiking. + A1 0 \ Free parameter. + A2 0 \ Free parameter. + ============= ============== ======== ==================================================================== - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + **Model Variables** + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -70 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + V_th -50 Spiking threshold potential. + I1 0 Internal current 1. + I2 0 Internal current 2. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= -AdQuaIF.__doc__ = AdQuaIF.__doc__ % (pneu_doc, dpneu_doc) -AdQuaIFRefLTC.__doc__ = AdQuaIFRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) -AdQuaIFRef.__doc__ = AdQuaIFRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) -AdQuaIFLTC.__doc__ = AdQuaIFLTC.__doc__ % () -class GifLTC(GradNeuDyn): - r"""Generalized Integrate-and-Fire model with liquid time-constant. + Args: + %s + %s + """ - **Model Descriptions** + def dV(self, V, t, I1, I2, I): + return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau - The generalized integrate-and-fire model [1]_ is given by + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) - .. math:: - &\frac{d I_j}{d t} = - k_j I_j +class GifRefLTC(GifLTC): + r"""Generalized Integrate-and-Fire model with liquid time-constant. - &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau + **Model Descriptions** - &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) + The generalized integrate-and-fire model [1]_ is given by - When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + .. math:: - .. math:: + &\frac{d I_j}{d t} = - k_j I_j - &I_j \leftarrow R_j I_j + A_j + &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau - &V \leftarrow V_{reset} + &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) - &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) + When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: - Note that :math:`I_j` refers to arbitrary number of internal currents. + .. math:: + &I_j \leftarrow R_j I_j + A_j - **References** + &V \leftarrow V_{reset} - .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear - integrate-and-fire neural model produces diverse spiking - behaviors." Neural computation 21.3 (2009): 704-718. - .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan - Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized - leaky integrate-and-fire models classify multiple neuron types." - Nature communications 9, no. 1 (2018): 1-15. + &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) - **Examples** + Note that :math:`I_j` refers to arbitrary number of internal currents. - There is a simple usage: you r bound to be together, roy and edward - .. code-block:: python + **References** - import brainpy as bp - import matplotlib.pyplot as plt + .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear + integrate-and-fire neural model produces diverse spiking + behaviors." Neural computation 21.3 (2009): 704-718. + .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan + Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized + leaky integrate-and-fire models classify multiple neuron types." + Nature communications 9, no. 1 (2018): 1-15. - # Tonic Spiking - neu = bp.dyn.Gif(1) - inputs = bp.inputs.ramp_input(.2, 2, 400, 0, 400) + **Examples** - runner = bp.DSRunner(neu, monitors=['V', 'V_th']) - runner.run(inputs=inputs) + There is a simple usage: mustang i love u - ts = runner.mon.ts - - fig, gs = bp.visualize.get_figure(1, 1, 4, 8) - ax1 = fig.add_subplot(gs[0, 0]) - - ax1.plot(ts, runner.mon.V[:, 0], label='V') - ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - - plt.show() - - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ==================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------- - V_rest -70 mV Resting potential. - V_reset -70 mV Reset potential after spike. - V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. - V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. - R 20 \ Membrane resistance. - tau 20 ms Membrane time constant. Compute by :math:`R * C`. - a 0 \ Coefficient describes the dependence of - :math:`V_{th}` on membrane potential. - b 0.01 \ Coefficient describes :math:`V_{th}` update. - k1 0.2 \ Constant pf :math:`I1`. - k2 0.02 \ Constant of :math:`I2`. - R1 0 \ Free parameter. - Describes dependence of :math:`I_1` reset value on - :math:`I_1` value before spiking. - R2 1 \ Free parameter. - Describes dependence of :math:`I_2` reset value on - :math:`I_2` value before spiking. - A1 0 \ Free parameter. - A2 0 \ Free parameter. - ============= ============== ======== ==================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -70 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - V_th -50 Spiking threshold potential. - I1 0 Internal current 1. - I2 0 Internal current 2. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - -""" - - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_rest: Union[float, ArrayType, Callable] = -70., - V_reset: Union[float, ArrayType, Callable] = -70., - V_th_inf: Union[float, ArrayType, Callable] = -50., - V_th_reset: Union[float, ArrayType, Callable] = -60., - R: Union[float, ArrayType, Callable] = 20., - tau: Union[float, ArrayType, Callable] = 20., - a: Union[float, ArrayType, Callable] = 0., - b: Union[float, ArrayType, Callable] = 0.01, - k1: Union[float, ArrayType, Callable] = 0.2, - k2: Union[float, ArrayType, Callable] = 0.02, - R1: Union[float, ArrayType, Callable] = 0., - R2: Union[float, ArrayType, Callable] = 1., - A1: Union[float, ArrayType, Callable] = 0., - A2: Union[float, ArrayType, Callable] = 0., - V_initializer: Union[Callable, ArrayType] = OneInit(-70.), - I1_initializer: Union[Callable, ArrayType] = ZeroInit(), - I2_initializer: Union[Callable, ArrayType] = ZeroInit(), - Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - # parameters - self.V_rest = self.offset_scaling(self.init_param(V_rest)) - self.V_reset = self.offset_scaling(self.init_param(V_reset)) - self.V_th_inf = self.offset_scaling(self.init_param(V_th_inf)) - self.V_th_reset = self.offset_scaling(self.init_param(V_th_reset)) - self.R = self.init_param(R) - self.a = self.init_param(a) - self.b = self.init_param(b) - self.k1 = self.init_param(k1) - self.k2 = self.init_param(k2) - self.R1 = self.init_param(R1) - self.R2 = self.init_param(R2) - self.A1 = self.std_scaling(self.init_param(A1)) - self.A2 = self.std_scaling(self.init_param(A2)) - self.tau = self.init_param(tau) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._I1_initializer = is_initializer(I1_initializer) - self._I2_initializer = is_initializer(I2_initializer) - self._Vth_initializer = is_initializer(Vth_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=4) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def dI1(self, I1, t): - return - self.k1 * I1 - - def dI2(self, I2, t): - return - self.k2 * I2 - - def dVth(self, V_th, t, V): - return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) - - def dV(self, V, t, I1, I2, I): - I = self.sum_current_inputs(V, init=I) - return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau - - @property - def derivative(self): - return JointEq(self.dI1, self.dI2, self.dVth, self.dV) - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size)) - self.V_th = self.offset_scaling(self.init_variable(self._Vth_initializer, batch_size)) - self.I1 = self.std_scaling(self.init_variable(self._I1_initializer, batch_size)) - self.I2 = self.std_scaling(self.init_variable(self._I2_initializer, batch_size)) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) - V += self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike - else: - raise ValueError - I1 += spike * (self.R1 * I1 + self.A1 - I1) - I2 += spike * (self.R2 * I2 + self.A2 - I2) - V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike - - else: - spike = self.V_th <= V - V = bm.where(spike, self.V_reset, V) - I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) - I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) - V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) - self.spike.value = spike - self.I1.value = I1 - self.I2.value = I2 - self.V_th.value = V_th - self.V.value = V - return spike - - def return_info(self): - return self.spike + .. code-block:: python + import brainpy as bp + import matplotlib.pyplot as plt -class Gif(GifLTC): - r"""Generalized Integrate-and-Fire model. + # Hyperpolarization-induced Spiking + neu = bp.dyn.GifRefLTC(1, a=0.005) + neu.V_th[:] = -50. + inputs = bp.inputs.section_input((1.5, 1.7, 1.5, 1.7), (100, 400, 100, 400)) - **Model Descriptions** + runner = bp.DSRunner(neu, monitors=['V', 'V_th']) + runner.run(inputs=inputs) - The generalized integrate-and-fire model [1]_ is given by + ts = runner.mon.ts - .. math:: + fig, gs = bp.visualize.get_figure(1, 1, 4, 8) + ax1 = fig.add_subplot(gs[0, 0]) - &\frac{d I_j}{d t} = - k_j I_j + ax1.plot(ts, runner.mon.V[:, 0], label='V') + ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau + plt.show() - &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) + **Model Examples** - When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + - `Detailed examples to reproduce different firing patterns `_ - .. math:: + **Model Parameters** - &I_j \leftarrow R_j I_j + A_j + ============= ============== ======== ==================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------- + V_rest -70 mV Resting potential. + V_reset -70 mV Reset potential after spike. + V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. + V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. + R 20 \ Membrane resistance. + tau 20 ms Membrane time constant. Compute by :math:`R * C`. + a 0 \ Coefficient describes the dependence of + :math:`V_{th}` on membrane potential. + b 0.01 \ Coefficient describes :math:`V_{th}` update. + k1 0.2 \ Constant pf :math:`I1`. + k2 0.02 \ Constant of :math:`I2`. + R1 0 \ Free parameter. + Describes dependence of :math:`I_1` reset value on + :math:`I_1` value before spiking. + R2 1 \ Free parameter. + Describes dependence of :math:`I_2` reset value on + :math:`I_2` value before spiking. + A1 0 \ Free parameter. + A2 0 \ Free parameter. + ============= ============== ======== ==================================================================== - &V \leftarrow V_{reset} + **Model Variables** - &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -70 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + V_th -50 Spiking threshold potential. + I1 0 Internal current 1. + I2 0 Internal current 2. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Note that :math:`I_j` refers to arbitrary number of internal currents. - **References** + Args: + %s + %s + %s + """ - .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear - integrate-and-fire neural model produces diverse spiking - behaviors." Neural computation 21.3 (2009): 704-718. - .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan - Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized - leaky integrate-and-fire models classify multiple neuron types." - Nature communications 9, no. 1 (2018): 1-15. + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_rest: Union[float, ArrayType, Callable] = -70., + V_reset: Union[float, ArrayType, Callable] = -70., + V_th_inf: Union[float, ArrayType, Callable] = -50., + V_th_reset: Union[float, ArrayType, Callable] = -60., + R: Union[float, ArrayType, Callable] = 20., + tau: Union[float, ArrayType, Callable] = 20., + a: Union[float, ArrayType, Callable] = 0., + b: Union[float, ArrayType, Callable] = 0.01, + k1: Union[float, ArrayType, Callable] = 0.2, + k2: Union[float, ArrayType, Callable] = 0.02, + R1: Union[float, ArrayType, Callable] = 0., + R2: Union[float, ArrayType, Callable] = 1., + A1: Union[float, ArrayType, Callable] = 0., + A2: Union[float, ArrayType, Callable] = 0., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + I1_initializer: Union[Callable, ArrayType] = ZeroInit(), + I2_initializer: Union[Callable, ArrayType] = ZeroInit(), + Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_rest=V_rest, + V_reset=V_reset, + V_th_inf=V_th_inf, + V_th_reset=V_th_reset, + R=R, + a=a, + b=b, + k1=k1, + k2=k2, + R1=R1, + R2=R2, + A1=A1, + A2=A2, + tau=tau, + V_initializer=V_initializer, + I1_initializer=I1_initializer, + I2_initializer=I2_initializer, + Vth_initializer=Vth_initializer, + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._I1_initializer = is_initializer(I1_initializer) + self._I2_initializer = is_initializer(I2_initializer) + self._Vth_initializer = is_initializer(Vth_initializer) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=4) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e8) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + V += self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + if self.spk_reset == 'soft': + V -= (self.V_th - self.V_reset) * spike_no_grad + elif self.spk_reset == 'hard': + V += (self.V_reset - V) * spike_no_grad + else: + raise ValueError + I1 += spike * (self.R1 * I1 + self.A1 - I1) + I2 += spike * (self.R2 * I2 + self.A2 - I2) + V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) + I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) + V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.I1.value = I1 + self.I2.value = I2 + self.V_th.value = V_th + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike - **Examples** - There is a simple usage: +class GifRef(GifRefLTC): + r"""Generalized Integrate-and-Fire model. - .. code-block:: python + **Model Descriptions** - import brainpy as bp - import matplotlib.pyplot as plt + The generalized integrate-and-fire model [1]_ is given by - # Phasic Spiking - neu = bp.dyn.Gif(1, a=0.005) - inputs = bp.inputs.section_input((0, 1.5), (50, 500)) + .. math:: - runner = bp.DSRunner(neu, monitors=['V', 'V_th']) - runner.run(inputs=inputs) + &\frac{d I_j}{d t} = - k_j I_j - ts = runner.mon.ts - - fig, gs = bp.visualize.get_figure(1, 1, 4, 8) - ax1 = fig.add_subplot(gs[0, 0]) - - ax1.plot(ts, runner.mon.V[:, 0], label='V') - ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - - plt.show() - - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ==================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------- - V_rest -70 mV Resting potential. - V_reset -70 mV Reset potential after spike. - V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. - V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. - R 20 \ Membrane resistance. - tau 20 ms Membrane time constant. Compute by :math:`R * C`. - a 0 \ Coefficient describes the dependence of - :math:`V_{th}` on membrane potential. - b 0.01 \ Coefficient describes :math:`V_{th}` update. - k1 0.2 \ Constant pf :math:`I1`. - k2 0.02 \ Constant of :math:`I2`. - R1 0 \ Free parameter. - Describes dependence of :math:`I_1` reset value on - :math:`I_1` value before spiking. - R2 1 \ Free parameter. - Describes dependence of :math:`I_2` reset value on - :math:`I_2` value before spiking. - A1 0 \ Free parameter. - A2 0 \ Free parameter. - ============= ============== ======== ==================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -70 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - V_th -50 Spiking threshold potential. - I1 0 Internal current 1. - I2 0 Internal current 2. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - - - Args: - %s - %s - """ + &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau - def dV(self, V, t, I1, I2, I): - return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau + &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + .. math:: -class GifRefLTC(GifLTC): - r"""Generalized Integrate-and-Fire model with liquid time-constant. + &I_j \leftarrow R_j I_j + A_j - **Model Descriptions** + &V \leftarrow V_{reset} - The generalized integrate-and-fire model [1]_ is given by + &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) - .. math:: + Note that :math:`I_j` refers to arbitrary number of internal currents. - &\frac{d I_j}{d t} = - k_j I_j - &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau + **References** - &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) + .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear + integrate-and-fire neural model produces diverse spiking + behaviors." Neural computation 21.3 (2009): 704-718. + .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan + Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized + leaky integrate-and-fire models classify multiple neuron types." + Nature communications 9, no. 1 (2018): 1-15. - When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + **Examples** - .. math:: + There is a simple usage: - &I_j \leftarrow R_j I_j + A_j + .. code-block:: python - &V \leftarrow V_{reset} + import brainpy as bp + import matplotlib.pyplot as plt - &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) + # Tonic Bursting + neu = bp.dyn.GifRef(1, a=0.005, A1=10., A2=-0.6) + neu.V_th[:] = -50. + inputs = bp.inputs.section_input((1.5, 1.7,), (100, 400)) - Note that :math:`I_j` refers to arbitrary number of internal currents. + runner = bp.DSRunner(neu, monitors=['V', 'V_th']) + runner.run(inputs=inputs) + ts = runner.mon.ts - **References** + fig, gs = bp.visualize.get_figure(1, 1, 4, 8) + ax1 = fig.add_subplot(gs[0, 0]) - .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear - integrate-and-fire neural model produces diverse spiking - behaviors." Neural computation 21.3 (2009): 704-718. - .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan - Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized - leaky integrate-and-fire models classify multiple neuron types." - Nature communications 9, no. 1 (2018): 1-15. + ax1.plot(ts, runner.mon.V[:, 0], label='V') + ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - **Examples** + plt.show() + **Model Examples** - There is a simple usage: mustang i love u + - `Detailed examples to reproduce different firing patterns `_ - .. code-block:: python + **Model Parameters** - import brainpy as bp - import matplotlib.pyplot as plt + ============= ============== ======== ==================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------- + V_rest -70 mV Resting potential. + V_reset -70 mV Reset potential after spike. + V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. + V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. + R 20 \ Membrane resistance. + tau 20 ms Membrane time constant. Compute by :math:`R * C`. + a 0 \ Coefficient describes the dependence of + :math:`V_{th}` on membrane potential. + b 0.01 \ Coefficient describes :math:`V_{th}` update. + k1 0.2 \ Constant pf :math:`I1`. + k2 0.02 \ Constant of :math:`I2`. + R1 0 \ Free parameter. + Describes dependence of :math:`I_1` reset value on + :math:`I_1` value before spiking. + R2 1 \ Free parameter. + Describes dependence of :math:`I_2` reset value on + :math:`I_2` value before spiking. + A1 0 \ Free parameter. + A2 0 \ Free parameter. + ============= ============== ======== ==================================================================== - # Hyperpolarization-induced Spiking - neu = bp.dyn.GifRefLTC(1, a=0.005) - neu.V_th[:] = -50. - inputs = bp.inputs.section_input((1.5, 1.7, 1.5, 1.7), (100, 400, 100, 400)) + **Model Variables** - runner = bp.DSRunner(neu, monitors=['V', 'V_th']) - runner.run(inputs=inputs) + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -70 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + V_th -50 Spiking threshold potential. + I1 0 Internal current 1. + I2 0 Internal current 2. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - ts = runner.mon.ts - - fig, gs = bp.visualize.get_figure(1, 1, 4, 8) - ax1 = fig.add_subplot(gs[0, 0]) - - ax1.plot(ts, runner.mon.V[:, 0], label='V') - ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - - plt.show() - - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ==================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------- - V_rest -70 mV Resting potential. - V_reset -70 mV Reset potential after spike. - V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. - V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. - R 20 \ Membrane resistance. - tau 20 ms Membrane time constant. Compute by :math:`R * C`. - a 0 \ Coefficient describes the dependence of - :math:`V_{th}` on membrane potential. - b 0.01 \ Coefficient describes :math:`V_{th}` update. - k1 0.2 \ Constant pf :math:`I1`. - k2 0.02 \ Constant of :math:`I2`. - R1 0 \ Free parameter. - Describes dependence of :math:`I_1` reset value on - :math:`I_1` value before spiking. - R2 1 \ Free parameter. - Describes dependence of :math:`I_2` reset value on - :math:`I_2` value before spiking. - A1 0 \ Free parameter. - A2 0 \ Free parameter. - ============= ============== ======== ==================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -70 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - V_th -50 Spiking threshold potential. - I1 0 Internal current 1. - I2 0 Internal current 2. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - - - Args: - %s - %s - %s -""" - - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_rest: Union[float, ArrayType, Callable] = -70., - V_reset: Union[float, ArrayType, Callable] = -70., - V_th_inf: Union[float, ArrayType, Callable] = -50., - V_th_reset: Union[float, ArrayType, Callable] = -60., - R: Union[float, ArrayType, Callable] = 20., - tau: Union[float, ArrayType, Callable] = 20., - a: Union[float, ArrayType, Callable] = 0., - b: Union[float, ArrayType, Callable] = 0.01, - k1: Union[float, ArrayType, Callable] = 0.2, - k2: Union[float, ArrayType, Callable] = 0.02, - R1: Union[float, ArrayType, Callable] = 0., - R2: Union[float, ArrayType, Callable] = 1., - A1: Union[float, ArrayType, Callable] = 0., - A2: Union[float, ArrayType, Callable] = 0., - V_initializer: Union[Callable, ArrayType] = OneInit(-70.), - I1_initializer: Union[Callable, ArrayType] = ZeroInit(), - I2_initializer: Union[Callable, ArrayType] = ZeroInit(), - Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_rest=V_rest, - V_reset=V_reset, - V_th_inf=V_th_inf, - V_th_reset=V_th_reset, - R=R, - a=a, - b=b, - k1=k1, - k2=k2, - R1=R1, - R2=R2, - A1=A1, - A2=A2, - tau=tau, - V_initializer=V_initializer, - I1_initializer=I1_initializer, - I2_initializer=I2_initializer, - Vth_initializer=Vth_initializer, - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._I1_initializer = is_initializer(I1_initializer) - self._I2_initializer = is_initializer(I2_initializer) - self._Vth_initializer = is_initializer(Vth_initializer) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=4) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e8) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) - V += self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - if self.spk_reset == 'soft': - V -= (self.V_th - self.V_reset) * spike_no_grad - elif self.spk_reset == 'hard': - V += (self.V_reset - V) * spike_no_grad - else: - raise ValueError - I1 += spike * (self.R1 * I1 + self.A1 - I1) - I2 += spike * (self.R2 * I2 + self.A2 - I2) - V_th += (bm.maximum(self.V_th_reset, V_th) - V_th) * spike_no_grad - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.V_reset, V) - I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) - I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) - V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.I1.value = I1 - self.I2.value = I2 - self.V_th.value = V_th - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike -class GifRef(GifRefLTC): - r"""Generalized Integrate-and-Fire model. + Args: + %s + %s + %s + """ - **Model Descriptions** + def dV(self, V, t, I1, I2, I): + return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau - The generalized integrate-and-fire model [1]_ is given by + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) - .. math:: - &\frac{d I_j}{d t} = - k_j I_j +Gif.__doc__ = Gif.__doc__ % (pneu_doc, dpneu_doc) +GifRefLTC.__doc__ = GifRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +GifRef.__doc__ = GifRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +GifLTC.__doc__ = GifLTC.__doc__ % () - &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau - &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) +class IzhikevichLTC(GradNeuDyn): + r"""The Izhikevich neuron model with liquid time-constant. - When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + **Model Descriptions** - .. math:: + The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: - &I_j \leftarrow R_j I_j + A_j + .. math :: - &V \leftarrow V_{reset} + \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I - &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) + \frac{d u}{d t} &=a(b V-u) - Note that :math:`I_j` refers to arbitrary number of internal currents. + .. math :: + \text{if} v \geq 30 \text{mV}, \text{then} + \begin{cases} v \leftarrow c \\ + u \leftarrow u+d \end{cases} - **References** - .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear - integrate-and-fire neural model produces diverse spiking - behaviors." Neural computation 21.3 (2009): 704-718. - .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan - Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized - leaky integrate-and-fire models classify multiple neuron types." - Nature communications 9, no. 1 (2018): 1-15. + **References** - **Examples** + .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE + Transactions on neural networks 14.6 (2003): 1569-1572. - There is a simple usage: + .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." + IEEE transactions on neural networks 15.5 (2004): 1063-1070. - .. code-block:: python + **Examples** - import brainpy as bp - import matplotlib.pyplot as plt + There is a simple usage example:: - # Tonic Bursting - neu = bp.dyn.GifRef(1, a=0.005, A1=10., A2=-0.6) - neu.V_th[:] = -50. - inputs = bp.inputs.section_input((1.5, 1.7,), (100, 400)) + import brainpy as bp - runner = bp.DSRunner(neu, monitors=['V', 'V_th']) - runner.run(inputs=inputs) + neu = bp.dyn.IzhikevichLTC(2) - ts = runner.mon.ts - - fig, gs = bp.visualize.get_figure(1, 1, 4, 8) - ax1 = fig.add_subplot(gs[0, 0]) - - ax1.plot(ts, runner.mon.V[:, 0], label='V') - ax1.plot(ts, runner.mon.V_th[:, 0], label='V_th') - - plt.show() - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ==================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------- - V_rest -70 mV Resting potential. - V_reset -70 mV Reset potential after spike. - V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. - V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. - R 20 \ Membrane resistance. - tau 20 ms Membrane time constant. Compute by :math:`R * C`. - a 0 \ Coefficient describes the dependence of - :math:`V_{th}` on membrane potential. - b 0.01 \ Coefficient describes :math:`V_{th}` update. - k1 0.2 \ Constant pf :math:`I1`. - k2 0.02 \ Constant of :math:`I2`. - R1 0 \ Free parameter. - Describes dependence of :math:`I_1` reset value on - :math:`I_1` value before spiking. - R2 1 \ Free parameter. - Describes dependence of :math:`I_2` reset value on - :math:`I_2` value before spiking. - A1 0 \ Free parameter. - A2 0 \ Free parameter. - ============= ============== ======== ==================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -70 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - V_th -50 Spiking threshold potential. - I1 0 Internal current 1. - I2 0 Internal current 2. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - - - Args: - %s - %s - %s -""" - - def dV(self, V, t, I1, I2, I): - return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau - - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) -Gif.__doc__ = Gif.__doc__ % (pneu_doc, dpneu_doc) -GifRefLTC.__doc__ = GifRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) -GifRef.__doc__ = GifRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) -GifLTC.__doc__ = GifLTC.__doc__ % () + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + + + + **Model Examples** + + - `Detailed examples to reproduce different firing patterns `_ + + **Model Parameters** + + ============= ============== ======== ================================================================================ + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------------------- + a 0.02 \ It determines the time scaling of + the recovery variable :math:`u`. + b 0.2 \ It describes the sensitivity of the + recovery variable :math:`u` to + the sub-threshold fluctuations of the + membrane potential :math:`v`. + c -65 \ It describes the after-spike reset value + of the membrane potential :math:`v` caused by + the fast high-threshold :math:`K^{+}` + conductance. + d 8 \ It describes after-spike reset of the + recovery variable :math:`u` + caused by slow high-threshold + :math:`Na^{+}` and :math:`K^{+}` conductance. + tau_ref 0 ms Refractory period length. [ms] + V_th 30 mV The membrane potential threshold. + ============= ============== ======== ================================================================================ + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -65 Membrane potential. + u 1 Recovery variable. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + """ + + def __init__( + self, + size: Shape, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # neuron parameters + V_th: Union[float, ArrayType, Callable] = 30., + p1: Union[float, ArrayType, Callable] = 0.04, + p2: Union[float, ArrayType, Callable] = 5., + p3: Union[float, ArrayType, Callable] = 140., + a: Union[float, ArrayType, Callable] = 0.02, + b: Union[float, ArrayType, Callable] = 0.20, + c: Union[float, ArrayType, Callable] = -65., + d: Union[float, ArrayType, Callable] = 8., + tau: Union[float, ArrayType, Callable] = 10., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + u_initializer: Union[Callable, ArrayType] = None, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + scaling=scaling) + # parameters + self.V_th = self.offset_scaling(self.init_param(V_th)) + self.p1 = self.inv_scaling(self.init_param(p1)) + p2_scaling = self.scaling.clone(bias=-p1 * 2 * self.scaling.bias, scale=1.) + self.p2 = p2_scaling.offset_scaling(self.init_param(p2)) + p3_bias = p1 * self.scaling.bias ** 2 + b * self.scaling.bias - p2 * self.scaling.bias + p3_scaling = self.scaling.clone(bias=p3_bias, scale=self.scaling.scale) + self.p3 = p3_scaling.offset_scaling(self.init_param(p3)) + self.a = self.init_param(a) + self.b = self.init_param(b) + self.c = self.offset_scaling(self.init_param(c)) + self.d = self.std_scaling(self.init_param(d)) + self.R = self.init_param(R) + self.tau = self.init_param(tau) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._u_initializer = is_initializer(u_initializer, allow_none=True) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def dV(self, V, t, u, I): + I = self.sum_current_inputs(V, init=I) + dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I + return dVdt + + def du(self, u, t, V): + dudt = self.a * (self.b * V - u) + return dudt + + @property + def derivative(self): + return JointEq([self.dV, self.du]) + + def reset_state(self, batch_size=None, **kwargs): + self.V = self.init_variable(self._V_initializer, batch_size) + u_initializer = OneInit(self.b * self.V) if self._u_initializer is None else self._u_initializer + self._u_initializer = is_initializer(u_initializer) + self.V = self.offset_scaling(self.V) + self.u = self.offset_scaling(self.init_variable(self._u_initializer, batch_size), + bias=self.b * self.scaling.bias, + scale=self.scaling.scale) + self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, u = self.integral(self.V.value, self.u.value, t, x, dt) + V += self.sum_delta_inputs() + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += spike * (self.c - V) + u += spike * self.d + + else: + spike = V >= self.V_th + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + + self.V.value = V + self.u.value = u + self.spike.value = spike + return spike + + def return_info(self): + return self.spike -class IzhikevichLTC(GradNeuDyn): - r"""The Izhikevich neuron model with liquid time-constant. +class Izhikevich(IzhikevichLTC): + r"""The Izhikevich neuron model. **Model Descriptions** @@ -3977,7 +4192,7 @@ class IzhikevichLTC(GradNeuDyn): import brainpy as bp - neu = bp.dyn.IzhikevichLTC(2) + neu = bp.dyn.Izhikevich(2) # section input with wiener process inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() @@ -3989,7 +4204,6 @@ class IzhikevichLTC(GradNeuDyn): bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Examples** - `Detailed examples to reproduce different firing patterns `_ @@ -4029,567 +4243,354 @@ class IzhikevichLTC(GradNeuDyn): refractory False Flag to mark whether the neuron is in refractory period. t_last_spike -1e7 Last spike time stamp. ================== ================= ========================================================= - """ - - def __init__( - self, - size: Shape, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # neuron parameters - V_th: Union[float, ArrayType, Callable] = 30., - p1: Union[float, ArrayType, Callable] = 0.04, - p2: Union[float, ArrayType, Callable] = 5., - p3: Union[float, ArrayType, Callable] = 140., - a: Union[float, ArrayType, Callable] = 0.02, - b: Union[float, ArrayType, Callable] = 0.20, - c: Union[float, ArrayType, Callable] = -65., - d: Union[float, ArrayType, Callable] = 8., - tau: Union[float, ArrayType, Callable] = 10., - R: Union[float, ArrayType, Callable] = 1., - V_initializer: Union[Callable, ArrayType] = OneInit(-70.), - u_initializer: Union[Callable, ArrayType] = None, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - method=method, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - scaling=scaling) - # parameters - self.V_th = self.offset_scaling(self.init_param(V_th)) - self.p1 = self.inv_scaling(self.init_param(p1)) - p2_scaling = self.scaling.clone(bias=-p1 * 2 * self.scaling.bias, scale=1.) - self.p2 = p2_scaling.offset_scaling(self.init_param(p2)) - p3_bias = p1 * self.scaling.bias ** 2 + b * self.scaling.bias - p2 * self.scaling.bias - p3_scaling = self.scaling.clone(bias=p3_bias, scale=self.scaling.scale) - self.p3 = p3_scaling.offset_scaling(self.init_param(p3)) - self.a = self.init_param(a) - self.b = self.init_param(b) - self.c = self.offset_scaling(self.init_param(c)) - self.d = self.std_scaling(self.init_param(d)) - self.R = self.init_param(R) - self.tau = self.init_param(tau) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._u_initializer = is_initializer(u_initializer, allow_none=True) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def dV(self, V, t, u, I): - I = self.sum_current_inputs(V, init=I) - dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I - return dVdt - - def du(self, u, t, V): - dudt = self.a * (self.b * V - u) - return dudt - - @property - def derivative(self): - return JointEq([self.dV, self.du]) - - def reset_state(self, batch_size=None, **kwargs): - self.V = self.init_variable(self._V_initializer, batch_size) - u_initializer = OneInit(self.b * self.V) if self._u_initializer is None else self._u_initializer - self._u_initializer = is_initializer(u_initializer) - self.V = self.offset_scaling(self.V) - self.u = self.offset_scaling(self.init_variable(self._u_initializer, batch_size), bias=self.b * self.scaling.bias, - scale=self.scaling.scale) - self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V, u = self.integral(self.V.value, self.u.value, t, x, dt) - V += self.sum_delta_inputs() - - # spike, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike = stop_gradient(spike) if self.detach_spk else spike - V += spike * (self.c - V) - u += spike * self.d - - else: - spike = V >= self.V_th - V = bm.where(spike, self.c, V) - u = bm.where(spike, u + self.d, u) - - self.V.value = V - self.u.value = u - self.spike.value = spike - return spike - - def return_info(self): - return self.spike - - -class Izhikevich(IzhikevichLTC): - r"""The Izhikevich neuron model. - - **Model Descriptions** - - The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: - - .. math :: - - \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I - - \frac{d u}{d t} &=a(b V-u) - - .. math :: - - \text{if} v \geq 30 \text{mV}, \text{then} - \begin{cases} v \leftarrow c \\ - u \leftarrow u+d \end{cases} - - - **References** - - .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE - Transactions on neural networks 14.6 (2003): 1569-1572. - - .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." - IEEE transactions on neural networks 15.5 (2004): 1063-1070. - - **Examples** - - There is a simple usage example:: - - import brainpy as bp - - neu = bp.dyn.Izhikevich(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + Args: + %s + %s - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - - - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ================================================================================ - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------------------- - a 0.02 \ It determines the time scaling of - the recovery variable :math:`u`. - b 0.2 \ It describes the sensitivity of the - recovery variable :math:`u` to - the sub-threshold fluctuations of the - membrane potential :math:`v`. - c -65 \ It describes the after-spike reset value - of the membrane potential :math:`v` caused by - the fast high-threshold :math:`K^{+}` - conductance. - d 8 \ It describes after-spike reset of the - recovery variable :math:`u` - caused by slow high-threshold - :math:`Na^{+}` and :math:`K^{+}` conductance. - tau_ref 0 ms Refractory period length. [ms] - V_th 30 mV The membrane potential threshold. - ============= ============== ======== ================================================================================ - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -65 Membrane potential. - u 1 Recovery variable. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - - Args: - %s - %s - - """ + """ - def dV(self, V, t, u, I): - dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I - return dVdt + def dV(self, V, t, u, I): + dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I + return dVdt - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) class IzhikevichRefLTC(IzhikevichLTC): - r"""The Izhikevich neuron model with liquid time-constant. + r"""The Izhikevich neuron model with liquid time-constant. - **Model Descriptions** + **Model Descriptions** - The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: + The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: - .. math :: + .. math :: - \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I + \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I - \frac{d u}{d t} &=a(b V-u) + \frac{d u}{d t} &=a(b V-u) - .. math :: + .. math :: - \text{if} v \geq 30 \text{mV}, \text{then} - \begin{cases} v \leftarrow c \\ - u \leftarrow u+d \end{cases} + \text{if} v \geq 30 \text{mV}, \text{then} + \begin{cases} v \leftarrow c \\ + u \leftarrow u+d \end{cases} - **References** + **References** - .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE - Transactions on neural networks 14.6 (2003): 1569-1572. + .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE + Transactions on neural networks 14.6 (2003): 1569-1572. - .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." - IEEE transactions on neural networks 15.5 (2004): 1063-1070. + .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." + IEEE transactions on neural networks 15.5 (2004): 1063-1070. - **Examples** + **Examples** - There is a simple usage example:: + There is a simple usage example:: - import brainpy as bp + import brainpy as bp - neu = bp.dyn.IzhikevichRefLTC(2) + neu = bp.dyn.IzhikevichRefLTC(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Examples** + **Model Examples** - - `Detailed examples to reproduce different firing patterns `_ + - `Detailed examples to reproduce different firing patterns `_ - **Model Parameters** + **Model Parameters** - ============= ============== ======== ================================================================================ - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------------------- - a 0.02 \ It determines the time scaling of - the recovery variable :math:`u`. - b 0.2 \ It describes the sensitivity of the - recovery variable :math:`u` to - the sub-threshold fluctuations of the - membrane potential :math:`v`. - c -65 \ It describes the after-spike reset value - of the membrane potential :math:`v` caused by - the fast high-threshold :math:`K^{+}` - conductance. - d 8 \ It describes after-spike reset of the - recovery variable :math:`u` - caused by slow high-threshold - :math:`Na^{+}` and :math:`K^{+}` conductance. - tau_ref 0 ms Refractory period length. [ms] - V_th 30 mV The membrane potential threshold. - ============= ============== ======== ================================================================================ + ============= ============== ======== ================================================================================ + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------------------- + a 0.02 \ It determines the time scaling of + the recovery variable :math:`u`. + b 0.2 \ It describes the sensitivity of the + recovery variable :math:`u` to + the sub-threshold fluctuations of the + membrane potential :math:`v`. + c -65 \ It describes the after-spike reset value + of the membrane potential :math:`v` caused by + the fast high-threshold :math:`K^{+}` + conductance. + d 8 \ It describes after-spike reset of the + recovery variable :math:`u` + caused by slow high-threshold + :math:`Na^{+}` and :math:`K^{+}` conductance. + tau_ref 0 ms Refractory period length. [ms] + V_th 30 mV The membrane potential threshold. + ============= ============== ======== ================================================================================ - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -65 Membrane potential. - u 1 Recovery variable. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -65 Membrane potential. + u 1 Recovery variable. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - %s + Args: + %s + %s + %s - """ + """ - def __init__( - self, - size: Shape, - sharding: Optional[Sharding] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - spk_fun: Callable = bm.surrogate.InvSquareGrad(), - spk_dtype: Any = None, - spk_reset: str = 'soft', - detach_spk: bool = False, - method: str = 'exp_auto', - name: Optional[str] = None, - init_var: bool = True, - scaling: Optional[bm.Scaling] = None, - - # old neuron parameter - V_th: Union[float, ArrayType, Callable] = 30., - p1: Union[float, ArrayType, Callable] = 0.04, - p2: Union[float, ArrayType, Callable] = 5., - p3: Union[float, ArrayType, Callable] = 140., - a: Union[float, ArrayType, Callable] = 0.02, - b: Union[float, ArrayType, Callable] = 0.20, - c: Union[float, ArrayType, Callable] = -65., - d: Union[float, ArrayType, Callable] = 8., - tau: Union[float, ArrayType, Callable] = 10., - R: Union[float, ArrayType, Callable] = 1., - V_initializer: Union[Callable, ArrayType] = OneInit(-70.), - u_initializer: Union[Callable, ArrayType] = None, - - # new neuron parameter - tau_ref: Union[float, ArrayType, Callable] = 0., - ref_var: bool = False, - - # noise - noise: Union[float, ArrayType, Callable] = None, - ): - # initialization - super().__init__( - size=size, - name=name, - keep_size=keep_size, - mode=mode, - method=method, - sharding=sharding, - spk_fun=spk_fun, - detach_spk=detach_spk, - spk_dtype=spk_dtype, - spk_reset=spk_reset, - - init_var=False, - scaling=scaling, - - V_th=V_th, - p1=p1, - p2=p2, - p3=p3, - a=a, - b=b, - c=c, - d=d, - R=R, - tau=tau, - V_initializer=V_initializer, - u_initializer=u_initializer - ) - - # parameters - self.ref_var = ref_var - self.tau_ref = self.init_param(tau_ref) - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._u_initializer = is_initializer(u_initializer, allow_none=True) - - # integral - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) - else: - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size, **kwargs) - self.t_last_spike = self.init_variable(bm.ones, batch_size) - self.t_last_spike.fill_(-1e7) - if self.ref_var: - self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integrate membrane potential - V, u = self.integral(self.V.value, self.u.value, t, x, dt) - V += self.sum_delta_inputs() - - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spk_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) if self.detach_spk else spike - V += spike * (self.c - V) - u += spike * self.d - spike_ = spike_no_grad > 0. - # will be used in other place, like Delta Synapse, so stop its gradient - if self.ref_var: - self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - - else: - spike = V >= self.V_th - V = bm.where(spike, self.c, V) - u = bm.where(spike, u + self.d, u) - if self.ref_var: - self.refractory.value = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.V.value = V - self.u.value = u - self.spike.value = spike - self.t_last_spike.value = t_last_spike - return spike + def __init__( + self, + size: Shape, + sharding: Optional[Sharding] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_dtype: Any = None, + spk_reset: str = 'soft', + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + scaling: Optional[bm.Scaling] = None, + + # old neuron parameter + V_th: Union[float, ArrayType, Callable] = 30., + p1: Union[float, ArrayType, Callable] = 0.04, + p2: Union[float, ArrayType, Callable] = 5., + p3: Union[float, ArrayType, Callable] = 140., + a: Union[float, ArrayType, Callable] = 0.02, + b: Union[float, ArrayType, Callable] = 0.20, + c: Union[float, ArrayType, Callable] = -65., + d: Union[float, ArrayType, Callable] = 8., + tau: Union[float, ArrayType, Callable] = 10., + R: Union[float, ArrayType, Callable] = 1., + V_initializer: Union[Callable, ArrayType] = OneInit(-70.), + u_initializer: Union[Callable, ArrayType] = None, + + # new neuron parameter + tau_ref: Union[float, ArrayType, Callable] = 0., + ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + sharding=sharding, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_dtype=spk_dtype, + spk_reset=spk_reset, + + init_var=False, + scaling=scaling, + + V_th=V_th, + p1=p1, + p2=p2, + p3=p3, + a=a, + b=b, + c=c, + d=d, + R=R, + tau=tau, + V_initializer=V_initializer, + u_initializer=u_initializer + ) + + # parameters + self.ref_var = ref_var + self.tau_ref = self.init_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._u_initializer = is_initializer(u_initializer, allow_none=True) + + # integral + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size, **kwargs) + self.t_last_spike = self.init_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.ref_var: + self.refractory = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V, u = self.integral(self.V.value, self.u.value, t, x, dt) + V += self.sum_delta_inputs() + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += spike * (self.c - V) + u += spike * self.d + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.c, V) + u = bm.where(spike, u + self.d, u) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.u.value = u + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike class IzhikevichRef(IzhikevichRefLTC): - r"""The Izhikevich neuron model. + r"""The Izhikevich neuron model. - **Model Descriptions** + **Model Descriptions** - The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: + The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: - .. math :: + .. math :: - \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I + \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I - \frac{d u}{d t} &=a(b V-u) + \frac{d u}{d t} &=a(b V-u) - .. math :: + .. math :: - \text{if} v \geq 30 \text{mV}, \text{then} - \begin{cases} v \leftarrow c \\ - u \leftarrow u+d \end{cases} + \text{if} v \geq 30 \text{mV}, \text{then} + \begin{cases} v \leftarrow c \\ + u \leftarrow u+d \end{cases} - **References** + **References** - .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE - Transactions on neural networks 14.6 (2003): 1569-1572. + .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE + Transactions on neural networks 14.6 (2003): 1569-1572. - .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." - IEEE transactions on neural networks 15.5 (2004): 1063-1070. + .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." + IEEE transactions on neural networks 15.5 (2004): 1063-1070. - **Examples** + **Examples** - There is a simple usage example:: + There is a simple usage example:: - import brainpy as bp + import brainpy as bp - neu = bp.dyn.IzhikevichRef(2) + neu = bp.dyn.IzhikevichRef(2) - # section input with wiener process - inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() - inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 + # section input with wiener process + inp1 = bp.inputs.wiener_process(500., n=1, t_start=100., t_end=400.).flatten() + inputs = bp.inputs.section_input([0., 22., 0.], [100., 300., 100.]) + inp1 - runner = bp.DSRunner(neu, monitors=['V']) - runner.run(inputs=inputs) + runner = bp.DSRunner(neu, monitors=['V']) + runner.run(inputs=inputs) - bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) + bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], plot_ids=(0, 1), show=True) - **Model Examples** + **Model Examples** - - `Detailed examples to reproduce different firing patterns `_ + - `Detailed examples to reproduce different firing patterns `_ - **Model Parameters** + **Model Parameters** - ============= ============== ======== ================================================================================ - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------------------- - a 0.02 \ It determines the time scaling of - the recovery variable :math:`u`. - b 0.2 \ It describes the sensitivity of the - recovery variable :math:`u` to - the sub-threshold fluctuations of the - membrane potential :math:`v`. - c -65 \ It describes the after-spike reset value - of the membrane potential :math:`v` caused by - the fast high-threshold :math:`K^{+}` - conductance. - d 8 \ It describes after-spike reset of the - recovery variable :math:`u` - caused by slow high-threshold - :math:`Na^{+}` and :math:`K^{+}` conductance. - tau_ref 0 ms Refractory period length. [ms] - V_th 30 mV The membrane potential threshold. - ============= ============== ======== ================================================================================ + ============= ============== ======== ================================================================================ + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------------------- + a 0.02 \ It determines the time scaling of + the recovery variable :math:`u`. + b 0.2 \ It describes the sensitivity of the + recovery variable :math:`u` to + the sub-threshold fluctuations of the + membrane potential :math:`v`. + c -65 \ It describes the after-spike reset value + of the membrane potential :math:`v` caused by + the fast high-threshold :math:`K^{+}` + conductance. + d 8 \ It describes after-spike reset of the + recovery variable :math:`u` + caused by slow high-threshold + :math:`Na^{+}` and :math:`K^{+}` conductance. + tau_ref 0 ms Refractory period length. [ms] + V_th 30 mV The membrane potential threshold. + ============= ============== ======== ================================================================================ - **Model Variables** + **Model Variables** - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -65 Membrane potential. - u 1 Recovery variable. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -65 Membrane potential. + u 1 Recovery variable. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= - Args: - %s - %s - %s - """ + Args: + %s + %s + %s + """ - def dV(self, V, t, u, I): - dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I - return dVdt + def dV(self, V, t, u, I): + dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I + return dVdt - def update(self, x=None): - x = 0. if x is None else x - x = self.sum_current_inputs(self.V.value, init=x) - return super().update(x) + def update(self, x=None): + x = 0. if x is None else x + x = self.sum_current_inputs(self.V.value, init=x) + return super().update(x) Izhikevich.__doc__ = Izhikevich.__doc__ % (pneu_doc, dpneu_doc) diff --git a/brainpy/_src/dyn/neurons/tests/test_hh.py b/brainpy/_src/dyn/neurons/tests/test_hh.py index 961701f7e..cd4603e7b 100644 --- a/brainpy/_src/dyn/neurons/tests/test_hh.py +++ b/brainpy/_src/dyn/neurons/tests/test_hh.py @@ -1,141 +1,142 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized from brainpy._src.dyn.neurons import hh class Test_HH(parameterized.TestCase): - def test_HH(self): - model = hh.HH(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - def test_HH_batching_mode(self): - model = hh.HH(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - def test_HHLTC(self): - model = hh.HHLTC(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - def test_HHLTC_batching_mode(self): - model = hh.HHLTC(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - def test_MorrisLecar(self): - model = hh.MorrisLecar(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - def test_MorrisLecar_batching_mode(self): - model = hh.MorrisLecar(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - def test_MorrisLecarLTC(self): - model = hh.MorrisLecarLTC(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - def test_MorrisLecarLTC_batching_mode(self): - model = hh.MorrisLecarLTC(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - def test_WangBuzsakiModel(self): - model = hh.WangBuzsakiHH(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - def test_WangBuzsakiModel_batching_mode(self): - model = hh.WangBuzsakiHH(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - def test_WangBuzsakiModelLTC(self): - model = hh.WangBuzsakiHHLTC(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - def test_WangBuzsakiModelLTC_batching_mode(self): - model = hh.WangBuzsakiHHLTC(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + def test_HH(self): + model = hh.HH(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_HH_batching_mode(self): + model = hh.HH(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_HHLTC(self): + model = hh.HHLTC(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_HHLTC_batching_mode(self): + model = hh.HHLTC(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_MorrisLecar(self): + model = hh.MorrisLecar(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_MorrisLecar_batching_mode(self): + model = hh.MorrisLecar(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_MorrisLecarLTC(self): + model = hh.MorrisLecarLTC(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_MorrisLecarLTC_batching_mode(self): + model = hh.MorrisLecarLTC(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_WangBuzsakiModel(self): + model = hh.WangBuzsakiHH(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_WangBuzsakiModel_batching_mode(self): + model = hh.WangBuzsakiHH(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_WangBuzsakiModelLTC(self): + model = hh.WangBuzsakiHHLTC(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_WangBuzsakiModelLTC_batching_mode(self): + model = hh.WangBuzsakiHHLTC(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) diff --git a/brainpy/_src/dyn/neurons/tests/test_lif.py b/brainpy/_src/dyn/neurons/tests/test_lif.py index 1521f82da..f162b441c 100644 --- a/brainpy/_src/dyn/neurons/tests/test_lif.py +++ b/brainpy/_src/dyn/neurons/tests/test_lif.py @@ -1,62 +1,62 @@ # -*- coding: utf-8 -*- import numpy as np +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized from brainpy._src.dyn.neurons import lif class Test_lif(parameterized.TestCase): - @parameterized.named_parameters( - {'testcase_name': f'{name}', 'neuron': name} - for name in lif.__all__ - ) - def test_run_shape(self, neuron): - model = getattr(lif, neuron)(size=1) - if neuron in ['IF', 'IFLTC']: - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - else: - runner = bp.DSRunner(model, - monitors=['V', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + @parameterized.named_parameters( + {'testcase_name': f'{name}', 'neuron': name} + for name in lif.__all__ + ) + def test_run_shape(self, neuron): + model = getattr(lif, neuron)(size=1) + if neuron in ['IF', 'IFLTC']: + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + else: + runner = bp.DSRunner(model, + monitors=['V', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - @parameterized.named_parameters( - {'testcase_name': f'{name}', 'neuron': name} - for name in lif.__all__ - ) - def test_training_shape(self, neuron): - model = getattr(lif, neuron)(size=10, mode=bm.training_mode) - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + @parameterized.named_parameters( + {'testcase_name': f'{name}', 'neuron': name} + for name in lif.__all__ + ) + def test_training_shape(self, neuron): + model = getattr(lif, neuron)(size=10, mode=bm.training_mode) + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - @parameterized.named_parameters( - {'testcase_name': f'{name}', 'neuron': name} - for name in lif.__all__ - ) - def test_training_lif(self, neuron): - if neuron not in ['IF', 'IFLTC']: - model1 = getattr(lif, neuron)(size=1, - V_initializer=bp.init.Constant(-70.), - mode=bm.training_mode, - spk_reset='hard', - scaling=bm.Scaling.transform(V_range=[-70, 30], scaled_V_range=[0, 1])) - model2 = getattr(lif, neuron)(size=1, - V_initializer=bp.init.Constant(-70.), - mode=bm.training_mode, - spk_reset='hard', - scaling=bm.Scaling(scale=1, bias=0)) - indices = bm.arange(5000) - spks1 = bm.for_loop(lambda i: model1.step_run(i, 10./model1.scaling.scale), indices, jit=True) - spks2 = bm.for_loop(lambda i: model2.step_run(i, 10./model2.scaling.scale), indices, jit=True) - self.assertTrue(np.allclose(spks1, spks2)) \ No newline at end of file + @parameterized.named_parameters( + {'testcase_name': f'{name}', 'neuron': name} + for name in lif.__all__ + ) + def test_training_lif(self, neuron): + if neuron not in ['IF', 'IFLTC']: + model1 = getattr(lif, neuron)(size=1, + V_initializer=bp.init.Constant(-70.), + mode=bm.training_mode, + spk_reset='hard', + scaling=bm.Scaling.transform(V_range=[-70, 30], scaled_V_range=[0, 1])) + model2 = getattr(lif, neuron)(size=1, + V_initializer=bp.init.Constant(-70.), + mode=bm.training_mode, + spk_reset='hard', + scaling=bm.Scaling(scale=1, bias=0)) + indices = bm.arange(5000) + spks1 = bm.for_loop(lambda i: model1.step_run(i, 10. / model1.scaling.scale), indices, jit=True) + spks2 = bm.for_loop(lambda i: model2.step_run(i, 10. / model2.scaling.scale), indices, jit=True) + self.assertTrue(np.allclose(spks1, spks2)) diff --git a/brainpy/_src/dyn/others/__init__.py b/brainpy/_src/dyn/others/__init__.py index e0b63186d..286bd5a29 100644 --- a/brainpy/_src/dyn/others/__init__.py +++ b/brainpy/_src/dyn/others/__init__.py @@ -1,2 +1,2 @@ def commono(): - return None \ No newline at end of file + return None diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py index 812375787..3613ac889 100644 --- a/brainpy/_src/dyn/others/common.py +++ b/brainpy/_src/dyn/others/common.py @@ -11,153 +11,153 @@ from brainpy.types import ArrayType __all__ = [ - 'Leaky', - 'Integrator', + 'Leaky', + 'Integrator', ] class Leaky(NeuDyn): - r"""Leaky Integrator Model. + r"""Leaky Integrator Model. - **Model Descriptions** + **Model Descriptions** - This class implements a leaky model, in which its dynamics is - given by: + This class implements a leaky model, in which its dynamics is + given by: - .. math:: + .. math:: - x(t + \Delta t) = \exp{-\Delta t/\tau} x(t) + I + x(t + \Delta t) = \exp{-\Delta t/\tau} x(t) + I - Args: - tau: float, ArrayType, Initializer, callable. Membrane time constant. - method: str. The numerical integration method. Default "exp_auto". - init_var: Initialize the variable or not. - %s - """ + Args: + tau: float, ArrayType, Initializer, callable. Membrane time constant. + method: str. The numerical integration method. Default "exp_auto". + init_var: Initialize the variable or not. + %s + """ - supported_modes = (bm.TrainingMode, bm.NonBatchingMode) + supported_modes = (bm.TrainingMode, bm.NonBatchingMode) - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, - tau: Union[float, ArrayType, Callable] = 10., - method: str = 'exp_auto', - init_var: bool = True - ): - super().__init__(size, - mode=mode, - name=name, - sharding=sharding, - keep_size=keep_size) + tau: Union[float, ArrayType, Callable] = 10., + method: str = 'exp_auto', + init_var: bool = True + ): + super().__init__(size, + mode=mode, + name=name, + sharding=sharding, + keep_size=keep_size) - # parameters - self.sharding = sharding - self.tau = self.init_param(tau) + # parameters + self.sharding = sharding + self.tau = self.init_param(tau) - # integral - self.integral = odeint(method=method, f=self.derivative) + # integral + self.integral = odeint(method=method, f=self.derivative) - # variables - if init_var: - self.reset_state(self.mode) + # variables + if init_var: + self.reset_state(self.mode) - def derivative(self, x, t): - return -x / self.tau + def derivative(self, x, t): + return -x / self.tau - def reset_state(self, batch_size=None, **kwargs): - self.x = self.init_variable(bm.zeros, batch_size) + def reset_state(self, batch_size=None, **kwargs): + self.x = self.init_variable(bm.zeros, batch_size) - def update(self, inp=None): - t = share.load('t') - dt = share.load('dt') - self.x.value = self.integral(self.x.value, t, dt) - if inp is None: inp = 0. - inp = self.sum_current_inputs(self.x.value, init=inp) - self.x += inp - return self.x.value + def update(self, inp=None): + t = share.load('t') + dt = share.load('dt') + self.x.value = self.integral(self.x.value, t, dt) + if inp is None: inp = 0. + inp = self.sum_current_inputs(self.x.value, init=inp) + self.x += inp + return self.x.value - def return_info(self): - return self.x + def return_info(self): + return self.x Leaky.__doc__ = Leaky.__doc__ % pneu_doc class Integrator(NeuDyn): - r"""Integrator Model. - - This class implements an integrator model, in which its dynamics is - given by: - - .. math:: - - \tau \frac{dx}{dt} = - x(t) + I(t) - - where :math:`x` is the integrator value, and :math:`\tau` is the time constant. - - Args: - tau: float, ArrayType, Initializer, callable. Membrane time constant. - method: str. The numerical integration method. Default "exp_auto". - x_initializer: ArrayType, Initializer, callable. The initializer of :math:`x`. - %s - """ - - supported_modes = (bm.TrainingMode, bm.NonBatchingMode) - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - tau: Union[float, ArrayType, Callable] = 10., - x_initializer: Union[Callable, ArrayType] = init.ZeroInit(), - method: str = 'exp_auto', - init_var: bool = True, - ): - super().__init__(size, - mode=mode, - name=name, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.size = tools.to_size(size) - self.sharding = sharding - self.tau = init.parameter(tau, self.size, sharding=self.sharding) - - # initializers - self._x_initializer = is_initializer(x_initializer) - - # integral - self.integral = odeint(method=method, f=self.derivative) - - # variables - if init_var: - self.reset_state(self.mode) - - def derivative(self, V, t, I_ext): - return (-V + I_ext) / self.tau - - def reset_state(self, batch_size=None, **kwargs): - self.x = self.init_variable(self._x_initializer, batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - self.x.value = self.integral(self.x.value, t, I_ext=x, dt=dt) - return self.x.value - - def return_info(self): - return self.x + r"""Integrator Model. + + This class implements an integrator model, in which its dynamics is + given by: + + .. math:: + + \tau \frac{dx}{dt} = - x(t) + I(t) + + where :math:`x` is the integrator value, and :math:`\tau` is the time constant. + + Args: + tau: float, ArrayType, Initializer, callable. Membrane time constant. + method: str. The numerical integration method. Default "exp_auto". + x_initializer: ArrayType, Initializer, callable. The initializer of :math:`x`. + %s + """ + + supported_modes = (bm.TrainingMode, bm.NonBatchingMode) + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + tau: Union[float, ArrayType, Callable] = 10., + x_initializer: Union[Callable, ArrayType] = init.ZeroInit(), + method: str = 'exp_auto', + init_var: bool = True, + ): + super().__init__(size, + mode=mode, + name=name, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.size = tools.to_size(size) + self.sharding = sharding + self.tau = init.parameter(tau, self.size, sharding=self.sharding) + + # initializers + self._x_initializer = is_initializer(x_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I_ext): + return (-V + I_ext) / self.tau + + def reset_state(self, batch_size=None, **kwargs): + self.x = self.init_variable(self._x_initializer, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + self.x.value = self.integral(self.x.value, t, I_ext=x, dt=dt) + return self.x.value + + def return_info(self): + return self.x Integrator.__doc__ = Integrator.__doc__ % pneu_doc diff --git a/brainpy/_src/dyn/others/input.py b/brainpy/_src/dyn/others/input.py index 4dfcf35dc..b02438e10 100644 --- a/brainpy/_src/dyn/others/input.py +++ b/brainpy/_src/dyn/others/input.py @@ -8,225 +8,225 @@ from brainpy import math as bm from brainpy._src.context import share -from brainpy._src.dyn.utils import get_spk_type from brainpy._src.dyn.base import NeuDyn +from brainpy._src.dyn.utils import get_spk_type from brainpy._src.initialize import parameter, variable_ from brainpy._src.mixin import ReturnInfo from brainpy.types import Shape, ArrayType __all__ = [ - 'InputGroup', - 'OutputGroup', - 'SpikeTimeGroup', - 'PoissonGroup', + 'InputGroup', + 'OutputGroup', + 'SpikeTimeGroup', + 'PoissonGroup', ] class InputGroup(NeuDyn): - """Input neuron group for place holder. - - Args: - size: int, tuple of int - keep_size: bool - mode: Mode - name: str - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Any = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, - sharding=sharding, - size=size, - keep_size=keep_size, - mode=mode) - - def update(self, x): - return x - - def return_info(self): - return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) - - def reset_state(self, batch_or_mode=None, **kwargs): - pass + """Input neuron group for place holder. + + Args: + size: int, tuple of int + keep_size: bool + mode: Mode + name: str + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, + sharding=sharding, + size=size, + keep_size=keep_size, + mode=mode) + + def update(self, x): + return x + + def return_info(self): + return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) + + def reset_state(self, batch_or_mode=None, **kwargs): + pass class OutputGroup(NeuDyn): - """Output neuron group for place holder. - - Args: - size: int, tuple of int - keep_size: bool - mode: Mode - name: str - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - sharding: Any = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, - sharding=sharding, - size=size, - keep_size=keep_size, - mode=mode) - - def update(self, x): - return x - - def return_info(self): - return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) - - def reset_state(self, batch_size=None, **kwargs): - pass + """Output neuron group for place holder. + + Args: + size: int, tuple of int + keep_size: bool + mode: Mode + name: str + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + sharding: Any = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, + sharding=sharding, + size=size, + keep_size=keep_size, + mode=mode) + + def update(self, x): + return x + + def return_info(self): + return ReturnInfo(self.varshape, self.sharding, self.mode, bm.zeros) + + def reset_state(self, batch_size=None, **kwargs): + pass class SpikeTimeGroup(NeuDyn): - """The input neuron group characterized by spikes emitting at given times. - - >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. - >>> SpikeTimeGroup(2, times=[10, 20]) - >>> # or - >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms. - >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0]) - >>> # or - >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms. - >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0]) - >>> # or - >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire; - >>> # at 30 ms, neuron 1 fires. - >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) - - Parameters:: - - size : int, tuple, list - The neuron group geometry. - indices : list, tuple, ArrayType - The neuron indices at each time point to emit spikes. - times : list, tuple, ArrayType - The time points which generate the spikes. - name : str, optional - The name of the dynamic system. - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - indices: Union[Sequence, ArrayType], - times: Union[Sequence, ArrayType], - spk_type: Optional[type] = None, - name: Optional[str] = None, - sharding: Optional[Sequence[str]] = None, - keep_size: bool = False, - mode: Optional[bm.Mode] = None, - need_sort: bool = True, - ): - super().__init__(size=size, - sharding=sharding, - name=name, - keep_size=keep_size, - mode=mode) - - # parameters - if keep_size: - raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}') - if len(indices) != len(times): - raise ValueError(f'The length of "indices" and "times" must be the same. ' - f'However, we got {len(indices)} != {len(times)}.') - self.num_times = len(times) - self.spk_type = get_spk_type(spk_type, self.mode) - - # data about times and indices - self.times = bm.asarray(times) - self.indices = bm.asarray(indices, dtype=bm.int_) - if need_sort: - sort_idx = bm.argsort(self.times) - self.indices.value = self.indices[sort_idx] - self.times.value = self.times[sort_idx] - - # variables - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - self.i = bm.Variable(bm.asarray(0)) - self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type), - self.varshape, - batch_size, - axis_names=self.sharding, - batch_axis_name=bm.sharding.BATCH_AXIS) - - def update(self): - # self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding) - self.spike.value = bm.zeros_like(self.spike) - bm.while_loop(self._body_fun, self._cond_fun, ()) - return self.spike.value - - def return_info(self): - return self.spike - - # functions - def _cond_fun(self): - i = self.i.value - return bm.logical_and(i < self.num_times, share['t'] >= self.times[i]) - - def _body_fun(self): - i = self.i.value - if isinstance(self.mode, bm.BatchingMode): - self.spike[:, self.indices[i]] = True - else: - self.spike[self.indices[i]] = True - self.i += 1 + """The input neuron group characterized by spikes emitting at given times. + + >>> # Get 2 neurons, firing spikes at 10 ms and 20 ms. + >>> SpikeTimeGroup(2, times=[10, 20]) + >>> # or + >>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms. + >>> SpikeTimeGroup(2, times=[10, 20], indices=[0, 0]) + >>> # or + >>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms. + >>> SpikeTimeGroup(2, times=[10, 20, 30], indices=[0, 1, 0]) + >>> # or + >>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire; + >>> # at 30 ms, neuron 1 fires. + >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) + + Parameters:: + + size : int, tuple, list + The neuron group geometry. + indices : list, tuple, ArrayType + The neuron indices at each time point to emit spikes. + times : list, tuple, ArrayType + The time points which generate the spikes. + name : str, optional + The name of the dynamic system. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + indices: Union[Sequence, ArrayType], + times: Union[Sequence, ArrayType], + spk_type: Optional[type] = None, + name: Optional[str] = None, + sharding: Optional[Sequence[str]] = None, + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + need_sort: bool = True, + ): + super().__init__(size=size, + sharding=sharding, + name=name, + keep_size=keep_size, + mode=mode) + + # parameters + if keep_size: + raise NotImplementedError(f'Do not support keep_size=True in {self.__class__.__name__}') + if len(indices) != len(times): + raise ValueError(f'The length of "indices" and "times" must be the same. ' + f'However, we got {len(indices)} != {len(times)}.') + self.num_times = len(times) + self.spk_type = get_spk_type(spk_type, self.mode) + + # data about times and indices + self.times = bm.asarray(times) + self.indices = bm.asarray(indices, dtype=bm.int_) + if need_sort: + sort_idx = bm.argsort(self.times) + self.indices.value = self.indices[sort_idx] + self.times.value = self.times[sort_idx] + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + self.i = bm.Variable(bm.asarray(0)) + self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type), + self.varshape, + batch_size, + axis_names=self.sharding, + batch_axis_name=bm.sharding.BATCH_AXIS) + + def update(self): + # self.spike.value = bm.sharding.partition(bm.zeros_like(self.spike), self.spike.sharding) + self.spike.value = bm.zeros_like(self.spike) + bm.while_loop(self._body_fun, self._cond_fun, ()) + return self.spike.value + + def return_info(self): + return self.spike + + # functions + def _cond_fun(self): + i = self.i.value + return bm.logical_and(i < self.num_times, share['t'] >= self.times[i]) + + def _body_fun(self): + i = self.i.value + if isinstance(self.mode, bm.BatchingMode): + self.spike[:, self.indices[i]] = True + else: + self.spike[self.indices[i]] = True + self.i += 1 class PoissonGroup(NeuDyn): - """Poisson Neuron Group. - """ - - def __init__( - self, - size: Shape, - freqs: Union[int, float, jax.Array, bm.Array, Callable], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - spk_type: Optional[type] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - seed=None, - ): - super().__init__(size=size, - sharding=sharding, - name=name, - keep_size=keep_size, - mode=mode) - - if seed is not None: - warnings.warn('') - - # parameters - self.freqs = parameter(freqs, self.num, allow_none=False) - self.spk_type = get_spk_type(spk_type, self.mode) - - # variables - self.reset_state(self.mode) - - def update(self): - spikes = bm.random.rand_like(self.spike) <= (self.freqs * share['dt'] / 1000.) - spikes = bm.asarray(spikes, dtype=self.spk_type) - # spikes = bm.sharding.partition(spikes, self.spike.sharding) - self.spike.value = spikes - return spikes - - def return_info(self): - return self.spike - - def reset_state(self, batch_or_mode=None, **kwargs): - self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode) + """Poisson Neuron Group. + """ + + def __init__( + self, + size: Shape, + freqs: Union[int, float, jax.Array, bm.Array, Callable], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + spk_type: Optional[type] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + seed=None, + ): + super().__init__(size=size, + sharding=sharding, + name=name, + keep_size=keep_size, + mode=mode) + + if seed is not None: + warnings.warn('') + + # parameters + self.freqs = parameter(freqs, self.num, allow_none=False) + self.spk_type = get_spk_type(spk_type, self.mode) + + # variables + self.reset_state(self.mode) + + def update(self): + spikes = bm.random.rand_like(self.spike) <= (self.freqs * share['dt'] / 1000.) + spikes = bm.asarray(spikes, dtype=self.spk_type) + # spikes = bm.sharding.partition(spikes, self.spike.sharding) + self.spike.value = spikes + return spikes + + def return_info(self): + return self.spike + + def reset_state(self, batch_or_mode=None, **kwargs): + self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode) diff --git a/brainpy/_src/dyn/others/noise.py b/brainpy/_src/dyn/others/noise.py index 5af246c10..2ebc180dd 100644 --- a/brainpy/_src/dyn/others/noise.py +++ b/brainpy/_src/dyn/others/noise.py @@ -10,74 +10,74 @@ from brainpy.types import Shape, ArrayType __all__ = [ - 'OUProcess', + 'OUProcess', ] class OUProcess(NeuDyn): - r"""The Ornstein–Uhlenbeck process. - - The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following - stochastic differential equation: - - .. math:: - - \tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t} - - where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}` - denotes the Wiener process. - - Parameters:: - - size: int, sequence of int - The model size. - mean: Parameter - The noise mean value. - sigma: Parameter - The noise amplitude. - tau: Parameter - The decay time constant. - method: str - The numerical integration method for stochastic differential equation. - name: str - The model name. - """ - - def __init__( - self, - size: Shape, - mean: Union[float, ArrayType, Callable] = 0., - sigma: Union[float, ArrayType, Callable] = 1., - tau: Union[float, ArrayType, Callable] = 10., - method: str = 'exp_euler', - keep_size: bool = False, - mode: bm.Mode = None, - name: str = None, - ): - super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) - - # parameters - self.mean = parameter(mean, self.varshape, allow_none=False) - self.sigma = parameter(sigma, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - - # variables - self.reset_state(self.mode) - - # integral functions - self.integral = sdeint(f=self.df, g=self.dg, method=method) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x = variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_or_mode) - - def df(self, x, t): - return (self.mean - x) / self.tau - - def dg(self, x, t): - return self.sigma - - def update(self): - t = share.load('t') - dt = share.load('dt') - self.x.value = self.integral(self.x.value, t, dt) - return self.x.value + r"""The Ornstein–Uhlenbeck process. + + The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following + stochastic differential equation: + + .. math:: + + \tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t} + + where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}` + denotes the Wiener process. + + Parameters:: + + size: int, sequence of int + The model size. + mean: Parameter + The noise mean value. + sigma: Parameter + The noise amplitude. + tau: Parameter + The decay time constant. + method: str + The numerical integration method for stochastic differential equation. + name: str + The model name. + """ + + def __init__( + self, + size: Shape, + mean: Union[float, ArrayType, Callable] = 0., + sigma: Union[float, ArrayType, Callable] = 1., + tau: Union[float, ArrayType, Callable] = 10., + method: str = 'exp_euler', + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + ): + super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) + + # parameters + self.mean = parameter(mean, self.varshape, allow_none=False) + self.sigma = parameter(sigma, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + + # variables + self.reset_state(self.mode) + + # integral functions + self.integral = sdeint(f=self.df, g=self.dg, method=method) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x = variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_or_mode) + + def df(self, x, t): + return (self.mean - x) / self.tau + + def dg(self, x, t): + return self.sigma + + def update(self): + t = share.load('t') + dt = share.load('dt') + self.x.value = self.integral(self.x.value, t, dt) + return self.x.value diff --git a/brainpy/_src/dyn/others/tests/test_input.py b/brainpy/_src/dyn/others/tests/test_input.py index c1630c38d..bc0d45907 100644 --- a/brainpy/_src/dyn/others/tests/test_input.py +++ b/brainpy/_src/dyn/others/tests/test_input.py @@ -1,24 +1,25 @@ # -*- coding: utf-8 -*- -import brainpy as bp from absl.testing import parameterized + +import brainpy as bp from brainpy._src.dyn.others import input class Test_input(parameterized.TestCase): - def test_SpikeTimeGroup(self): - model = input.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) - runner = bp.DSRunner(model, - monitors=['spike'], - progress_bar=False) - runner.run(30.) - self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) + def test_SpikeTimeGroup(self): + model = input.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) + runner = bp.DSRunner(model, + monitors=['spike'], + progress_bar=False) + runner.run(30.) + self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) - def test_PoissonGroup(self): - model = input.PoissonGroup(size=2, freqs=1000, seed=0) - runner = bp.DSRunner(model, - monitors=['spike'], - progress_bar=False) - runner.run(30.) - self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) + def test_PoissonGroup(self): + model = input.PoissonGroup(size=2, freqs=1000, seed=0) + runner = bp.DSRunner(model, + monitors=['spike'], + progress_bar=False) + runner.run(30.) + self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) diff --git a/brainpy/_src/dyn/others/tests/test_input_groups.py b/brainpy/_src/dyn/others/tests/test_input_groups.py index 0164e2f12..364610e54 100644 --- a/brainpy/_src/dyn/others/tests/test_input_groups.py +++ b/brainpy/_src/dyn/others/tests/test_input_groups.py @@ -1,26 +1,27 @@ # -*- coding: utf-8 -*- -import brainpy as bp from absl.testing import parameterized + +import brainpy as bp from brainpy._src.dyn.others import input class Test_input_Group(parameterized.TestCase): - def test_SpikeTimeGroup(self): - bp.math.random.seed() - model = input.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) - runner = bp.DSRunner(model, - monitors=['spike'], - progress_bar=False) - runner.run(30.) - self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) + def test_SpikeTimeGroup(self): + bp.math.random.seed() + model = input.SpikeTimeGroup(size=2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) + runner = bp.DSRunner(model, + monitors=['spike'], + progress_bar=False) + runner.run(30.) + self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) - def test_PoissonGroup(self): - bp.math.random.seed() - model = input.PoissonGroup(size=2, freqs=1000) - runner = bp.DSRunner(model, - monitors=['spike'], - progress_bar=False) - runner.run(30.) - self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) + def test_PoissonGroup(self): + bp.math.random.seed() + model = input.PoissonGroup(size=2, freqs=1000) + runner = bp.DSRunner(model, + monitors=['spike'], + progress_bar=False) + runner.run(30.) + self.assertTupleEqual(runner.mon['spike'].shape, (300, 2)) diff --git a/brainpy/_src/dyn/others/tests/test_noise_groups.py b/brainpy/_src/dyn/others/tests/test_noise_groups.py index 9050d383f..0aaecc3ea 100644 --- a/brainpy/_src/dyn/others/tests/test_noise_groups.py +++ b/brainpy/_src/dyn/others/tests/test_noise_groups.py @@ -1,24 +1,24 @@ # -*- coding: utf-8 -*- +import pytest +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized -import pytest pytest.skip("Skip the test due to the jax 0.5.0 version", allow_module_level=True) class Test_Noise_Group(parameterized.TestCase): - def test_OU(self): - bm.random.seed(1234) - model = bp.dyn.OUProcess(size=1, mean=0., sigma=0.1) - runner = bp.DSRunner(model, - monitors=['x'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['x'].shape, (100, 1)) - x = runner.mon['x'] - self.assertLessEqual(abs(x.mean()), 0.1) - self.assertLessEqual(abs(x.std() - 0.1), 0.1) - + def test_OU(self): + bm.random.seed(1234) + model = bp.dyn.OUProcess(size=1, mean=0., sigma=0.1) + runner = bp.DSRunner(model, + monitors=['x'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['x'].shape, (100, 1)) + x = runner.mon['x'] + self.assertLessEqual(abs(x.mean()), 0.1) + self.assertLessEqual(abs(x.std() - 0.1), 0.1) diff --git a/brainpy/_src/dyn/outs/base.py b/brainpy/_src/dyn/outs/base.py index 35936748f..2b2a46e05 100644 --- a/brainpy/_src/dyn/outs/base.py +++ b/brainpy/_src/dyn/outs/base.py @@ -5,52 +5,53 @@ from brainpy._src.mixin import ParamDesc, BindCondData __all__ = [ - 'SynOut' + 'SynOut' ] class SynOut(DynamicalSystem, ParamDesc, BindCondData): - """Base class for synaptic outputs. - - :py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :pu:class:`~.BindCondData`. - """ - def __init__(self, - name: Optional[str] = None, - scaling: Optional[bm.Scaling] = None): - super().__init__(name=name) - self._conductance = None - if scaling is None: - self.scaling = bm.get_membrane_scaling() - else: - self.scaling = scaling - - def __call__(self, *args, **kwargs): - if self._conductance is None: - raise ValueError(f'Please first pack conductance data at the current step using ' - f'".{BindCondData.bind_cond.__name__}(data)". {self}') - ret = self.update(self._conductance, *args, **kwargs) - return ret - - def reset_state(self, *args, **kwargs): - pass - - def offset_scaling(self, x, bias=None, scale=None): - s = self.scaling.offset_scaling(x, bias=bias, scale=scale) - if isinstance(x, bm.BaseArray): - x.value = s - return x - return s - - def std_scaling(self, x, scale=None): - s = self.scaling.std_scaling(x, scale=scale) - if isinstance(x, bm.BaseArray): - x.value = s - return x - return s - - def inv_scaling(self, x, scale=None): - s = self.scaling.inv_scaling(x, scale=scale) - if isinstance(x, bm.BaseArray): - x.value = s - return x - return s + """Base class for synaptic outputs. + + :py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :pu:class:`~.BindCondData`. + """ + + def __init__(self, + name: Optional[str] = None, + scaling: Optional[bm.Scaling] = None): + super().__init__(name=name) + self._conductance = None + if scaling is None: + self.scaling = bm.get_membrane_scaling() + else: + self.scaling = scaling + + def __call__(self, *args, **kwargs): + if self._conductance is None: + raise ValueError(f'Please first pack conductance data at the current step using ' + f'".{BindCondData.bind_cond.__name__}(data)". {self}') + ret = self.update(self._conductance, *args, **kwargs) + return ret + + def reset_state(self, *args, **kwargs): + pass + + def offset_scaling(self, x, bias=None, scale=None): + s = self.scaling.offset_scaling(x, bias=bias, scale=scale) + if isinstance(x, bm.BaseArray): + x.value = s + return x + return s + + def std_scaling(self, x, scale=None): + s = self.scaling.std_scaling(x, scale=scale) + if isinstance(x, bm.BaseArray): + x.value = s + return x + return s + + def inv_scaling(self, x, scale=None): + s = self.scaling.inv_scaling(x, scale=scale) + if isinstance(x, bm.BaseArray): + x.value = s + return x + return s diff --git a/brainpy/_src/dyn/outs/outputs.py b/brainpy/_src/dyn/outs/outputs.py index aa81f0828..aace92463 100644 --- a/brainpy/_src/dyn/outs/outputs.py +++ b/brainpy/_src/dyn/outs/outputs.py @@ -7,136 +7,138 @@ from .base import SynOut __all__ = [ - 'COBA', - 'CUBA', - 'MgBlock' + 'COBA', + 'CUBA', + 'MgBlock' ] class COBA(SynOut): - r"""Conductance-based synaptic output. + r"""Conductance-based synaptic output. - Given the synaptic conductance, the model output the post-synaptic current with + Given the synaptic conductance, the model output the post-synaptic current with - .. math:: + .. math:: - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - Parameters:: + Parameters:: - E: float, ArrayType, ndarray - The reversal potential. - sharding: sequence of str - The axis names for variable for parallelization. - name: str - The model name. - scaling: brainpy.Scaling - The scaling object. + E: float, ArrayType, ndarray + The reversal potential. + sharding: sequence of str + The axis names for variable for parallelization. + name: str + The model name. + scaling: brainpy.Scaling + The scaling object. - See Also:: + See Also:: - CUBA - """ + CUBA + """ - def __init__( - self, - E: Union[float, ArrayType], - sharding: Optional[Sequence[str]] = None, - name: Optional[str] = None, - scaling: Optional[bm.Scaling] = None, - ): - super().__init__(name=name, scaling=scaling) + def __init__( + self, + E: Union[float, ArrayType], + sharding: Optional[Sequence[str]] = None, + name: Optional[str] = None, + scaling: Optional[bm.Scaling] = None, + ): + super().__init__(name=name, scaling=scaling) - self.sharding = sharding - self.E = self.offset_scaling(init.parameter(E, np.shape(E), sharding=sharding)) + self.sharding = sharding + self.E = self.offset_scaling(init.parameter(E, np.shape(E), sharding=sharding)) - def update(self, conductance, potential): - return conductance * (self.E - potential) + def update(self, conductance, potential): + return conductance * (self.E - potential) class CUBA(SynOut): - r"""Current-based synaptic output. + r"""Current-based synaptic output. - Given the conductance, this model outputs the post-synaptic current with a identity function: + Given the conductance, this model outputs the post-synaptic current with a identity function: - .. math:: + .. math:: - I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) + I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - Parameters:: + Parameters:: - name: str - The model name. - scaling: brainpy.Scaling - The scaling object. + name: str + The model name. + scaling: brainpy.Scaling + The scaling object. - See Also:: + See Also:: - COBA - """ - def __init__( - self, - name: Optional[str] = None, - scaling: Optional[bm.Scaling] = None, - ): - super().__init__(name=name, scaling=scaling) + COBA + """ - def update(self, conductance, potential=None): - return conductance + def __init__( + self, + name: Optional[str] = None, + scaling: Optional[bm.Scaling] = None, + ): + super().__init__(name=name, scaling=scaling) + + def update(self, conductance, potential=None): + return conductance class MgBlock(SynOut): - r"""Synaptic output based on Magnesium blocking. - - Given the synaptic conductance, the model output the post-synaptic current with - - .. math:: - - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) - - where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to - - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - - Parameters:: - - E: float, ArrayType - The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType - Binding constant. Default 0.062 - beta: float, ArrayType - Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType - Concentration of Magnesium ion. Default 1.2 [mM]. - sharding: sequence of str - The axis names for variable for parallelization. - name: str - The model name. - """ - def __init__( - self, - E: Union[float, ArrayType] = 0., - cc_Mg: Union[float, ArrayType] = 1.2, - alpha: Union[float, ArrayType] = 0.062, - beta: Union[float, ArrayType] = 3.57, - V_offset: Union[float, ArrayType] = 0., - sharding: Optional[Sequence[str]] = None, - name: Optional[str] = None, - scaling: Optional[bm.Scaling] = None, - ): - super().__init__(name=name, scaling=scaling) - - self.sharding = sharding - self.E = self.offset_scaling(init.parameter(E, np.shape(E), sharding=sharding)) - self.V_offset = self.offset_scaling(init.parameter(V_offset, np.shape(V_offset), sharding=sharding)) - self.cc_Mg = init.parameter(cc_Mg, np.shape(cc_Mg), sharding=sharding) - self.alpha = self.inv_scaling(init.parameter(alpha, np.shape(alpha), sharding=sharding)) - self.beta = init.parameter(beta, np.shape(beta), sharding=sharding) - - def update(self, conductance, potential): - norm = (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential))) - return conductance * (self.E - potential) / norm + r"""Synaptic output based on Magnesium blocking. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + + Parameters:: + + E: float, ArrayType + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType + Concentration of Magnesium ion. Default 1.2 [mM]. + sharding: sequence of str + The axis names for variable for parallelization. + name: str + The model name. + """ + + def __init__( + self, + E: Union[float, ArrayType] = 0., + cc_Mg: Union[float, ArrayType] = 1.2, + alpha: Union[float, ArrayType] = 0.062, + beta: Union[float, ArrayType] = 3.57, + V_offset: Union[float, ArrayType] = 0., + sharding: Optional[Sequence[str]] = None, + name: Optional[str] = None, + scaling: Optional[bm.Scaling] = None, + ): + super().__init__(name=name, scaling=scaling) + + self.sharding = sharding + self.E = self.offset_scaling(init.parameter(E, np.shape(E), sharding=sharding)) + self.V_offset = self.offset_scaling(init.parameter(V_offset, np.shape(V_offset), sharding=sharding)) + self.cc_Mg = init.parameter(cc_Mg, np.shape(cc_Mg), sharding=sharding) + self.alpha = self.inv_scaling(init.parameter(alpha, np.shape(alpha), sharding=sharding)) + self.beta = init.parameter(beta, np.shape(beta), sharding=sharding) + + def update(self, conductance, potential): + norm = (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential))) + return conductance * (self.E - potential) / norm diff --git a/brainpy/_src/dyn/projections/align_post.py b/brainpy/_src/dyn/projections/align_post.py index 9bd280f81..4ded7e027 100644 --- a/brainpy/_src/dyn/projections/align_post.py +++ b/brainpy/_src/dyn/projections/align_post.py @@ -7,310 +7,84 @@ from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost) __all__ = [ - 'HalfProjAlignPostMg', 'FullProjAlignPostMg', - 'HalfProjAlignPost', 'FullProjAlignPost', + 'HalfProjAlignPostMg', 'FullProjAlignPostMg', + 'HalfProjAlignPost', 'FullProjAlignPost', ] def get_post_repr(out_label, syn, out): - return f'{out_label} // {syn.identifier} // {out.identifier}' + return f'{out_label} // {syn.identifier} // {out.identifier}' def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name): - # synapse and output initialization - _post_repr = get_post_repr(out_label, syn_desc, out_desc) - if not post.has_bef_update(_post_repr): - syn_cls = syn_desc() - out_cls = out_desc() - # synapse and output initialization - post.add_inp_fun(proj_name, out_cls, label=out_label) - post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) - syn = post.get_bef_update(_post_repr).syn - out = post.get_bef_update(_post_repr).out - return syn, out - - -class _AlignPost(DynamicalSystem): - def __init__(self, - syn: Callable, - out: JointType[DynamicalSystem, BindCondData]): - super().__init__() - self.syn = syn - self.out = out - - def update(self, *args, **kwargs): - self.out.bind_cond(self.syn(*args, **kwargs)) - - def reset_state(self, *args, **kwargs): - pass - - -class HalfProjAlignPostMg(Projection): - r"""Defining the half part of synaptic projection with the align-post reduction and the automatic synapse merging. - - The ``half-part`` means that the model only needs to provide half information needed for a projection, - including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs - the manual providing of the spiking input. - - The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. - - The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same - parameters (such like time constants) will also share the same synaptic variables. + _post_repr = get_post_repr(out_label, syn_desc, out_desc) + if not post.has_bef_update(_post_repr): + syn_cls = syn_desc() + out_cls = out_desc() - All align-post projection models prefer to use the event-driven computation mode. This means that the - ``comm`` model should be the event-driven model. + # synapse and output initialization + post.add_inp_fun(proj_name, out_cls, label=out_label) + post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) + syn = post.get_bef_update(_post_repr).syn + out = post.get_bef_update(_post_repr).out + return syn, out - **Code Examples** - To define an E/I balanced network model. - - .. code-block:: python - - import brainpy as bp - import brainpy.math as bm - - class EINet(bp.DynSysGroup): - def __init__(self): +class _AlignPost(DynamicalSystem): + def __init__(self, + syn: Callable, + out: JointType[DynamicalSystem, BindCondData]): super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - out_label: str. The prefix of the output function. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and output initialization - syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - - # references - self.refs = dict(post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = syn - self.refs['out'] = out - self.refs['comm'] = comm # unify the access + self.syn = syn + self.out = out - def update(self, x): - current = self.comm(x) - self.refs['syn'].add_current(current) # synapse post current - return current + def update(self, *args, **kwargs): + self.out.bind_cond(self.syn(*args, **kwargs)) - syn = property(lambda self: self.refs['syn']) - out = property(lambda self: self.refs['out']) - post = property(lambda self: self.refs['post']) + def reset_state(self, *args, **kwargs): + pass -class FullProjAlignPostMg(Projection): - """Full-chain synaptic projection with the align-post reduction and the automatic synapse merging. - - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. +class HalfProjAlignPostMg(Projection): + r"""Defining the half part of synaptic projection with the align-post reduction and the automatic synapse merging. - The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + The ``half-part`` means that the model only needs to provide half information needed for a projection, + including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs + the manual providing of the spiking input. - The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same - parameters (such like time constants) will also share the same synaptic variables. + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. - All align-post projection models prefer to use the event-driven computation mode. This means that the - ``comm`` model should be the event-driven model. + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. - Moreover, it's worth noting that ``FullProjAlignPostMg`` has a different updating order with all align-pre - projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. - While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. - **Code Examples** + **Code Examples** - To define an E/I balanced network model. + To define an E/I balanced network model. - .. code-block:: python + .. code-block:: python import brainpy as bp import brainpy.math as bm - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPostMg(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPostMg(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=ni, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPostMg(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=ne, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPostMg(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # synapse and output initialization - syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - - # references - self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = syn # invisible to ``self.node()`` - self.refs['out'] = out # invisible to ``self.node()`` - # unify the access - self.refs['comm'] = comm - self.refs['delay'] = pre.get_aft_update(delay_identifier) - - def update(self): - x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) - current = self.comm(x) - self.refs['syn'].add_current(current) # synapse post current - return current - - syn = property(lambda self: self.refs['syn']) - out = property(lambda self: self.refs['out']) - delay = property(lambda self: self.refs['delay']) - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - - -class HalfProjAlignPost(Projection): - """Defining the half-part of synaptic projection with the align-post reduction. - - The ``half-part`` means that the model only needs to provide half information needed for a projection, - including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs - the manual providing of the spiking input. - - The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. - - All align-post projection models prefer to use the event-driven computation mode. This means that the - ``comm`` model should be the event-driven model. - - To simulate an E/I balanced network: - - .. code-block:: - class EINet(bp.DynSysGroup): def __init__(self): super().__init__() self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=4000, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=4000, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) + self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=4000, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=4000, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.N) def update(self, input): spk = self.delay.at('I') @@ -324,184 +98,410 @@ def update(self, input): spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) + Args: + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + out_label: str. The prefix of the output function. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and output initialization + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + + # references + self.refs = dict(post=post) # invisible to ``self.nodes()`` + self.refs['syn'] = syn + self.refs['out'] = out + self.refs['comm'] = comm # unify the access + + def update(self, x): + current = self.comm(x) + self.refs['syn'].add_current(current) # synapse post current + return current + + syn = property(lambda self: self.refs['syn']) + out = property(lambda self: self.refs['out']) + post = property(lambda self: self.refs['post']) - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - syn: JointType[DynamicalSystem, AlignPost], - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - self.out = out - # synapse and output initialization - post.add_inp_fun(self.name, out, label=out_label) +class FullProjAlignPostMg(Projection): + """Full-chain synaptic projection with the align-post reduction and the automatic synapse merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + Moreover, it's worth noting that ``FullProjAlignPostMg`` has a different updating order with all align-pre + projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. + While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. + + **Code Examples** + + To define an E/I balanced network model. + + .. code-block:: python + + import brainpy as bp + import brainpy.math as bm + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPostMg(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPostMg(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=ni, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPostMg(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=ne, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPostMg(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + + # references + self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs['syn'] = syn # invisible to ``self.node()`` + self.refs['out'] = out # invisible to ``self.node()`` + # unify the access + self.refs['comm'] = comm + self.refs['delay'] = pre.get_aft_update(delay_identifier) + + def update(self): + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + current = self.comm(x) + self.refs['syn'].add_current(current) # synapse post current + return current + + syn = property(lambda self: self.refs['syn']) + out = property(lambda self: self.refs['out']) + delay = property(lambda self: self.refs['delay']) + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) - # reference - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['post'] = post - self.refs['syn'] = syn - self.refs['out'] = out - # unify the access - self.refs['comm'] = comm - def update(self, x): - current = self.comm(x) - g = self.syn(self.comm(x)) - self.refs['out'].bind_cond(g) # synapse post current - return current +class HalfProjAlignPost(Projection): + """Defining the half-part of synaptic projection with the align-post reduction. - post = property(lambda self: self.refs['post']) + The ``half-part`` means that the model only needs to provide half information needed for a projection, + including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs + the manual providing of the spiking input. + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. -class FullProjAlignPost(Projection): - """Full-chain synaptic projection with the align-post reduction. + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. + To simulate an E/I balanced network: - The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + .. code-block:: - All align-post projection models prefer to use the event-driven computation mode. This means that the - ``comm`` model should be the event-driven model. + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=4000, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=4000, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + syn: JointType[DynamicalSystem, AlignPost], + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + self.out = out + + # synapse and output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # reference + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['post'] = post + self.refs['syn'] = syn + self.refs['out'] = out + # unify the access + self.refs['comm'] = comm + + def update(self, x): + current = self.comm(x) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current + return current + + post = property(lambda self: self.refs['post']) - Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre - projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. - While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. - To simulate and define an E/I balanced network model: +class FullProjAlignPost(Projection): + """Full-chain synaptic projection with the align-post reduction. - .. code-block:: python + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) + Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre + projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. + While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. + To simulate and define an E/I balanced network model: - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - syn: JointType[DynamicalSystem, AlignPost], - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) + .. code-block:: python - # synapse and output initialization - post.add_inp_fun(self.name, out, label=out_label) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - # unify the access - self.refs['delay'] = delay_cls - self.refs['comm'] = comm - self.refs['syn'] = syn - - def update(self): - x = self.refs['delay'].at(self.name) - g = self.syn(self.comm(x)) - self.refs['out'].bind_cond(g) # synapse post current - return g - - delay = property(lambda self: self.refs['delay']) - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - out = property(lambda self: self.refs['out']) + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + syn: JointType[DynamicalSystem, AlignPost], + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + # unify the access + self.refs['delay'] = delay_cls + self.refs['comm'] = comm + self.refs['syn'] = syn + + def update(self): + x = self.refs['delay'].at(self.name) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current + return g + + delay = property(lambda self: self.refs['delay']) + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) + out = property(lambda self: self.refs['out']) diff --git a/brainpy/_src/dyn/projections/align_pre.py b/brainpy/_src/dyn/projections/align_pre.py index 6e5cd223a..915e3ac7b 100644 --- a/brainpy/_src/dyn/projections/align_pre.py +++ b/brainpy/_src/dyn/projections/align_pre.py @@ -7,600 +7,599 @@ from .utils import _get_return __all__ = [ - 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg', - 'FullProjAlignPreSD', 'FullProjAlignPreDS', + 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg', + 'FullProjAlignPreSD', 'FullProjAlignPreDS', ] def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None): - _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' - if not delay_cls.has_bef_update(_syn_id): - # delay - delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) - # synapse - syn_cls = syn_desc() - # add to "after_updates" - delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) - syn = delay_cls.get_bef_update(_syn_id).syn - return syn + _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' + if not delay_cls.has_bef_update(_syn_id): + # delay + delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) + # synapse + syn_cls = syn_desc() + # add to "after_updates" + delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) + syn = delay_cls.get_bef_update(_syn_id).syn + return syn class _AlignPreMg(DynamicalSystem): - def __init__(self, access, syn): - super().__init__() - self.access = access - self.syn = syn + def __init__(self, access, syn): + super().__init__() + self.access = access + self.syn = syn - def update(self, *args, **kwargs): - return self.syn(self.access()) + def update(self, *args, **kwargs): + return self.syn(self.access()) - def reset_state(self, *args, **kwargs): - pass + def reset_state(self, *args, **kwargs): + pass def align_pre1_add_bef_update(syn_desc, pre): - _syn_id = f'{syn_desc.identifier} // Delay' - if not pre.has_aft_update(_syn_id): - # "syn_cls" needs an instance of "ProjAutoDelay" - syn_cls: SupportAutoDelay = syn_desc() - delay_cls = init_delay_by_return(syn_cls.return_info()) - # add to "after_updates" - pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) - delay_cls: Delay = pre.get_aft_update(_syn_id).delay - syn = pre.get_aft_update(_syn_id).syn - return delay_cls, syn + _syn_id = f'{syn_desc.identifier} // Delay' + if not pre.has_aft_update(_syn_id): + # "syn_cls" needs an instance of "ProjAutoDelay" + syn_cls: SupportAutoDelay = syn_desc() + delay_cls = init_delay_by_return(syn_cls.return_info()) + # add to "after_updates" + pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) + delay_cls: Delay = pre.get_aft_update(_syn_id).delay + syn = pre.get_aft_update(_syn_id).syn + return delay_cls, syn class _AlignPre(DynamicalSystem): - def __init__(self, syn, delay=None): - super().__init__() - self.syn = syn - self.delay = delay + def __init__(self, syn, delay=None): + super().__init__() + self.syn = syn + self.delay = delay - def update(self, x): - if self.delay is None: - return x >> self.syn - else: - return x >> self.syn >> self.delay + def update(self, x): + if self.delay is None: + return x >> self.syn + else: + return x >> self.syn >> self.delay - def reset_state(self, *args, **kwargs): - pass + def reset_state(self, *args, **kwargs): + pass class FullProjAlignPreSDMg(Projection): - """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging. + """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the + synapse states to the delay model, and finally computes the synaptic current. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], + delay: Union[None, int, float], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and delay initialization + delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + self.refs['syn'] = syn_cls + # unify the access + self.refs['comm'] = comm + + def update(self, x=None): + if x is None: + x = self.refs['delay'].at(self.name) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) + syn = property(lambda self: self.refs['syn']) + delay = property(lambda self: self.refs['delay']) + out = property(lambda self: self.refs['out']) - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. - The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. +class FullProjAlignPreDSMg(Projection): + """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``. + Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the + spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + syn: The synaptic dynamics. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: ParamDescriber[DynamicalSystem], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + + # synapse initialization + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to `self.nodes()` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['syn'] = syn_cls + self.refs['out'] = out + # unify the access + self.refs['comm'] = comm + + def update(self): + x = _get_return(self.refs['syn'].return_info()) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) + syn = property(lambda self: self.refs['syn']) + out = property(lambda self: self.refs['out']) + + +class FullProjAlignPreSD(Projection): + """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating. - The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the - synapse states to the delay model, and finally computes the synaptic current. + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. - The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same - parameters (such like time constants) will also share the same synaptic variables. + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. - Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation. - This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather - than the spiking. To facilitate the event-driven computation, please use align post projections. + The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the + synapse states to the delay model, and finally computes the synaptic current. - To simulate an E/I balanced network model: + Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. - .. code-block:: python - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) - self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) - self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) - self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), delay=0.1, comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], - delay: Union[None, int, float], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and delay initialization - delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) - delay_cls.register_entry(self.name, delay) - - # output initialization - post.add_inp_fun(self.name, out, label=out_label) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - self.refs['syn'] = syn_cls - # unify the access - self.refs['comm'] = comm - - def update(self, x=None): - if x is None: - x = self.refs['delay'].at(self.name) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - syn = property(lambda self: self.refs['syn']) - delay = property(lambda self: self.refs['delay']) - out = property(lambda self: self.refs['out']) + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: DynamicalSystem, + syn: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and delay initialization + delay_cls = init_delay_by_return(syn.return_info()) + delay_cls.register_entry(self.name, delay) + pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + self.refs['syn'] = syn + # unify the access + self.refs['comm'] = comm + + def update(self, x=None): + if x is None: + x = self.refs['delay'].at(self.name) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) + syn = property(lambda self: self.refs['syn']) + delay = property(lambda self: self.refs['delay']) + out = property(lambda self: self.refs['out']) -class FullProjAlignPreDSMg(Projection): - """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging. - - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``. - Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged. +class FullProjAlignPreDS(Projection): + """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating. - The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged. - The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the - spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. - The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same - parameters (such like time constants) will also share the same synaptic variables. + The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the + spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. - Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation. - This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather - than the spiking. To facilitate the event-driven computation, please use align post projections. + Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. - To simulate an E/I balanced network model: + To simulate an E/I balanced network model: - .. code-block:: python + .. code-block:: python - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E, delay=0.1, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) - self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E, delay=0.1, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) - self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I, delay=0.1, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) - self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I, delay=0.1, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: ParamDescriber[DynamicalSystem], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescriber[DynamicalSystem]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # delay initialization - delay_cls = register_delay_by_return(pre) - - # synapse initialization - syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) - - # output initialization - post.add_inp_fun(self.name, out, label=out_label) - - # references - self.refs = dict() - # invisible to `self.nodes()` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['syn'] = syn_cls - self.refs['out'] = out - # unify the access - self.refs['comm'] = comm - - def update(self): - x = _get_return(self.refs['syn'].return_info()) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - syn = property(lambda self: self.refs['syn']) - out = property(lambda self: self.refs['out']) - - -class FullProjAlignPreSD(Projection): - """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating. - - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. - - The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. - - The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the - synapse states to the delay model, and finally computes the synaptic current. - - Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation. - This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather - than the spiking. To facilitate the event-driven computation, please use align post projections. - - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: DynamicalSystem, - syn: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and delay initialization - delay_cls = init_delay_by_return(syn.return_info()) - delay_cls.register_entry(self.name, delay) - pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) - - # output initialization - post.add_inp_fun(self.name, out, label=out_label) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - self.refs['syn'] = syn - # unify the access - self.refs['comm'] = comm - - def update(self, x=None): - if x is None: - x = self.refs['delay'].at(self.name) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - syn = property(lambda self: self.refs['syn']) - delay = property(lambda self: self.refs['delay']) - out = property(lambda self: self.refs['out']) - - -class FullProjAlignPreDS(Projection): - """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating. - - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. - Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged. - - The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. - - The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the - spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. - - Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation. - This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather - than the spiking. To facilitate the event-driven computation, please use align post projections. - - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: DynamicalSystem, - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, DynamicalSystem) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # output initialization - post.add_inp_fun(self.name, out, label=out_label) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - # unify the access - self.refs['syn'] = syn - self.refs['comm'] = comm - - def update(self): - spk = self.refs['delay'].at(self.name) - g = self.comm(self.syn(spk)) - self.refs['out'].bind_cond(g) - return g - - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - delay = property(lambda self: self.refs['delay']) - out = property(lambda self: self.refs['out']) - + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + syn: The synaptic dynamics. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: DynamicalSystem, + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(syn, DynamicalSystem) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + # unify the access + self.refs['syn'] = syn + self.refs['comm'] = comm + + def update(self): + spk = self.refs['delay'].at(self.name) + g = self.comm(self.syn(spk)) + self.refs['out'].bind_cond(g) + return g + + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) + delay = property(lambda self: self.refs['delay']) + out = property(lambda self: self.refs['out']) diff --git a/brainpy/_src/dyn/projections/base.py b/brainpy/_src/dyn/projections/base.py index 44a2273a4..b7946fb37 100644 --- a/brainpy/_src/dyn/projections/base.py +++ b/brainpy/_src/dyn/projections/base.py @@ -3,10 +3,9 @@ def _get_return(return_info): - if isinstance(return_info, bm.Variable): - return return_info.value - elif isinstance(return_info, ReturnInfo): - return return_info.get_data() - else: - raise NotImplementedError - + if isinstance(return_info, bm.Variable): + return return_info.value + elif isinstance(return_info, ReturnInfo): + return return_info.get_data() + else: + raise NotImplementedError diff --git a/brainpy/_src/dyn/projections/conn.py b/brainpy/_src/dyn/projections/conn.py index b14f1e129..e622726c2 100644 --- a/brainpy/_src/dyn/projections/conn.py +++ b/brainpy/_src/dyn/projections/conn.py @@ -9,98 +9,97 @@ from brainpy.types import ArrayType __all__ = [ - 'SynConn', + 'SynConn', ] class SynConn(Projection): - """Base class to model two-end synaptic connections. - - Parameters:: - - pre : NeuGroup - Pre-synaptic neuron group. - post : NeuGroup - Post-synaptic neuron group. - conn : optional, ndarray, ArrayType, dict, TwoEndConnector - The connection method between pre- and post-synaptic groups. - name : str, optional - The name of the dynamic system. - """ - - def __init__( - self, - pre: DynamicalSystem, - post: DynamicalSystem, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # pre or post neuron group - # ------------------------ - if not isinstance(pre, DynamicalSystem): - raise TypeError('"pre" must be an instance of DynamicalSystem.') - if not isinstance(post, DynamicalSystem): - raise TypeError('"post" must be an instance of DynamicalSystem.') - self.pre = pre - self.post = post - - # connectivity - # ------------ - if isinstance(conn, TwoEndConnector): - self.conn = conn(pre.size, post.size) - elif isinstance(conn, (bm.BaseArray, np.ndarray, jax.Array)): - if (pre.num, post.num) != conn.shape: - raise ValueError(f'"conn" is provided as a matrix, and it is expected ' - f'to be an array with shape of (pre.num, post.num) = ' - f'{(pre.num, post.num)}, however we got {conn.shape}') - self.conn = MatConn(conn_mat=conn) - elif isinstance(conn, dict): - if not ('i' in conn and 'j' in conn): - raise ValueError(f'"conn" is provided as a dict, and it is expected to ' - f'be a dictionary with "i" and "j" specification, ' - f'however we got {conn}') - self.conn = IJConn(i=conn['i'], j=conn['j']) - elif isinstance(conn, str): - self.conn = conn - elif conn is None: - self.conn = None - else: - raise ValueError(f'Unknown "conn" type: {conn}') - - def __repr__(self): - names = self.__class__.__name__ - return (f'{names}(name={self.name}, mode={self.mode}, \n' - f'{" " * len(names)} pre={self.pre}, \n' - f'{" " * len(names)} post={self.post})') - - def check_pre_attrs(self, *attrs): - """Check whether pre group satisfies the requirement.""" - if not hasattr(self, 'pre'): - raise ValueError('Please call __init__ function first.') - for attr in attrs: - if not isinstance(attr, str): - raise TypeError(f'Must be string. But got {attr}.') - if not hasattr(self.pre, attr): - raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".') - - def check_post_attrs(self, *attrs): - """Check whether post group satisfies the requirement.""" - if not hasattr(self, 'post'): - raise ValueError('Please call __init__ function first.') - for attr in attrs: - if not isinstance(attr, str): - raise TypeError(f'Must be string. But got {attr}.') - if not hasattr(self.post, attr): - raise ValueError(f'{self} need "post" neuron group has attribute "{attr}".') - - def update(self, *args, **kwargs): - """The function to specify the updating rule. - - Assume any dynamical system depends on the shared variables (`sha`), - like time variable ``t``, the step precision ``dt``, and the time step `i`. + """Base class to model two-end synaptic connections. + + Parameters:: + + pre : NeuGroup + Pre-synaptic neuron group. + post : NeuGroup + Post-synaptic neuron group. + conn : optional, ndarray, ArrayType, dict, TwoEndConnector + The connection method between pre- and post-synaptic groups. + name : str, optional + The name of the dynamic system. """ - raise NotImplementedError('Must implement "update" function by subclass self.') + def __init__( + self, + pre: DynamicalSystem, + post: DynamicalSystem, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # pre or post neuron group + # ------------------------ + if not isinstance(pre, DynamicalSystem): + raise TypeError('"pre" must be an instance of DynamicalSystem.') + if not isinstance(post, DynamicalSystem): + raise TypeError('"post" must be an instance of DynamicalSystem.') + self.pre = pre + self.post = post + + # connectivity + # ------------ + if isinstance(conn, TwoEndConnector): + self.conn = conn(pre.size, post.size) + elif isinstance(conn, (bm.BaseArray, np.ndarray, jax.Array)): + if (pre.num, post.num) != conn.shape: + raise ValueError(f'"conn" is provided as a matrix, and it is expected ' + f'to be an array with shape of (pre.num, post.num) = ' + f'{(pre.num, post.num)}, however we got {conn.shape}') + self.conn = MatConn(conn_mat=conn) + elif isinstance(conn, dict): + if not ('i' in conn and 'j' in conn): + raise ValueError(f'"conn" is provided as a dict, and it is expected to ' + f'be a dictionary with "i" and "j" specification, ' + f'however we got {conn}') + self.conn = IJConn(i=conn['i'], j=conn['j']) + elif isinstance(conn, str): + self.conn = conn + elif conn is None: + self.conn = None + else: + raise ValueError(f'Unknown "conn" type: {conn}') + + def __repr__(self): + names = self.__class__.__name__ + return (f'{names}(name={self.name}, mode={self.mode}, \n' + f'{" " * len(names)} pre={self.pre}, \n' + f'{" " * len(names)} post={self.post})') + + def check_pre_attrs(self, *attrs): + """Check whether pre group satisfies the requirement.""" + if not hasattr(self, 'pre'): + raise ValueError('Please call __init__ function first.') + for attr in attrs: + if not isinstance(attr, str): + raise TypeError(f'Must be string. But got {attr}.') + if not hasattr(self.pre, attr): + raise ValueError(f'{self} need "pre" neuron group has attribute "{attr}".') + + def check_post_attrs(self, *attrs): + """Check whether post group satisfies the requirement.""" + if not hasattr(self, 'post'): + raise ValueError('Please call __init__ function first.') + for attr in attrs: + if not isinstance(attr, str): + raise TypeError(f'Must be string. But got {attr}.') + if not hasattr(self.post, attr): + raise ValueError(f'{self} need "post" neuron group has attribute "{attr}".') + + def update(self, *args, **kwargs): + """The function to specify the updating rule. + + Assume any dynamical system depends on the shared variables (`sha`), + like time variable ``t``, the step precision ``dt``, and the time step `i`. + """ + raise NotImplementedError('Must implement "update" function by subclass self.') diff --git a/brainpy/_src/dyn/projections/delta.py b/brainpy/_src/dyn/projections/delta.py index 19e4938cb..d9a47a459 100644 --- a/brainpy/_src/dyn/projections/delta.py +++ b/brainpy/_src/dyn/projections/delta.py @@ -6,205 +6,205 @@ from brainpy._src.mixin import (JointType, SupportAutoDelay) __all__ = [ - 'HalfProjDelta', 'FullProjDelta', + 'HalfProjDelta', 'FullProjDelta', ] class _Delta: - def __init__(self): - self._cond = None + def __init__(self): + self._cond = None - def bind_cond(self, cond): - self._cond = cond + def bind_cond(self, cond): + self._cond = cond - def __call__(self, *args, **kwargs): - r = self._cond - return r + def __call__(self, *args, **kwargs): + r = self._cond + return r class HalfProjDelta(Projection): - """Defining the half-part of the synaptic projection for the Delta synapse model. + """Defining the half-part of the synaptic projection for the Delta synapse model. - The synaptic projection requires the input is the spiking data, otherwise - the synapse is not the Delta synapse model. + The synaptic projection requires the input is the spiking data, otherwise + the synapse is not the Delta synapse model. - The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``. - Therefore, the model's ``update`` function needs the manual providing of the spiking input. + The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``. + Therefore, the model's ``update`` function needs the manual providing of the spiking input. - **Model Descriptions** + **Model Descriptions** - .. math:: + .. math:: - I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) - where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, - :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, - :math:`C` the set of neurons connected to the post-synaptic neuron, - and :math:`D` the transmission delay of chemical synapses. - For simplicity, the rise and decay phases of post-synaptic currents are - omitted in this model. + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + and :math:`D` the transmission delay of chemical synapses. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. - **Code Examples** + **Code Examples** - .. code-block:: + .. code-block:: - import brainpy as bp - import brainpy.math as bm + import brainpy as bp + import brainpy.math as bm - class Net(bp.DynamicalSystem): - def __init__(self): - super().__init__() + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() - self.pre = bp.dyn.PoissonGroup(10, 100.) - self.post = bp.dyn.LifRef(1) - self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) - def update(self): - self.syn(self.pre()) - self.post() - return self.post.V.value + def update(self): + self.syn(self.pre()) + self.post() + return self.post.V.value - net = Net() - indices = bm.arange(1000).to_numpy() - vs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices, vs, show=True) + net = Net() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) - Args: - comm: DynamicalSystem. The synaptic communication. - post: DynamicalSystem. The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ + Args: + comm: DynamicalSystem. The synaptic communication. + post: DynamicalSystem. The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ - def __init__( - self, - comm: DynamicalSystem, - post: DynamicalSystem, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) + def __init__( + self, + comm: DynamicalSystem, + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(post, DynamicalSystem) - self.comm = comm + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(post, DynamicalSystem) + self.comm = comm - # output initialization - out = _Delta() - post.add_inp_fun(self.name, out, category='delta') + # output initialization + out = _Delta() + post.add_inp_fun(self.name, out, category='delta') - # references - self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` - self.refs['comm'] = comm # unify the access + # references + self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access - def update(self, x): - # call the communication - current = self.comm(x) - # bind the output - self.refs['out'].bind_cond(current) - # return the current, if needed - return current + def update(self, x): + # call the communication + current = self.comm(x) + # bind the output + self.refs['out'].bind_cond(current) + # return the current, if needed + return current class FullProjDelta(Projection): - """Full-chain of the synaptic projection for the Delta synapse model. + """Full-chain of the synaptic projection for the Delta synapse model. - The synaptic projection requires the input is the spiking data, otherwise - the synapse is not the Delta synapse model. + The synaptic projection requires the input is the spiking data, otherwise + the synapse is not the Delta synapse model. - The ``full-chain`` means that the model needs to provide all information needed for a projection, - including ``pre`` -> ``delay`` -> ``comm`` -> ``post``. + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``post``. - **Model Descriptions** + **Model Descriptions** - .. math:: + .. math:: - I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) - where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, - :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, - :math:`C` the set of neurons connected to the post-synaptic neuron, - and :math:`D` the transmission delay of chemical synapses. - For simplicity, the rise and decay phases of post-synaptic currents are - omitted in this model. + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + and :math:`D` the transmission delay of chemical synapses. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. - **Code Examples** + **Code Examples** - .. code-block:: + .. code-block:: - import brainpy as bp - import brainpy.math as bm + import brainpy as bp + import brainpy.math as bm - class Net(bp.DynamicalSystem): - def __init__(self): - super().__init__() + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() - self.pre = bp.dyn.PoissonGroup(10, 100.) - self.post = bp.dyn.LifRef(1) - self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) - def update(self): - self.syn() - self.pre() - self.post() - return self.post.V.value + def update(self): + self.syn() + self.pre() + self.post() + return self.post.V.value - net = Net() - indices = bm.arange(1000).to_numpy() - vs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices, vs, show=True) + net = Net() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: DynamicalSystem. The synaptic communication. - post: DynamicalSystem. The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: DynamicalSystem. The synaptic communication. + post: DynamicalSystem. The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - post: DynamicalSystem, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(post, DynamicalSystem) - self.comm = comm + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(post, DynamicalSystem) + self.comm = comm - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) - # output initialization - out = _Delta() - post.add_inp_fun(self.name, out, category='delta') + # output initialization + out = _Delta() + post.add_inp_fun(self.name, out, category='delta') - # references - self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` - self.refs['comm'] = comm # unify the access - self.refs['delay'] = pre.get_aft_update(delay_identifier) + # references + self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + self.refs['delay'] = pre.get_aft_update(delay_identifier) - def update(self): - # get delay - x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) - # call the communication - current = self.comm(x) - # bind the output - self.refs['out'].bind_cond(current) - # return the current, if needed - return current + def update(self): + # get delay + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + # call the communication + current = self.comm(x) + # bind the output + self.refs['out'].bind_cond(current) + # return the current, if needed + return current diff --git a/brainpy/_src/dyn/projections/inputs.py b/brainpy/_src/dyn/projections/inputs.py index dd1e1e3df..f04f02583 100644 --- a/brainpy/_src/dyn/projections/inputs.py +++ b/brainpy/_src/dyn/projections/inputs.py @@ -10,158 +10,158 @@ from brainpy.types import Shape __all__ = [ - 'InputVar', - 'PoissonInput', + 'InputVar', + 'PoissonInput', ] class InputVar(Dynamic, SupportAutoDelay): - """Define an input variable. + """Define an input variable. - Example:: + Example:: - import brainpy as bp + import brainpy as bp - class Exponential(bp.Projection): - def __init__(self, pre, post, prob, g_max, tau, E=0.): - super().__init__() - self.proj = bp.dyn.ProjAlignPostMg2( - pre=pre, - delay=None, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - syn=bp.dyn.Expon.desc(post.num, tau=tau), - out=bp.dyn.COBA.desc(E=E), - post=post, - ) + class Exponential(bp.Projection): + def __init__(self, pre, post, prob, g_max, tau, E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg2( + pre=pre, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + syn=bp.dyn.Expon.desc(post.num, tau=tau), + out=bp.dyn.COBA.desc(E=E), + post=post, + ) - class EINet(bp.DynSysGroup): - def __init__(self, num_exc, num_inh, method='exp_auto'): - super(EINet, self).__init__() + class EINet(bp.DynSysGroup): + def __init__(self, num_exc, num_inh, method='exp_auto'): + super(EINet, self).__init__() - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.), method=method) - self.E = bp.dyn.LifRef(num_exc, **pars) - self.I = bp.dyn.LifRef(num_inh, **pars) + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), method=method) + self.E = bp.dyn.LifRef(num_exc, **pars) + self.I = bp.dyn.LifRef(num_inh, **pars) - # synapses - w_e = 0.6 # excitatory synaptic weight - w_i = 6.7 # inhibitory synaptic weight + # synapses + w_e = 0.6 # excitatory synaptic weight + w_i = 6.7 # inhibitory synaptic weight - # Neurons connect to each other randomly with a connection probability of 2% - self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) - self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) - self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) - self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) + # Neurons connect to each other randomly with a connection probability of 2% + self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) + self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) + self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) + self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) - # define input variables given to E/I populations - self.Ein = bp.dyn.InputVar(self.E.varshape) - self.Iin = bp.dyn.InputVar(self.I.varshape) - self.E.add_inp_fun('', self.Ein) - self.I.add_inp_fun('', self.Iin) + # define input variables given to E/I populations + self.Ein = bp.dyn.InputVar(self.E.varshape) + self.Iin = bp.dyn.InputVar(self.I.varshape) + self.E.add_inp_fun('', self.Ein) + self.I.add_inp_fun('', self.Iin) - net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method - runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) - runner.run(100.) + net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method + runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) + runner.run(100.) - # visualization - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], - title='Spikes of Excitatory Neurons', show=True) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], - title='Spikes of Inhibitory Neurons', show=True) + # visualization + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], + title='Spikes of Excitatory Neurons', show=True) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], + title='Spikes of Inhibitory Neurons', show=True) - """ + """ - def __init__( - self, - size: Shape, - keep_size: bool = False, - sharding: Optional[Any] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - method: str = 'exp_auto' - ): - super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) + def __init__( + self, + size: Shape, + keep_size: bool = False, + sharding: Optional[Any] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + method: str = 'exp_auto' + ): + super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) - self.reset_state(self.mode) + self.reset_state(self.mode) - def reset_state(self, batch_or_mode=None, **kwargs): - self.input = self.init_variable(bm.zeros, batch_or_mode) + def reset_state(self, batch_or_mode=None, **kwargs): + self.input = self.init_variable(bm.zeros, batch_or_mode) - def update(self, *args, **kwargs): - return self.input.value + def update(self, *args, **kwargs): + return self.input.value - def return_info(self): - return self.input + def return_info(self): + return self.input - def clear_input(self, *args, **kwargs): - self.reset_state(self.mode) + def clear_input(self, *args, **kwargs): + self.reset_state(self.mode) class PoissonInput(Projection): - """Poisson Input to the given :py:class:`~.Variable`. - - Adds independent Poisson input to a target variable. For large - numbers of inputs, this is much more efficient than creating a - `PoissonGroup`. The synaptic events are generated randomly during the - simulation and are not preloaded and stored in memory. All the inputs must - target the same variable, have the same frequency and same synaptic weight. - All neurons in the target variable receive independent realizations of - Poisson spike trains. - - Args: - target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. - num_input: The number of inputs. - freq: The frequency of each of the inputs. Must be a scalar. - weight: The synaptic weight. Must be a scalar. - name: The target name. - mode: The computing mode. - """ - - def __init__( - self, - target_var: bm.Variable, - num_input: int, - freq: Union[int, float], - weight: Union[int, float], - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - if not isinstance(target_var, bm.Variable): - raise TypeError(f'"target_var" must be an instance of Variable. ' - f'But we got {type(target_var)}: {target_var}') - self.target_var = target_var - self.num_input = check.is_integer(num_input, min_bound=1) - self.freq = check.is_float(freq, min_bound=0., allow_int=True) - self.weight = check.is_float(weight, allow_int=True) - - def reset_state(self, *args, **kwargs): - pass - - def update(self): - p = self.freq * share['dt'] / 1e3 - a = self.num_input * p - b = self.num_input * (1 - p) - - if isinstance(share['dt'], numbers.Number): # dt is not traced - if (a > 5) and (b > 5): - inp = bm.random.normal(a, b * p, self.target_var.shape) - else: - inp = bm.random.binomial(self.num_input, p, self.target_var.shape) - - else: # dt is traced - inp = bm.cond((a > 5) * (b > 5), - lambda: bm.random.normal(a, b * p, self.target_var.shape), - lambda: bm.random.binomial(self.num_input, p, self.target_var.shape)) - - # inp = bm.sharding.partition(inp, self.target_var.sharding) - self.target_var += inp * self.weight - - def __repr__(self): - return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})' + """Poisson Input to the given :py:class:`~.Variable`. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Args: + target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. + num_input: The number of inputs. + freq: The frequency of each of the inputs. Must be a scalar. + weight: The synaptic weight. Must be a scalar. + name: The target name. + mode: The computing mode. + """ + + def __init__( + self, + target_var: bm.Variable, + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + if not isinstance(target_var, bm.Variable): + raise TypeError(f'"target_var" must be an instance of Variable. ' + f'But we got {type(target_var)}: {target_var}') + self.target_var = target_var + self.num_input = check.is_integer(num_input, min_bound=1) + self.freq = check.is_float(freq, min_bound=0., allow_int=True) + self.weight = check.is_float(weight, allow_int=True) + + def reset_state(self, *args, **kwargs): + pass + + def update(self): + p = self.freq * share['dt'] / 1e3 + a = self.num_input * p + b = self.num_input * (1 - p) + + if isinstance(share['dt'], numbers.Number): # dt is not traced + if (a > 5) and (b > 5): + inp = bm.random.normal(a, b * p, self.target_var.shape) + else: + inp = bm.random.binomial(self.num_input, p, self.target_var.shape) + + else: # dt is traced + inp = bm.cond((a > 5) * (b > 5), + lambda: bm.random.normal(a, b * p, self.target_var.shape), + lambda: bm.random.binomial(self.num_input, p, self.target_var.shape)) + + # inp = bm.sharding.partition(inp, self.target_var.sharding) + self.target_var += inp * self.weight + + def __repr__(self): + return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})' diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 439b6eb6c..47a1d6a0c 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -12,7 +12,7 @@ from .utils import (_get_return, ) __all__ = [ - 'STDP_Song2000', + 'STDP_Song2000', ] @@ -21,211 +21,212 @@ def _init_trace_by_align_pre2( delay: Union[None, int, float], syn: ParamDescriber[DynamicalSystem], ): - """Calculate the trace of the target by reusing the existing connections.""" - check.is_instance(target, DynamicalSystem) - check.is_instance(syn, ParamDescriber[DynamicalSystem]) - # delay initialization - delay_cls = register_delay_by_return(target) - # synapse initialization - syn = align_pre2_add_bef_update(syn, delay, delay_cls) - return syn - - -class STDP_Song2000(Projection): - r"""Spike-time-dependent plasticity proposed by (Song, et. al, 2000). - - This model filters the synaptic currents according to the variables: :math:`w`. - - .. math:: - - I_{syn}^+(t) = I_{syn}^-(t) * w - - where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before - and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse, - the conductance of the synapse will increase w. - - The dynamics of :math:`w` is governed by the following equation: - - .. math:: - - \begin{aligned} - \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ - \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\ - \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\ - \end{aligned} - - where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment - of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike. - - Here is an example of the usage of this class:: - - import brainpy as bp - import brainpy.math as bm - - class STDPNet(bp.DynamicalSystem): - def __init__(self, num_pre, num_post): - super().__init__() - self.pre = bp.dyn.LifRef(num_pre) - self.post = bp.dyn.LifRef(num_post) - self.syn = bp.dyn.STDP_Song2000( - pre=self.pre, - delay=1., - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(max_val=0.1)), - syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.post, - tau_s=16.8, - tau_t=33.7, - A1=0.96, - A2=0.53, - ) - - def update(self, I_pre, I_post): - self.syn() - self.pre(I_pre) - self.post(I_post) - conductance = self.syn.refs['syn'].g - Apre = self.syn.refs['pre_trace'].g - Apost = self.syn.refs['post_trace'].g - current = self.post.sum_inputs(self.post.V) - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight - - duration = 300. - I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) - I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) - - net = STDPNet(1, 1) - def run(i, I_pre, I_post): - pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) - return pre_spike, post_spike, g, Apre, Apost, current, W - - indices = bm.arange(0, duration, bm.dt) - pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - - Args: - tau_s: float. The time constant of :math:`A_{pre}`. - tau_t: float. The time constant of :math:`A_{post}`. - A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. - A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. - W_max: float. The maximum weight. - W_min: float. The minimum weight. - pre: DynamicalSystem. The pre-synaptic neuron group. - delay: int, float. The pre spike delay length. (ms) - syn: DynamicalSystem. The synapse model. - comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. - out: DynamicalSystem. The synaptic current output models. - post: DynamicalSystem. The post-synaptic neuron group. - out_label: str. The output label. - name: str. The model name. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: ParamDescriber[DynamicalSystem], - comm: JointType[DynamicalSystem, SupportSTDP], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - # synapse parameters - tau_s: Union[float, ArrayType, Callable] = 16.8, - tau_t: Union[float, ArrayType, Callable] = 33.7, - A1: Union[float, ArrayType, Callable] = 0.96, - A2: Union[float, ArrayType, Callable] = 0.53, - W_max: Optional[float] = None, - W_min: Optional[float] = None, - # others - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) + """Calculate the trace of the target by reusing the existing connections.""" + check.is_instance(target, DynamicalSystem) check.is_instance(syn, ParamDescriber[DynamicalSystem]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.pre_num = pre.num - self.post_num = post.num - self.comm = comm - self._is_align_post = issubclass(syn.cls, AlignPost) - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # synapse and output initialization - if self._is_align_post: - syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, - proj_name=self.name) - else: - syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') - out_cls = out() - post.add_inp_fun(self.name, out_cls, label=out_label) - - # references - self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` - self.refs['delay'] = delay_cls - self.refs['syn'] = syn_cls # invisible to ``self.node()`` - self.refs['out'] = out_cls # invisible to ``self.node()`` - self.refs['comm'] = comm - - # tracing pre-synaptic spikes using Exponential model - self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) - - # tracing post-synaptic spikes using Exponential model - self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) - - # synapse parameters - self.W_max = W_max - self.W_min = W_min - self.tau_s = tau_s - self.tau_t = tau_t - self.A1 = A1 - self.A2 = A2 - - pre = property(lambda self: self.refs['pre']) - post = property(lambda self: self.refs['post']) - syn = property(lambda self: self.refs['syn']) - delay = property(lambda self: self.refs['delay']) - out = property(lambda self: self.refs['out']) - - def update(self): - # pre-synaptic spikes - pre_spike = self.refs['delay'].at(self.name) # spike - # pre-synaptic variables - if self._is_align_post: - # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance - x = pre_spike - else: - # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance - x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable - - # post spikes - if not hasattr(self.refs['post'], 'spike'): - raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') - post_spike = self.refs['post'].spike - - # weight updates - Apost = self.refs['post_trace'].g - self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max) - Apre = self.refs['pre_trace'].g - self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max) - - # synaptic currents - current = self.comm(x) - if self._is_align_post: - self.refs['syn'].add_current(current) # synapse post current - else: - self.refs['out'].bind_cond(current) # align pre - return current + delay_cls = register_delay_by_return(target) + # synapse initialization + syn = align_pre2_add_bef_update(syn, delay, delay_cls) + return syn +class STDP_Song2000(Projection): + r"""Spike-time-dependent plasticity proposed by (Song, et. al, 2000). + + This model filters the synaptic currents according to the variables: :math:`w`. + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * w + + where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse, + the conductance of the synapse will increase w. + + The dynamics of :math:`w` is governed by the following equation: + + .. math:: + + \begin{aligned} + \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ + \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\ + \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\ + \end{aligned} + + where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment + of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike. + + Here is an example of the usage of this class:: + + import brainpy as bp + import brainpy.math as bm + + class STDPNet(bp.DynamicalSystem): + def __init__(self, num_pre, num_post): + super().__init__() + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) + self.syn = bp.dyn.STDP_Song2000( + pre=self.pre, + delay=1., + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(max_val=0.1)), + syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.post, + tau_s=16.8, + tau_t=33.7, + A1=0.96, + A2=0.53, + ) + + def update(self, I_pre, I_post): + self.syn() + self.pre(I_pre) + self.post(I_post) + conductance = self.syn.refs['syn'].g + Apre = self.syn.refs['pre_trace'].g + Apost = self.syn.refs['post_trace'].g + current = self.post.sum_inputs(self.post.V) + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight + + duration = 300. + I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) + I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) + + net = STDPNet(1, 1) + def run(i, I_pre, I_post): + pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) + return pre_spike, post_spike, g, Apre, Apost, current, W + + indices = bm.arange(0, duration, bm.dt) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) + + Args: + tau_s: float. The time constant of :math:`A_{pre}`. + tau_t: float. The time constant of :math:`A_{post}`. + A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. + A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. + W_max: float. The maximum weight. + W_min: float. The minimum weight. + pre: DynamicalSystem. The pre-synaptic neuron group. + delay: int, float. The pre spike delay length. (ms) + syn: DynamicalSystem. The synapse model. + comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. + out: DynamicalSystem. The synaptic current output models. + post: DynamicalSystem. The post-synaptic neuron group. + out_label: str. The output label. + name: str. The model name. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: ParamDescriber[DynamicalSystem], + comm: JointType[DynamicalSystem, SupportSTDP], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + # synapse parameters + tau_s: Union[float, ArrayType, Callable] = 16.8, + tau_t: Union[float, ArrayType, Callable] = 33.7, + A1: Union[float, ArrayType, Callable] = 0.96, + A2: Union[float, ArrayType, Callable] = 0.53, + W_max: Optional[float] = None, + W_min: Optional[float] = None, + # others + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.pre_num = pre.num + self.post_num = post.num + self.comm = comm + self._is_align_post = issubclass(syn.cls, AlignPost) + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + if self._is_align_post: + syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, + proj_name=self.name) + else: + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') + out_cls = out() + post.add_inp_fun(self.name, out_cls, label=out_label) + + # references + self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs['delay'] = delay_cls + self.refs['syn'] = syn_cls # invisible to ``self.node()`` + self.refs['out'] = out_cls # invisible to ``self.node()`` + self.refs['comm'] = comm + + # tracing pre-synaptic spikes using Exponential model + self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) + + # tracing post-synaptic spikes using Exponential model + self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) + + # synapse parameters + self.W_max = W_max + self.W_min = W_min + self.tau_s = tau_s + self.tau_t = tau_t + self.A1 = A1 + self.A2 = A2 + + pre = property(lambda self: self.refs['pre']) + post = property(lambda self: self.refs['post']) + syn = property(lambda self: self.refs['syn']) + delay = property(lambda self: self.refs['delay']) + out = property(lambda self: self.refs['out']) + + def update(self): + # pre-synaptic spikes + pre_spike = self.refs['delay'].at(self.name) # spike + # pre-synaptic variables + if self._is_align_post: + # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance + x = pre_spike + else: + # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance + x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable + + # post spikes + if not hasattr(self.refs['post'], 'spike'): + raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') + post_spike = self.refs['post'].spike + + # weight updates + Apost = self.refs['post_trace'].g + self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, + w_max=self.W_max) + Apre = self.refs['pre_trace'].g + self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, + w_max=self.W_max) + + # synaptic currents + current = self.comm(x) + if self._is_align_post: + self.refs['syn'].add_current(current) # synapse post current + else: + self.refs['out'].bind_cond(current) # align pre + return current + # class PairedSTDP(Projection): # r"""Paired spike-time-dependent plasticity model. # @@ -414,5 +415,3 @@ def update(self): # else: # self.refs['out'].bind_cond(current) # align pre # return current - - diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index 1ee009bca..e1146ecab 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import numpy as np -import pytest from absl.testing import parameterized import brainpy as bp @@ -17,112 +16,110 @@ class Test_STDP(parameterized.TestCase): - @parameterized.product( - comm_method=['csr', 'dense', 'masked_linear', 'all2all', 'one2one'], - delay=[None, 0., 2.], - syn_model=['exp', 'dual_exp', 'ampa'], - out_model=['cuba', 'coba', 'mg'] - ) - def test_STDP(self, comm_method, delay, syn_model, out_model): - bm.random.seed() - - class STDPNet(bp.DynamicalSystem): - def __init__(self, num_pre, num_post): - super().__init__() - self.pre = bp.dyn.LifRef(num_pre) - self.post = bp.dyn.LifRef(num_post) - - if comm_method == 'all2all': - comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'csr': - if syn_model == 'exp': - comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - else: - comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'masked_linear': - comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'dense': - comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'one2one': - comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) - else: - raise ValueError - - if syn_model == 'exp': - syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) - elif syn_model == 'dual_exp': - syn = bp.dyn.DualExpon.desc(self.post.varshape) - elif syn_model == 'dual_exp_v2': - syn = bp.dyn.DualExponV2.desc(self.post.varshape) - elif syn_model == 'ampa': - syn = bp.dyn.AMPA.desc(self.post.varshape) - else: - raise ValueError - - if out_model == 'cuba': - out = bp.dyn.CUBA.desc() - elif out_model == 'coba': - out = bp.dyn.COBA.desc(E=0.) - elif out_model == 'mg': - out = bp.dyn.MgBlock.desc(E=0.) - else: - raise ValueError - - self.syn = bp.dyn.STDP_Song2000( - pre=self.pre, - delay=delay, - comm=comm, - syn=syn, - out=out, - post=self.post, - tau_s=16.8, - tau_t=33.7, - A1=0.96, - A2=0.53, - W_min=0., - W_max=1. - ) - - def update(self, I_pre, I_post): - self.syn() - self.pre(I_pre) - self.post(I_post) - conductance = self.syn.refs['syn'].g - Apre = self.syn.refs['pre_trace'].g - Apost = self.syn.refs['post_trace'].g - current = self.post.sum_current_inputs(self.post.V) - if comm_method == 'dense': - w = self.syn.comm.W.flatten() - else: - w = self.syn.comm.weight.flatten() - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w - - duration = 300. - I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, - duration - 255]) - I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, - duration - 250]) - - net = STDPNet(1, 1) - - def run(i, I_pre, I_post): - pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) - return pre_spike, post_spike, g, Apre, Apost, current, W - - indices = np.arange(int(duration / bm.dt)) - pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - - # import matplotlib.pyplot as plt - # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) - # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) - # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) - # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) - # plt.show() - - + @parameterized.product( + comm_method=['csr', 'dense', 'masked_linear', 'all2all', 'one2one'], + delay=[None, 0., 2.], + syn_model=['exp', 'dual_exp', 'ampa'], + out_model=['cuba', 'coba', 'mg'] + ) + def test_STDP(self, comm_method, delay, syn_model, out_model): + bm.random.seed() + + class STDPNet(bp.DynamicalSystem): + def __init__(self, num_pre, num_post): + super().__init__() + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) + + if comm_method == 'all2all': + comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'csr': + if syn_model == 'exp': + comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + else: + comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'masked_linear': + comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'dense': + comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'one2one': + comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) + else: + raise ValueError + + if syn_model == 'exp': + syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) + elif syn_model == 'dual_exp': + syn = bp.dyn.DualExpon.desc(self.post.varshape) + elif syn_model == 'dual_exp_v2': + syn = bp.dyn.DualExponV2.desc(self.post.varshape) + elif syn_model == 'ampa': + syn = bp.dyn.AMPA.desc(self.post.varshape) + else: + raise ValueError + + if out_model == 'cuba': + out = bp.dyn.CUBA.desc() + elif out_model == 'coba': + out = bp.dyn.COBA.desc(E=0.) + elif out_model == 'mg': + out = bp.dyn.MgBlock.desc(E=0.) + else: + raise ValueError + + self.syn = bp.dyn.STDP_Song2000( + pre=self.pre, + delay=delay, + comm=comm, + syn=syn, + out=out, + post=self.post, + tau_s=16.8, + tau_t=33.7, + A1=0.96, + A2=0.53, + W_min=0., + W_max=1. + ) + + def update(self, I_pre, I_post): + self.syn() + self.pre(I_pre) + self.post(I_post) + conductance = self.syn.refs['syn'].g + Apre = self.syn.refs['pre_trace'].g + Apost = self.syn.refs['post_trace'].g + current = self.post.sum_current_inputs(self.post.V) + if comm_method == 'dense': + w = self.syn.comm.W.flatten() + else: + w = self.syn.comm.weight.flatten() + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w + + duration = 300. + I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, + duration - 255]) + I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, + duration - 250]) + + net = STDPNet(1, 1) + + def run(i, I_pre, I_post): + pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) + return pre_spike, post_spike, g, Apre, Apost, current, W + + indices = np.arange(int(duration / bm.dt)) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) + + # import matplotlib.pyplot as plt + # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) + # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) + # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) + # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) + # plt.show() diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 56e8f32d1..05ca22838 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -1,4 +1,3 @@ -import pytest import matplotlib.pyplot as plt import numpy as np @@ -12,430 +11,426 @@ def test_ProjAlignPreMg1(): - class EICOBA_PreAlign(bp.DynamicalSystem): - def __init__(self, scale=1., inp=20., delay=None): - super().__init__() - - self.inp = inp - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.I, - ) - self.E2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.E, - ) - self.I2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PreAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=show) - - net = EICOBA_PreAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=show) - - plt.close() - + class EICOBA_PreAlign(bp.DynamicalSystem): + def __init__(self, scale=1., inp=20., delay=None): + super().__init__() + + self.inp = inp + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.I, + ) + self.E2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.E, + ) + self.I2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PreAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=show) + + net = EICOBA_PreAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=show) + + plt.close() def test_ProjAlignPostMg2(): - class EICOBA_PostAlign(bp.DynamicalSystem): - def __init__(self, scale, inp=20., ltc=True, delay=None): - super().__init__() - self.inp = inp - - if ltc: - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - else: - self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2E = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E, - ) - self.E2I = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I, - ) - self.I2E = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PostAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=show) - - net = EICOBA_PostAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=show) - - net = EICOBA_PostAlign(0.5, ltc=False) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=show) - - plt.close() - + class EICOBA_PostAlign(bp.DynamicalSystem): + def __init__(self, scale, inp=20., ltc=True, delay=None): + super().__init__() + self.inp = inp + + if ltc: + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + else: + self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2E = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E, + ) + self.E2I = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I, + ) + self.I2E = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PostAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=show) + + net = EICOBA_PostAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=show) + + net = EICOBA_PostAlign(0.5, ltc=False) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=show) + + plt.close() def test_ProjAlignPost1(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.num_exc = int(3200 * scale) - self.num_inh = num - self.num_exc - prob = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), - syn=bp.dyn.Expon(size=num, tau=5.), + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.num_exc = int(3200 * scale) + self.num_inh = num - self.num_exc + prob = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost( + comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost( + comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:self.num_exc]) + self.I(spk[self.num_exc:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet(0.5) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + plt.close() + + +def test_ProjAlignPost2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale, delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (ne + ni) + + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(0.5, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + model = EINet(0.5, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + plt.close() + + +def test_VanillaProj(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=0.5): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.N) - self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), - syn=bp.dyn.Expon(size=num, tau=10.), + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.N) - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:self.num_exc]) - self.I(spk[self.num_exc:]) - self.delay(self.N(input)) - return self.N.spike.value + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value - model = EINet(0.5) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - plt.close() + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) - -def test_ProjAlignPost2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale, delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (ne + ni) - - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(0.5, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - model = EINet(0.5, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - - plt.close() - - -def test_VanillaProj(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=0.5): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - plt.close() + plt.close() def test_ProjAlignPreMg1_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - model = EINet(delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - - plt.close() + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + model = EINet(delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + plt.close() def test_ProjAlignPreMg2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(scale=0.2, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - model = EINet(scale=0.2, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=show) - - - plt.close() + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(scale=0.2, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + model = EINet(scale=0.2, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=show) + + plt.close() def test_vanalla_proj_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 1.)) - self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N - ) - self.I = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N - ) - - def update(self, input): - spk = self.delay.at('delay') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) - bp.visualize.raster_plot(indices, spks, show=show) - plt.close() - + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 1.)) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N + ) + self.I = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N + ) + + def update(self, input): + spk = self.delay.at('delay') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) + bp.visualize.raster_plot(indices, spks, show=show) + plt.close() diff --git a/brainpy/_src/dyn/projections/tests/test_delta.py b/brainpy/_src/dyn/projections/tests/test_delta.py index f4d21b643..90da7eb43 100644 --- a/brainpy/_src/dyn/projections/tests/test_delta.py +++ b/brainpy/_src/dyn/projections/tests/test_delta.py @@ -5,47 +5,45 @@ class NetForHalfProj(bp.DynamicalSystem): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - self.pre = bp.dyn.PoissonGroup(10, 100.) - self.post = bp.dyn.LifRef(1) - self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) - def update(self): - self.syn(self.pre()) - self.post() - return self.post.V.value + def update(self): + self.syn(self.pre()) + self.post() + return self.post.V.value def test1(): - net = NetForHalfProj() - indices = bm.arange(1000).to_numpy() - vs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices, vs, show=False) - plt.close('all') + net = NetForHalfProj() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=False) + plt.close('all') class NetForFullProj(bp.DynamicalSystem): - def __init__(self): - super().__init__() + def __init__(self): + super().__init__() - self.pre = bp.dyn.PoissonGroup(10, 100.) - self.post = bp.dyn.LifRef(1) - self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) - def update(self): - self.syn() - self.pre() - self.post() - return self.post.V.value + def update(self): + self.syn() + self.pre() + self.post() + return self.post.V.value def test2(): - net = NetForFullProj() - indices = bm.arange(1000).to_numpy() - vs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices, vs, show=False) - plt.close('all') - - + net = NetForFullProj() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=False) + plt.close('all') diff --git a/brainpy/_src/dyn/projections/utils.py b/brainpy/_src/dyn/projections/utils.py index 44a2273a4..b7946fb37 100644 --- a/brainpy/_src/dyn/projections/utils.py +++ b/brainpy/_src/dyn/projections/utils.py @@ -3,10 +3,9 @@ def _get_return(return_info): - if isinstance(return_info, bm.Variable): - return return_info.value - elif isinstance(return_info, ReturnInfo): - return return_info.get_data() - else: - raise NotImplementedError - + if isinstance(return_info, bm.Variable): + return return_info.value + elif isinstance(return_info, ReturnInfo): + return return_info.get_data() + else: + raise NotImplementedError diff --git a/brainpy/_src/dyn/projections/vanilla.py b/brainpy/_src/dyn/projections/vanilla.py index 15773d231..a49147b72 100644 --- a/brainpy/_src/dyn/projections/vanilla.py +++ b/brainpy/_src/dyn/projections/vanilla.py @@ -5,79 +5,79 @@ from brainpy._src.mixin import (JointType, BindCondData) __all__ = [ - 'VanillaProj', + 'VanillaProj', ] class VanillaProj(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group. - - **Code Examples** - - To simulate an E/I balanced network model: - - .. code-block:: - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=3200, tau=5.) - self.syn2 = bp.dyn.Expon(size=800, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:3200])) - self.I(self.syn2(spk[3200:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # output initialization - post.add_inp_fun(self.name, out) - - # references - self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` - self.refs['comm'] = comm # unify the access - - def update(self, x): - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current + """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group. + + **Code Examples** + + To simulate an E/I balanced network model: + + .. code-block:: + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=3200, tau=5.) + self.syn2 = bp.dyn.Expon(size=800, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:3200])) + self.I(self.syn2(spk[3200:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # output initialization + post.add_inp_fun(self.name, out) + + # references + self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + + def update(self, x): + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current diff --git a/brainpy/_src/dyn/rates/__init__.py b/brainpy/_src/dyn/rates/__init__.py index 76f828172..3509093b4 100644 --- a/brainpy/_src/dyn/rates/__init__.py +++ b/brainpy/_src/dyn/rates/__init__.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- +from .nvar import * from .populations import * from .reservoir import * -from .nvar import * from .rnncells import * - - diff --git a/brainpy/_src/dyn/rates/nvar.py b/brainpy/_src/dyn/rates/nvar.py index d5840f73d..dd75eb5b6 100644 --- a/brainpy/_src/dyn/rates/nvar.py +++ b/brainpy/_src/dyn/rates/nvar.py @@ -11,190 +11,190 @@ from brainpy._src.dnn.base import Layer __all__ = [ - 'NVAR' + 'NVAR' ] def _comb(N, k): - r"""The number of combinations of N things taken k at a time. + r"""The number of combinations of N things taken k at a time. - .. math:: + .. math:: - \frac{N!}{(N-k)! k!} + \frac{N!}{(N-k)! k!} - """ - if N > k: - val = 1 - for j in range(min(k, N - k)): - val = (val * (N - j)) // (j + 1) - return val - elif N == k: - return 1 - else: - return 0 + """ + if N > k: + val = 1 + for j in range(min(k, N - k)): + val = (val * (N - j)) // (j + 1) + return val + elif N == k: + return 1 + else: + return 0 class NVAR(Layer): - """Nonlinear vector auto-regression (NVAR) node. - - This class has the following features: - - - it supports batch size, - - it supports multiple orders, - - Parameters:: - - delay: int - The number of delay step. - order: int, sequence of int - The nonlinear order. - stride: int - The stride to sample linear part vector in the delays. - constant: optional, float - The constant value. - - References:: - - .. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation - reservoir computing. Nat Commun 12, 5564 (2021). - https://doi.org/10.1038/s41467-021-25801-2 - - """ - - def __init__( - self, - num_in: int, - delay: int, - order: Optional[Union[int, Sequence[int]]] = None, - stride: int = 1, - constant: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(NVAR, self).__init__(mode=mode, name=name) - - # parameters - order = tuple() if order is None else order - if not isinstance(order, (tuple, list)): - order = (order,) - self.order = tuple(order) - check.is_sequence(order, 'order', allow_none=False) - for o in order: - check.is_integer(o, 'order', allow_none=False, min_bound=2) - check.is_integer(delay, 'delay', allow_none=False, min_bound=1) - check.is_integer(stride, 'stride', allow_none=False, min_bound=1) - assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.' - self.delay = delay - self.stride = stride - self.constant = constant - self.num_delay = 1 + (self.delay - 1) * self.stride - self.num_in = num_in - - # delay variables - self.idx = bm.Variable(jnp.asarray([0])) - if isinstance(self.mode, bm.BatchingMode): - batch_size = 1 # first initialize the state with batch size = 1 - self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1) - else: - self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in))) - - # linear dimension - self.linear_dim = self.delay * num_in - # For each monomial created in the non-linear part, indices - # of the n components involved, n being the order of the - # monomials. Precompute them to improve efficiency. - self.comb_ids = [] - for order in self.order: - assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.' - idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order))) - self.comb_ids.append(jnp.asarray(idx)) - # number of non-linear components is (d + n - 1)! / (d - 1)! n! - # i.e. number of all unique monomials of order n made from the - # linear components. - self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids]) - # output dimension - self.num_out = int(self.linear_dim + self.nonlinear_dim) - if self.constant: - self.num_out += 1 - - def reset_state(self, batch_or_mode=None, **kwargs): - """Reset the node state which depends on batch size.""" - self.idx[0] = 0 - # To store the last inputs. - # Note, the batch axis is not in the first dimension, so we - # manually handle the state of NVAR, rather return it. - if batch_or_mode is None: - self.store.value = jnp.zeros((self.num_delay, self.num_in)) - else: - self.store.value = jnp.zeros((self.num_delay, batch_or_mode, self.num_in)) - - def update(self, x): - all_parts = [] - select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay - # 1. Store the current input - self.store[self.idx[0]] = x - - if isinstance(self.mode, bm.BatchingMode): - # 2. Linear part: - # select all previous inputs, including the current, with strides - linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) - linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1)) - # 3. constant - if self.constant: - constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype) - all_parts.append(constant) - all_parts.append(linear_parts) - # 3. Nonlinear part: - # select monomial terms and compute them - for ids in self.comb_ids: - all_parts.append(jnp.prod(linear_parts[:, ids], axis=2)) + """Nonlinear vector auto-regression (NVAR) node. - else: - # 2. Linear part: - # select all previous inputs, including the current, with strides - linear_parts = self.store[select_ids].flatten() # (num_time x num_feature,) - # 3. constant - if self.constant: - constant = jnp.ones((1,), dtype=x.dtype) - all_parts.append(constant) - all_parts.append(linear_parts) - # 3. Nonlinear part: - # select monomial terms and compute them - for ids in self.comb_ids: - all_parts.append(jnp.prod(linear_parts[ids], axis=1)) - - # 4. Finally - self.idx.value = (self.idx + 1) % self.num_delay - return jnp.concatenate(all_parts, axis=-1) - - def get_feature_names(self, for_plot=False) -> List[str]: - """Get output feature names for transformation. + This class has the following features: + + - it supports batch size, + - it supports multiple orders, Parameters:: - for_plot: bool - Use the feature names for plotting or not? (Default False) + delay: int + The number of delay step. + order: int, sequence of int + The nonlinear order. + stride: int + The stride to sample linear part vector in the delays. + constant: optional, float + The constant value. + + References:: + + .. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation + reservoir computing. Nat Commun 12, 5564 (2021). + https://doi.org/10.1038/s41467-021-25801-2 + """ - if for_plot: - linear_names = [f'x{i}_t' for i in range(self.num_in)] - else: - linear_names = [f'x{i}(t)' for i in range(self.num_in)] - for di in range(1, self.delay): - linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride)) - if for_plot else f'x{i}(t-{di * self.stride})') - for i in range(self.num_in)]) - nonlinear_names = [] - for ids in self.comb_ids: - for id_ in np.asarray(ids): - uniques, counts = np.unique(id_, return_counts=True) - nonlinear_names.append(" ".join( - "%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind] - for ind, exp in zip(uniques, counts) - )) - if for_plot: - all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names] - else: - all_names = linear_names + nonlinear_names - if self.constant: - all_names = ['1'] + all_names - return all_names + + def __init__( + self, + num_in: int, + delay: int, + order: Optional[Union[int, Sequence[int]]] = None, + stride: int = 1, + constant: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(NVAR, self).__init__(mode=mode, name=name) + + # parameters + order = tuple() if order is None else order + if not isinstance(order, (tuple, list)): + order = (order,) + self.order = tuple(order) + check.is_sequence(order, 'order', allow_none=False) + for o in order: + check.is_integer(o, 'order', allow_none=False, min_bound=2) + check.is_integer(delay, 'delay', allow_none=False, min_bound=1) + check.is_integer(stride, 'stride', allow_none=False, min_bound=1) + assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.' + self.delay = delay + self.stride = stride + self.constant = constant + self.num_delay = 1 + (self.delay - 1) * self.stride + self.num_in = num_in + + # delay variables + self.idx = bm.Variable(jnp.asarray([0])) + if isinstance(self.mode, bm.BatchingMode): + batch_size = 1 # first initialize the state with batch size = 1 + self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1) + else: + self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in))) + + # linear dimension + self.linear_dim = self.delay * num_in + # For each monomial created in the non-linear part, indices + # of the n components involved, n being the order of the + # monomials. Precompute them to improve efficiency. + self.comb_ids = [] + for order in self.order: + assert order >= 2, f'"order" must be a integer >= 2, while we got {order}.' + idx = np.array(list(combinations_with_replacement(np.arange(self.linear_dim), order))) + self.comb_ids.append(jnp.asarray(idx)) + # number of non-linear components is (d + n - 1)! / (d - 1)! n! + # i.e. number of all unique monomials of order n made from the + # linear components. + self.nonlinear_dim = sum([len(ids) for ids in self.comb_ids]) + # output dimension + self.num_out = int(self.linear_dim + self.nonlinear_dim) + if self.constant: + self.num_out += 1 + + def reset_state(self, batch_or_mode=None, **kwargs): + """Reset the node state which depends on batch size.""" + self.idx[0] = 0 + # To store the last inputs. + # Note, the batch axis is not in the first dimension, so we + # manually handle the state of NVAR, rather return it. + if batch_or_mode is None: + self.store.value = jnp.zeros((self.num_delay, self.num_in)) + else: + self.store.value = jnp.zeros((self.num_delay, batch_or_mode, self.num_in)) + + def update(self, x): + all_parts = [] + select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay + # 1. Store the current input + self.store[self.idx[0]] = x + + if isinstance(self.mode, bm.BatchingMode): + # 2. Linear part: + # select all previous inputs, including the current, with strides + linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) + linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1)) + # 3. constant + if self.constant: + constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype) + all_parts.append(constant) + all_parts.append(linear_parts) + # 3. Nonlinear part: + # select monomial terms and compute them + for ids in self.comb_ids: + all_parts.append(jnp.prod(linear_parts[:, ids], axis=2)) + + else: + # 2. Linear part: + # select all previous inputs, including the current, with strides + linear_parts = self.store[select_ids].flatten() # (num_time x num_feature,) + # 3. constant + if self.constant: + constant = jnp.ones((1,), dtype=x.dtype) + all_parts.append(constant) + all_parts.append(linear_parts) + # 3. Nonlinear part: + # select monomial terms and compute them + for ids in self.comb_ids: + all_parts.append(jnp.prod(linear_parts[ids], axis=1)) + + # 4. Finally + self.idx.value = (self.idx + 1) % self.num_delay + return jnp.concatenate(all_parts, axis=-1) + + def get_feature_names(self, for_plot=False) -> List[str]: + """Get output feature names for transformation. + + Parameters:: + + for_plot: bool + Use the feature names for plotting or not? (Default False) + """ + if for_plot: + linear_names = [f'x{i}_t' for i in range(self.num_in)] + else: + linear_names = [f'x{i}(t)' for i in range(self.num_in)] + for di in range(1, self.delay): + linear_names.extend([((f'x{i}_' + r'{t-%d}' % (di * self.stride)) + if for_plot else f'x{i}(t-{di * self.stride})') + for i in range(self.num_in)]) + nonlinear_names = [] + for ids in self.comb_ids: + for id_ in np.asarray(ids): + uniques, counts = np.unique(id_, return_counts=True) + nonlinear_names.append(" ".join( + "%s^%d" % (linear_names[ind], exp) if (exp != 1) else linear_names[ind] + for ind, exp in zip(uniques, counts) + )) + if for_plot: + all_names = [f'${n}$' for n in linear_names] + [f'${n}$' for n in nonlinear_names] + else: + all_names = linear_names + nonlinear_names + if self.constant: + all_names = ['1'] + all_names + return all_names diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py index 01c88b10d..edbeadfdb 100644 --- a/brainpy/_src/dyn/rates/populations.py +++ b/brainpy/_src/dyn/rates/populations.py @@ -6,8 +6,8 @@ from brainpy import math as bm from brainpy._src.context import share -from brainpy._src.dyn.others.noise import OUProcess from brainpy._src.dyn.base import NeuDyn +from brainpy._src.dyn.others.noise import OUProcess from brainpy._src.initialize import (Initializer, Uniform, parameter, @@ -20,1036 +20,1036 @@ from brainpy.types import Shape, ArrayType __all__ = [ - 'FHN', - 'FeedbackFHN', - 'QIF', - 'StuartLandauOscillator', - 'WilsonCowanModel', - 'ThresholdLinearModel', + 'FHN', + 'FeedbackFHN', + 'QIF', + 'StuartLandauOscillator', + 'WilsonCowanModel', + 'ThresholdLinearModel', ] class RateModel(NeuDyn): - pass + pass class FHN(RateModel): - r"""FitzHugh-Nagumo system used in [1]_. - - .. math:: - - \frac{dx}{dt} = -\alpha V^3 + \beta V^2 + \gamma V - w + I_{ext}\\ - \tau \frac{dy}{dt} = (V - \delta - \epsilon w) - - Parameters:: - - size: Shape - The model size. - x_ou_mean: Parameter - The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter - The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter - The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter - The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - - - References:: - - .. [1] Kostova, T., Ravindran, R., & Schonbek, M. (2004). FitzHugh–Nagumo - revisited: Types of bifurcations, periodical forcing and stability - regions by a Lyapunov functional. International journal of - bifurcation and chaos, 14(03), 913-925. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # fhn parameters - alpha: Union[float, ArrayType, Initializer, Callable] = 3.0, - beta: Union[float, ArrayType, Initializer, Callable] = 4.0, - gamma: Union[float, ArrayType, Initializer, Callable] = -1.5, - delta: Union[float, ArrayType, Initializer, Callable] = 0.0, - epsilon: Union[float, ArrayType, Initializer, Callable] = 0.5, - tau: Union[float, ArrayType, Initializer, Callable] = 20.0, - - # noise parameters - x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - - # other parameters - x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), - y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), - method: str = 'exp_auto', - name: str = None, - - # parameter for training - mode: bm.Mode = None, - input_var: bool = True, - ): - super().__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) - - # model parameters - self.alpha = parameter(alpha, self.varshape, allow_none=False) - self.beta = parameter(beta, self.varshape, allow_none=False) - self.gamma = parameter(gamma, self.varshape, allow_none=False) - self.delta = parameter(delta, self.varshape, allow_none=False) - self.epsilon = parameter(epsilon, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - - # noise parameters - self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process - self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process - self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity - self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity - self.x_ou_tau = parameter(x_ou_tau, self.varshape, - allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process - self.y_ou_tau = parameter(y_ou_tau, self.varshape, - allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process - self.input_var = input_var - - # initializers - is_initializer(x_initializer, 'x_initializer') - is_initializer(y_initializer, 'y_initializer') - self._x_initializer = x_initializer - self._y_initializer = y_initializer - - # variables - self.x = variable_(self._x_initializer, self.varshape, self.mode) - self.y = variable_(self._y_initializer, self.varshape, self.mode) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, self.mode) - self.input_y = variable_(bm.zeros, self.varshape, self.mode) - - # noise variables - self.x_ou = self.y_ou = None - if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.varshape, - self.x_ou_mean, - self.x_ou_sigma, - self.x_ou_tau, - method=method) - if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.varshape, - self.y_ou_mean, - self.y_ou_sigma, - self.y_ou_tau, - method=method) - - # integral functions - self.integral = odeint(f=JointEq(self.dx, self.dy), method=method) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) - self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) - if self.input_var: - self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) - self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) - - def dx(self, x, t, y, x_ext): - return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext - - def dy(self, y, t, x, y_ext=0.): - return (x - self.delta - self.epsilon * y) / self.tau + y_ext - - def update(self, inp_x=None, inp_y=None): - t = share.load('t') - dt = share.load('dt') - - # input - if self.input_var: - if inp_x is not None: - self.input += inp_x - if self.x_ou is not None: - self.input += self.x_ou() - if inp_y is not None: - self.input_y += inp_y - if self.y_ou is not None: - self.input_y += self.y_ou() - input_x = self.input.value - input_y = self.input_y.value - else: - input_x = inp_x if (inp_x is not None) else 0. - if self.x_ou is not None: input_x += self.x_ou() - input_y = inp_y if (inp_y is not None) else 0. - if self.y_ou is not None: input_y += self.y_ou() - - # integral - x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) - self.x.value = x - self.y.value = y - return x - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + r"""FitzHugh-Nagumo system used in [1]_. + + .. math:: + + \frac{dx}{dt} = -\alpha V^3 + \beta V^2 + \gamma V - w + I_{ext}\\ + \tau \frac{dy}{dt} = (V - \delta - \epsilon w) + + Parameters:: + + size: Shape + The model size. + x_ou_mean: Parameter + The noise mean of the :math:`x` variable, [mV/ms] + y_ou_mean: Parameter + The noise mean of the :math:`y` variable, [mV/ms]. + x_ou_sigma: Parameter + The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. + y_ou_sigma: Parameter + The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. + x_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. + y_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. + + + References:: + + .. [1] Kostova, T., Ravindran, R., & Schonbek, M. (2004). FitzHugh–Nagumo + revisited: Types of bifurcations, periodical forcing and stability + regions by a Lyapunov functional. International journal of + bifurcation and chaos, 14(03), 913-925. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # fhn parameters + alpha: Union[float, ArrayType, Initializer, Callable] = 3.0, + beta: Union[float, ArrayType, Initializer, Callable] = 4.0, + gamma: Union[float, ArrayType, Initializer, Callable] = -1.5, + delta: Union[float, ArrayType, Initializer, Callable] = 0.0, + epsilon: Union[float, ArrayType, Initializer, Callable] = 0.5, + tau: Union[float, ArrayType, Initializer, Callable] = 20.0, + + # noise parameters + x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + + # other parameters + x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), + y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), + method: str = 'exp_auto', + name: str = None, + + # parameter for training + mode: bm.Mode = None, + input_var: bool = True, + ): + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode) + + # model parameters + self.alpha = parameter(alpha, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.gamma = parameter(gamma, self.varshape, allow_none=False) + self.delta = parameter(delta, self.varshape, allow_none=False) + self.epsilon = parameter(epsilon, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + + # noise parameters + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) # mV/ms, OU process + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) # mV/ms/sqrt(ms), noise intensity + self.x_ou_tau = parameter(x_ou_tau, self.varshape, + allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process + self.y_ou_tau = parameter(y_ou_tau, self.varshape, + allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process + self.input_var = input_var + + # initializers + is_initializer(x_initializer, 'x_initializer') + is_initializer(y_initializer, 'y_initializer') + self._x_initializer = x_initializer + self._y_initializer = y_initializer + + # variables + self.x = variable_(self._x_initializer, self.varshape, self.mode) + self.y = variable_(self._y_initializer, self.varshape, self.mode) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, self.mode) + self.input_y = variable_(bm.zeros, self.varshape, self.mode) + + # noise variables + self.x_ou = self.y_ou = None + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) + + # integral functions + self.integral = odeint(f=JointEq(self.dx, self.dy), method=method) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) + self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) + self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) + + def dx(self, x, t, y, x_ext): + return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext + + def dy(self, y, t, x, y_ext=0.): + return (x - self.delta - self.epsilon * y) / self.tau + y_ext + + def update(self, inp_x=None, inp_y=None): + t = share.load('t') + dt = share.load('dt') + + # input + if self.input_var: + if inp_x is not None: + self.input += inp_x + if self.x_ou is not None: + self.input += self.x_ou() + if inp_y is not None: + self.input_y += inp_y + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = inp_x if (inp_x is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = inp_y if (inp_y is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + # integral + x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) + self.x.value = x + self.y.value = y + return x + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class FeedbackFHN(RateModel): - r"""FitzHugh-Nagumo model with recurrent neural feedback. - - The equation of the feedback FitzHugh-Nagumo model [4]_ is given by - - .. math:: - - \begin{aligned} - \frac{dx}{dt} &= x(t) - \frac{x^3(t)}{3} - y(t) + \mu[x(t-\mathrm{delay}) - x_0] \\ - \frac{dy}{dt} &= [x(t) + a - b y(t)] / \tau - \end{aligned} - - - **Model Examples** - - >>> import brainpy as bp - >>> fhn = bp.rates.FeedbackFHN(1, delay=10.) - >>> runner = bp.DSRunner(fhn, inputs=('input', 1.), monitors=['x', 'y']) - >>> runner.run(100.) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y') - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x', show=True) - - - **Model Parameters** - - ============= ============== ======== ======================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------ - a 1 \ Positive constant - b 1 \ Positive constant - tau 12.5 ms Membrane time constant. - delay 10 ms Synaptic delay time constant. - V_th 1.8 mV Threshold potential of spike. - v0 -1 mV Resting potential. - mu 1.8 \ The feedback strength. When positive, it is a excitatory feedback; - when negative, it is a inhibitory feedback. - ============= ============== ======== ======================== - - Parameters:: - - x_ou_mean: Parameter - The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter - The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter - The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter - The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - - References:: - - .. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference - Equation Modeling Recurrent Neural Feedback. SIAM Journal on - Applied Mathematics, 40(1), 150–162.* doi:10.1137/0140012 - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # model parameters - a: Union[float, ArrayType, Initializer, Callable] = 0.7, - b: Union[float, ArrayType, Initializer, Callable] = 0.8, - delay: Union[float, ArrayType, Initializer, Callable] = 10., - tau: Union[float, ArrayType, Initializer, Callable] = 12.5, - mu: Union[float, ArrayType, Initializer, Callable] = 1.6886, - v0: Union[float, ArrayType, Initializer, Callable] = -1, - - # noise parameters - x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - - # other parameters - x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), - y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), - method: str = 'exp_auto', - name: str = None, - - # parameter for training - mode: bm.Mode = None, - input_var: bool = True, - ): - super(FeedbackFHN, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) - - # parameters - self.a = parameter(a, self.varshape, allow_none=False) - self.b = parameter(b, self.varshape, allow_none=False) - self.delay = parameter(delay, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.mu = parameter(mu, self.varshape, allow_none=False) # feedback strength - self.v0 = parameter(v0, self.varshape, allow_none=False) # resting potential - - # noise parameters - self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) - self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) - self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) - self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) - self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) - self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) - self.input_var = input_var - - # initializers - is_initializer(x_initializer, 'x_initializer') - is_initializer(y_initializer, 'y_initializer') - self._x_initializer = x_initializer - self._y_initializer = y_initializer - - # variables - self.x = variable(x_initializer, self.mode, self.varshape) - self.y = variable(y_initializer, self.mode, self.varshape) - self.x_delay = bm.TimeDelay(self.x, self.delay, dt=bm.dt, interp_method='round') - if self.input_var: - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) - - # noise variables - self.x_ou = self.y_ou = None - if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.varshape, - self.x_ou_mean, - self.x_ou_sigma, - self.x_ou_tau, - method=method) - if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.varshape, - self.y_ou_mean, - self.y_ou_sigma, - self.y_ou_tau, - method=method) - - # integral - self.integral = odeint(method=method, - f=JointEq([self.dx, self.dy]), - state_delays={'x': self.x_delay}) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) - self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) - self.x_delay.reset(self.x, self.delay) - if self.input_var: - self.input = variable(bm.zeros, batch_or_mode, self.varshape) - self.input_y = variable(bm.zeros, batch_or_mode, self.varshape) - - def dx(self, x, t, y, x_ext): - return x - x * x * x / 3 - y + x_ext + self.mu * (self.x_delay(t - self.delay) - self.v0) - - def dy(self, y, t, x, y_ext): - return (x + self.a - self.b * y + y_ext) / self.tau - - def update(self, inp_x=None, inp_y=None): - t = share.load('t') - dt = share.load('dt') - - # input - if self.input_var: - if inp_x is not None: - self.input += inp_x - if self.x_ou is not None: - self.input += self.x_ou() - if inp_y is not None: - self.input_y += inp_y - if self.y_ou is not None: - self.input_y += self.y_ou() - input_x = self.input.value - input_y = self.input_y.value - else: - input_x = inp_x if (inp_x is not None) else 0. - if self.x_ou is not None: input_x += self.x_ou() - input_y = inp_y if (inp_y is not None) else 0. - if self.y_ou is not None: input_y += self.y_ou() - - x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) - self.x.value = x - self.y.value = y - return x - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + r"""FitzHugh-Nagumo model with recurrent neural feedback. + + The equation of the feedback FitzHugh-Nagumo model [4]_ is given by + + .. math:: + + \begin{aligned} + \frac{dx}{dt} &= x(t) - \frac{x^3(t)}{3} - y(t) + \mu[x(t-\mathrm{delay}) - x_0] \\ + \frac{dy}{dt} &= [x(t) + a - b y(t)] / \tau + \end{aligned} + + + **Model Examples** + + >>> import brainpy as bp + >>> fhn = bp.rates.FeedbackFHN(1, delay=10.) + >>> runner = bp.DSRunner(fhn, inputs=('input', 1.), monitors=['x', 'y']) + >>> runner.run(100.) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y') + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x', show=True) + + + **Model Parameters** + + ============= ============== ======== ======================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------ + a 1 \ Positive constant + b 1 \ Positive constant + tau 12.5 ms Membrane time constant. + delay 10 ms Synaptic delay time constant. + V_th 1.8 mV Threshold potential of spike. + v0 -1 mV Resting potential. + mu 1.8 \ The feedback strength. When positive, it is a excitatory feedback; + when negative, it is a inhibitory feedback. + ============= ============== ======== ======================== + + Parameters:: + + x_ou_mean: Parameter + The noise mean of the :math:`x` variable, [mV/ms] + y_ou_mean: Parameter + The noise mean of the :math:`y` variable, [mV/ms]. + x_ou_sigma: Parameter + The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. + y_ou_sigma: Parameter + The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. + x_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. + y_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. + + References:: + + .. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference + Equation Modeling Recurrent Neural Feedback. SIAM Journal on + Applied Mathematics, 40(1), 150–162.* doi:10.1137/0140012 + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + a: Union[float, ArrayType, Initializer, Callable] = 0.7, + b: Union[float, ArrayType, Initializer, Callable] = 0.8, + delay: Union[float, ArrayType, Initializer, Callable] = 10., + tau: Union[float, ArrayType, Initializer, Callable] = 12.5, + mu: Union[float, ArrayType, Initializer, Callable] = 1.6886, + v0: Union[float, ArrayType, Initializer, Callable] = -1, + + # noise parameters + x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + + # other parameters + x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), + y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), + method: str = 'exp_auto', + name: str = None, + + # parameter for training + mode: bm.Mode = None, + input_var: bool = True, + ): + super(FeedbackFHN, self).__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode) + + # parameters + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.delay = parameter(delay, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.mu = parameter(mu, self.varshape, allow_none=False) # feedback strength + self.v0 = parameter(v0, self.varshape, allow_none=False) # resting potential + + # noise parameters + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + self.input_var = input_var + + # initializers + is_initializer(x_initializer, 'x_initializer') + is_initializer(y_initializer, 'y_initializer') + self._x_initializer = x_initializer + self._y_initializer = y_initializer + + # variables + self.x = variable(x_initializer, self.mode, self.varshape) + self.y = variable(y_initializer, self.mode, self.varshape) + self.x_delay = bm.TimeDelay(self.x, self.delay, dt=bm.dt, interp_method='round') + if self.input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) + + # noise variables + self.x_ou = self.y_ou = None + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) + + # integral + self.integral = odeint(method=method, + f=JointEq([self.dx, self.dy]), + state_delays={'x': self.x_delay}) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) + self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) + self.x_delay.reset(self.x, self.delay) + if self.input_var: + self.input = variable(bm.zeros, batch_or_mode, self.varshape) + self.input_y = variable(bm.zeros, batch_or_mode, self.varshape) + + def dx(self, x, t, y, x_ext): + return x - x * x * x / 3 - y + x_ext + self.mu * (self.x_delay(t - self.delay) - self.v0) + + def dy(self, y, t, x, y_ext): + return (x + self.a - self.b * y + y_ext) / self.tau + + def update(self, inp_x=None, inp_y=None): + t = share.load('t') + dt = share.load('dt') + + # input + if self.input_var: + if inp_x is not None: + self.input += inp_x + if self.x_ou is not None: + self.input += self.x_ou() + if inp_y is not None: + self.input_y += inp_y + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = inp_x if (inp_x is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = inp_y if (inp_y is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) + self.x.value = x + self.y.value = y + return x + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class QIF(RateModel): - r"""A mean-field model of a quadratic integrate-and-fire neuron population. - - **Model Descriptions** - - The QIF population mean-field model, which has been derived from a - population of all-to-all coupled QIF neurons in [5]_. - The model equations are given by: - - .. math:: - - \begin{aligned} - \tau \dot{r} &=\frac{\Delta}{\pi \tau}+2 r v \\ - \tau \dot{v} &=v^{2}+\bar{\eta}+I(t)+J r \tau-(\pi r \tau)^{2} - \end{aligned} - - where :math:`r` is the average firing rate and :math:`v` is the - average membrane potential of the QIF population [5]_. - - This mean-field model is an exact representation of the macroscopic - firing rate and membrane potential dynamics of a spiking neural network - consisting of QIF neurons with Lorentzian distributed background - excitability. While the mean-field derivation is mathematically - only valid for all-to-all coupled populations of infinite size, it - has been shown that there is a close correspondence between the - mean-field model and neural populations with sparse coupling and - population sizes of a few thousand neurons [6]_. - - **Model Parameters** - - ============= ============== ======== ======================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------ - tau 1 ms the population time constant - eta -5. \ the mean of a Lorenzian distribution over the neural excitability in the population - delta 1.0 \ the half-width at half maximum of the Lorenzian distribution over the neural excitability - J 15 \ the strength of the recurrent coupling inside the population - ============= ============== ======== ======================== - - Parameters:: - - x_ou_mean: Parameter - The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter - The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter - The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter - The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - - - References:: - - .. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for - networks of spiking neurons. Physical Review X, 5:021028, - https://doi.org/10.1103/PhysRevX.5.021028. - .. [6] R. Gast, H. Schmidt, T.R. Knösche (2020) A Mean-Field Description - of Bursting Dynamics in Spiking Neural Networks with Short-Term - Adaptation. Neural Computation 32.9 (2020): 1615-1634. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # model parameters - tau: Union[float, ArrayType, Initializer, Callable] = 1., - eta: Union[float, ArrayType, Initializer, Callable] = -5.0, - delta: Union[float, ArrayType, Initializer, Callable] = 1.0, - J: Union[float, ArrayType, Initializer, Callable] = 15., - - # noise parameters - x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - - # other parameters - x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), - y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), - method: str = 'exp_auto', - name: str = None, - input_var: bool = True, - - # parameter for training - mode: bm.Mode = None, - ): - super(QIF, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) - - # parameters - self.tau = parameter(tau, self.varshape, allow_none=False) - # the mean of a Lorenzian distribution over the neural excitability in the population - self.eta = parameter(eta, self.varshape, allow_none=False) - # the half-width at half maximum of the Lorenzian distribution over the neural excitability - self.delta = parameter(delta, self.varshape, allow_none=False) - # the strength of the recurrent coupling inside the population - self.J = parameter(J, self.varshape, allow_none=False) - - # noise parameters - self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) - self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) - self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) - self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) - self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) - self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) - self.input_var = input_var - - # initializers - is_initializer(x_initializer, 'x_initializer') - is_initializer(y_initializer, 'y_initializer') - self._x_initializer = x_initializer - self._y_initializer = y_initializer - - # variables - self.x = variable(x_initializer, self.mode, self.varshape) - self.y = variable(y_initializer, self.mode, self.varshape) - if self.input_var: - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) - - # noise variables - self.x_ou = self.y_ou = None - if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.varshape, - self.x_ou_mean, - self.x_ou_sigma, - self.x_ou_tau, - method=method) - if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.varshape, - self.y_ou_mean, - self.y_ou_sigma, - self.y_ou_tau, - method=method) - - # functions - self.integral = odeint(JointEq([self.dx, self.dy]), method=method) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) - self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) - if self.input_var: - self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) - self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) - - def dy(self, y, t, x, y_ext): - return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau - - def dx(self, x, t, y, x_ext): - return (x ** 2 + self.eta + x_ext + self.J * y * self.tau - - (bm.pi * y * self.tau) ** 2) / self.tau - - def update(self, inp_x=None, inp_y=None): - t = share.load('t') - dt = share.load('dt') - - # input - if self.input_var: - if inp_x is not None: - self.input += inp_x - if self.x_ou is not None: - self.input += self.x_ou() - if inp_y is not None: - self.input_y += inp_y - if self.y_ou is not None: - self.input_y += self.y_ou() - input_x = self.input.value - input_y = self.input_y.value - else: - input_x = inp_x if (inp_x is not None) else 0. - if self.x_ou is not None: input_x += self.x_ou() - input_y = inp_y if (inp_y is not None) else 0. - if self.y_ou is not None: input_y += self.y_ou() - - x, y = self.integral(self.x.value, self.y.value, t=t, x_ext=input_x, y_ext=input_y, dt=dt) - self.x.value = x - self.y.value = y - return x - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + r"""A mean-field model of a quadratic integrate-and-fire neuron population. + + **Model Descriptions** + + The QIF population mean-field model, which has been derived from a + population of all-to-all coupled QIF neurons in [5]_. + The model equations are given by: + + .. math:: + + \begin{aligned} + \tau \dot{r} &=\frac{\Delta}{\pi \tau}+2 r v \\ + \tau \dot{v} &=v^{2}+\bar{\eta}+I(t)+J r \tau-(\pi r \tau)^{2} + \end{aligned} + + where :math:`r` is the average firing rate and :math:`v` is the + average membrane potential of the QIF population [5]_. + + This mean-field model is an exact representation of the macroscopic + firing rate and membrane potential dynamics of a spiking neural network + consisting of QIF neurons with Lorentzian distributed background + excitability. While the mean-field derivation is mathematically + only valid for all-to-all coupled populations of infinite size, it + has been shown that there is a close correspondence between the + mean-field model and neural populations with sparse coupling and + population sizes of a few thousand neurons [6]_. + + **Model Parameters** + + ============= ============== ======== ======================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------ + tau 1 ms the population time constant + eta -5. \ the mean of a Lorenzian distribution over the neural excitability in the population + delta 1.0 \ the half-width at half maximum of the Lorenzian distribution over the neural excitability + J 15 \ the strength of the recurrent coupling inside the population + ============= ============== ======== ======================== + + Parameters:: + + x_ou_mean: Parameter + The noise mean of the :math:`x` variable, [mV/ms] + y_ou_mean: Parameter + The noise mean of the :math:`y` variable, [mV/ms]. + x_ou_sigma: Parameter + The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. + y_ou_sigma: Parameter + The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. + x_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. + y_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. + + + References:: + + .. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for + networks of spiking neurons. Physical Review X, 5:021028, + https://doi.org/10.1103/PhysRevX.5.021028. + .. [6] R. Gast, H. Schmidt, T.R. Knösche (2020) A Mean-Field Description + of Bursting Dynamics in Spiking Neural Networks with Short-Term + Adaptation. Neural Computation 32.9 (2020): 1615-1634. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + tau: Union[float, ArrayType, Initializer, Callable] = 1., + eta: Union[float, ArrayType, Initializer, Callable] = -5.0, + delta: Union[float, ArrayType, Initializer, Callable] = 1.0, + J: Union[float, ArrayType, Initializer, Callable] = 15., + + # noise parameters + x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + + # other parameters + x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), + y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), + method: str = 'exp_auto', + name: str = None, + input_var: bool = True, + + # parameter for training + mode: bm.Mode = None, + ): + super(QIF, self).__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode) + + # parameters + self.tau = parameter(tau, self.varshape, allow_none=False) + # the mean of a Lorenzian distribution over the neural excitability in the population + self.eta = parameter(eta, self.varshape, allow_none=False) + # the half-width at half maximum of the Lorenzian distribution over the neural excitability + self.delta = parameter(delta, self.varshape, allow_none=False) + # the strength of the recurrent coupling inside the population + self.J = parameter(J, self.varshape, allow_none=False) + + # noise parameters + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + self.input_var = input_var + + # initializers + is_initializer(x_initializer, 'x_initializer') + is_initializer(y_initializer, 'y_initializer') + self._x_initializer = x_initializer + self._y_initializer = y_initializer + + # variables + self.x = variable(x_initializer, self.mode, self.varshape) + self.y = variable(y_initializer, self.mode, self.varshape) + if self.input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) + + # noise variables + self.x_ou = self.y_ou = None + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) + + # functions + self.integral = odeint(JointEq([self.dx, self.dy]), method=method) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) + self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) + self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) + + def dy(self, y, t, x, y_ext): + return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau + + def dx(self, x, t, y, x_ext): + return (x ** 2 + self.eta + x_ext + self.J * y * self.tau - + (bm.pi * y * self.tau) ** 2) / self.tau + + def update(self, inp_x=None, inp_y=None): + t = share.load('t') + dt = share.load('dt') + + # input + if self.input_var: + if inp_x is not None: + self.input += inp_x + if self.x_ou is not None: + self.input += self.x_ou() + if inp_y is not None: + self.input_y += inp_y + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = inp_x if (inp_x is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = inp_y if (inp_y is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x.value, self.y.value, t=t, x_ext=input_x, y_ext=input_y, dt=dt) + self.x.value = x + self.y.value = y + return x + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class StuartLandauOscillator(RateModel): - r""" - Stuart-Landau model with Hopf bifurcation. - - .. math:: - - \frac{dx}{dt} = (a - x^2 - y^2) * x - w*y + I^x_{ext} \\ - \frac{dy}{dt} = (a - x^2 - y^2) * y + w*x + I^y_{ext} - - Parameters:: - - x_ou_mean: Parameter - The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter - The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter - The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter - The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # model parameters - a: Union[float, ArrayType, Initializer, Callable] = 0.25, - w: Union[float, ArrayType, Initializer, Callable] = 0.2, - - # noise parameters - x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - - # other parameters - x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.5), - y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.5), - method: str = 'exp_auto', - name: str = None, - - # parameter for training - mode: bm.Mode = None, - input_var: bool = True, - ): - super(StuartLandauOscillator, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) - - # model parameters - self.a = parameter(a, self.varshape, allow_none=False) - self.w = parameter(w, self.varshape, allow_none=False) - - # noise parameters - self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) - self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) - self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) - self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) - self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) - self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) - self.input_var = input_var - - # initializers - is_initializer(x_initializer, 'x_initializer') - is_initializer(y_initializer, 'y_initializer') - self._x_initializer = x_initializer - self._y_initializer = y_initializer - - # variables - self.x = variable(x_initializer, self.mode, self.varshape) - self.y = variable(y_initializer, self.mode, self.varshape) - if input_var: - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) - - # noise variables - self.x_ou = self.y_ou = None - if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.varshape, - self.x_ou_mean, - self.x_ou_sigma, - self.x_ou_tau, - method=method) - if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.varshape, - self.y_ou_mean, - self.y_ou_sigma, - self.y_ou_tau, - method=method) - - # integral functions - self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) - self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) - if self.input_var: - self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) - self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) - - def dx(self, x, t, y, x_ext, a, w): - return (a - x * x - y * y) * x - w * y + x_ext - - def dy(self, y, t, x, y_ext, a, w): - return (a - x * x - y * y) * y - w * y + y_ext - - def update(self, inp_x=None, inp_y=None): - t = share.load('t') - dt = share.load('dt') - - # input - if self.input_var: - if inp_x is not None: - self.input += inp_x - if self.x_ou is not None: - self.input += self.x_ou() - if inp_y is not None: - self.input_y += inp_y - if self.y_ou is not None: - self.input_y += self.y_ou() - input_x = self.input.value - input_y = self.input_y.value - else: - input_x = inp_x if (inp_x is not None) else 0. - if self.x_ou is not None: input_x += self.x_ou() - input_y = inp_y if (inp_y is not None) else 0. - if self.y_ou is not None: input_y += self.y_ou() - - x, y = self.integral(self.x.value, - self.y.value, - t=t, - x_ext=input_x, - y_ext=input_y, - a=self.a, - w=self.w, - dt=dt) - self.x.value = x - self.y.value = y - return x - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + r""" + Stuart-Landau model with Hopf bifurcation. + + .. math:: + + \frac{dx}{dt} = (a - x^2 - y^2) * x - w*y + I^x_{ext} \\ + \frac{dy}{dt} = (a - x^2 - y^2) * y + w*x + I^y_{ext} + + Parameters:: + + x_ou_mean: Parameter + The noise mean of the :math:`x` variable, [mV/ms] + y_ou_mean: Parameter + The noise mean of the :math:`y` variable, [mV/ms]. + x_ou_sigma: Parameter + The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. + y_ou_sigma: Parameter + The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. + x_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. + y_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + a: Union[float, ArrayType, Initializer, Callable] = 0.25, + w: Union[float, ArrayType, Initializer, Callable] = 0.2, + + # noise parameters + x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + + # other parameters + x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.5), + y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.5), + method: str = 'exp_auto', + name: str = None, + + # parameter for training + mode: bm.Mode = None, + input_var: bool = True, + ): + super(StuartLandauOscillator, self).__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode) + + # model parameters + self.a = parameter(a, self.varshape, allow_none=False) + self.w = parameter(w, self.varshape, allow_none=False) + + # noise parameters + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + self.input_var = input_var + + # initializers + is_initializer(x_initializer, 'x_initializer') + is_initializer(y_initializer, 'y_initializer') + self._x_initializer = x_initializer + self._y_initializer = y_initializer + + # variables + self.x = variable(x_initializer, self.mode, self.varshape) + self.y = variable(y_initializer, self.mode, self.varshape) + if input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) + + # noise variables + self.x_ou = self.y_ou = None + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) + + # integral functions + self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) + self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) + self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) + + def dx(self, x, t, y, x_ext, a, w): + return (a - x * x - y * y) * x - w * y + x_ext + + def dy(self, y, t, x, y_ext, a, w): + return (a - x * x - y * y) * y - w * y + y_ext + + def update(self, inp_x=None, inp_y=None): + t = share.load('t') + dt = share.load('dt') + + # input + if self.input_var: + if inp_x is not None: + self.input += inp_x + if self.x_ou is not None: + self.input += self.x_ou() + if inp_y is not None: + self.input_y += inp_y + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = inp_x if (inp_x is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = inp_y if (inp_y is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x.value, + self.y.value, + t=t, + x_ext=input_x, + y_ext=input_y, + a=self.a, + w=self.w, + dt=dt) + self.x.value = x + self.y.value = y + return x + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class WilsonCowanModel(RateModel): - """Wilson-Cowan population model. - - - Parameters:: - - x_ou_mean: Parameter - The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter - The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter - The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter - The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter - The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - - - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # Excitatory parameters - E_tau: Union[float, ArrayType, Initializer, Callable] = 1., # excitatory time constant - E_a: Union[float, ArrayType, Initializer, Callable] = 1.2, # excitatory gain - E_theta: Union[float, ArrayType, Initializer, Callable] = 2.8, # excitatory firing threshold - - # Inhibitory parameters - I_tau: Union[float, ArrayType, Initializer, Callable] = 1., # inhibitory time constant - I_a: Union[float, ArrayType, Initializer, Callable] = 1., # inhibitory gain - I_theta: Union[float, ArrayType, Initializer, Callable] = 4.0, # inhibitory firing threshold - - # connection parameters - wEE: Union[float, ArrayType, Initializer, Callable] = 12., # local E-E coupling - wIE: Union[float, ArrayType, Initializer, Callable] = 4., # local E-I coupling - wEI: Union[float, ArrayType, Initializer, Callable] = 13., # local I-E coupling - wII: Union[float, ArrayType, Initializer, Callable] = 11., # local I-I coupling - - # Refractory parameter - r: Union[float, ArrayType, Initializer, Callable] = 1., - - # noise parameters - x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, - y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, - - # state initializer - x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(max_val=0.05), - y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(max_val=0.05), - - # other parameters - method: str = 'exp_euler_auto', - name: str = None, - - # parameter for training - mode: bm.Mode = None, - input_var: bool = True, - ): - super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) - - # model parameters - self.E_a = parameter(E_a, self.varshape, allow_none=False) - self.I_a = parameter(I_a, self.varshape, allow_none=False) - self.E_tau = parameter(E_tau, self.varshape, allow_none=False) - self.I_tau = parameter(I_tau, self.varshape, allow_none=False) - self.E_theta = parameter(E_theta, self.varshape, allow_none=False) - self.I_theta = parameter(I_theta, self.varshape, allow_none=False) - self.wEE = parameter(wEE, self.varshape, allow_none=False) - self.wIE = parameter(wIE, self.varshape, allow_none=False) - self.wEI = parameter(wEI, self.varshape, allow_none=False) - self.wII = parameter(wII, self.varshape, allow_none=False) - self.r = parameter(r, self.varshape, allow_none=False) - self.input_var = input_var - - # noise parameters - self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) - self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) - self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) - self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) - self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) - self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) - - # initializers - is_initializer(x_initializer, 'x_initializer') - is_initializer(y_initializer, 'y_initializer') - self._x_initializer = x_initializer - self._y_initializer = y_initializer - - # variables - self.x = variable(x_initializer, self.mode, self.varshape) - self.y = variable(y_initializer, self.mode, self.varshape) - if self.input_var: - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) - - # noise variables - self.x_ou = self.y_ou = None - if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): - self.x_ou = OUProcess(self.varshape, - self.x_ou_mean, - self.x_ou_sigma, - self.x_ou_tau, - method=method) - if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): - self.y_ou = OUProcess(self.varshape, - self.y_ou_mean, - self.y_ou_sigma, - self.y_ou_tau, - method=method) - - # functions - self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) - self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) - if self.input_var: - self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) - self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) - - def F(self, x, a, theta): - return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) - - def dx(self, x, t, y, x_ext): - xx = self.wEE * x - self.wIE * y + x_ext - return (-x + (1 - self.r * x) * self.F(xx, self.E_a, self.E_theta)) / self.E_tau - - def dy(self, y, t, x, y_ext): - xx = self.wEI * x - self.wII * y + y_ext - return (-y + (1 - self.r * y) * self.F(xx, self.I_a, self.I_theta)) / self.I_tau - - def update(self, inp_x=None, inp_y=None): - t = share.load('t') - dt = share.load('dt') - - # input - if self.input_var: - if inp_x is not None: - self.input += inp_x - if self.x_ou is not None: - self.input += self.x_ou() - if inp_y is not None: - self.input_y += inp_y - if self.y_ou is not None: - self.input_y += self.y_ou() - input_x = self.input.value - input_y = self.input_y.value - else: - input_x = inp_x if (inp_x is not None) else 0. - if self.x_ou is not None: input_x += self.x_ou() - input_y = inp_y if (inp_y is not None) else 0. - if self.y_ou is not None: input_y += self.y_ou() - - x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) - self.x.value = x - self.y.value = y - return x - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + """Wilson-Cowan population model. + + + Parameters:: + + x_ou_mean: Parameter + The noise mean of the :math:`x` variable, [mV/ms] + y_ou_mean: Parameter + The noise mean of the :math:`y` variable, [mV/ms]. + x_ou_sigma: Parameter + The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. + y_ou_sigma: Parameter + The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. + x_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. + y_ou_tau: Parameter + The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. + + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # Excitatory parameters + E_tau: Union[float, ArrayType, Initializer, Callable] = 1., # excitatory time constant + E_a: Union[float, ArrayType, Initializer, Callable] = 1.2, # excitatory gain + E_theta: Union[float, ArrayType, Initializer, Callable] = 2.8, # excitatory firing threshold + + # Inhibitory parameters + I_tau: Union[float, ArrayType, Initializer, Callable] = 1., # inhibitory time constant + I_a: Union[float, ArrayType, Initializer, Callable] = 1., # inhibitory gain + I_theta: Union[float, ArrayType, Initializer, Callable] = 4.0, # inhibitory firing threshold + + # connection parameters + wEE: Union[float, ArrayType, Initializer, Callable] = 12., # local E-E coupling + wIE: Union[float, ArrayType, Initializer, Callable] = 4., # local E-I coupling + wEI: Union[float, ArrayType, Initializer, Callable] = 13., # local I-E coupling + wII: Union[float, ArrayType, Initializer, Callable] = 11., # local I-I coupling + + # Refractory parameter + r: Union[float, ArrayType, Initializer, Callable] = 1., + + # noise parameters + x_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + x_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + y_ou_mean: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_sigma: Union[float, ArrayType, Initializer, Callable] = 0.0, + y_ou_tau: Union[float, ArrayType, Initializer, Callable] = 5.0, + + # state initializer + x_initializer: Union[Initializer, Callable, ArrayType] = Uniform(max_val=0.05), + y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(max_val=0.05), + + # other parameters + method: str = 'exp_euler_auto', + name: str = None, + + # parameter for training + mode: bm.Mode = None, + input_var: bool = True, + ): + super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) + + # model parameters + self.E_a = parameter(E_a, self.varshape, allow_none=False) + self.I_a = parameter(I_a, self.varshape, allow_none=False) + self.E_tau = parameter(E_tau, self.varshape, allow_none=False) + self.I_tau = parameter(I_tau, self.varshape, allow_none=False) + self.E_theta = parameter(E_theta, self.varshape, allow_none=False) + self.I_theta = parameter(I_theta, self.varshape, allow_none=False) + self.wEE = parameter(wEE, self.varshape, allow_none=False) + self.wIE = parameter(wIE, self.varshape, allow_none=False) + self.wEI = parameter(wEI, self.varshape, allow_none=False) + self.wII = parameter(wII, self.varshape, allow_none=False) + self.r = parameter(r, self.varshape, allow_none=False) + self.input_var = input_var + + # noise parameters + self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) + self.y_ou_mean = parameter(y_ou_mean, self.varshape, allow_none=False) + self.x_ou_sigma = parameter(x_ou_sigma, self.varshape, allow_none=False) + self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) + self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) + self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + + # initializers + is_initializer(x_initializer, 'x_initializer') + is_initializer(y_initializer, 'y_initializer') + self._x_initializer = x_initializer + self._y_initializer = y_initializer + + # variables + self.x = variable(x_initializer, self.mode, self.varshape) + self.y = variable(y_initializer, self.mode, self.varshape) + if self.input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) + + # noise variables + self.x_ou = self.y_ou = None + if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): + self.x_ou = OUProcess(self.varshape, + self.x_ou_mean, + self.x_ou_sigma, + self.x_ou_tau, + method=method) + if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): + self.y_ou = OUProcess(self.varshape, + self.y_ou_mean, + self.y_ou_sigma, + self.y_ou_tau, + method=method) + + # functions + self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x.value = variable(self._x_initializer, batch_or_mode, self.varshape) + self.y.value = variable(self._y_initializer, batch_or_mode, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_or_mode, self.varshape) + self.input_y.value = variable(bm.zeros, batch_or_mode, self.varshape) + + def F(self, x, a, theta): + return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) + + def dx(self, x, t, y, x_ext): + xx = self.wEE * x - self.wIE * y + x_ext + return (-x + (1 - self.r * x) * self.F(xx, self.E_a, self.E_theta)) / self.E_tau + + def dy(self, y, t, x, y_ext): + xx = self.wEI * x - self.wII * y + y_ext + return (-y + (1 - self.r * y) * self.F(xx, self.I_a, self.I_theta)) / self.I_tau + + def update(self, inp_x=None, inp_y=None): + t = share.load('t') + dt = share.load('dt') + + # input + if self.input_var: + if inp_x is not None: + self.input += inp_x + if self.x_ou is not None: + self.input += self.x_ou() + if inp_y is not None: + self.input_y += inp_y + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = inp_x if (inp_x is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = inp_y if (inp_y is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) + self.x.value = x + self.y.value = y + return x + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class JansenRitModel(RateModel): - pass + pass class KuramotoOscillator(RateModel): - pass + pass class ThetaNeuron(RateModel): - pass + pass class RateQIFWithSFA(RateModel): - pass + pass class VanDerPolOscillator(RateModel): - pass + pass class ThresholdLinearModel(RateModel): - r"""A threshold linear rate model. - - The threshold linear rate model is given by [1]_ - - .. math:: - - \begin{aligned} - &\tau_{E} \frac{d \nu_{E}}{d t}=-\nu_{E}+\beta_{E}\left[I_{E}\right]_{+} \\ - &\tau_{I} \frac{d \nu_{I}}{d t}=-\nu_{I}+\beta_{I}\left[I_{I}\right]_{+} - \end{aligned} - - where :math:`\left[I_{E}\right]_{+}=\max \left(I_{E}, 0\right)`. - :math:`v_E` and :math:`v_I` denote the firing rates of the excitatory and inhibitory - populations respectively, :math:`\tau_E` and :math:`\tau_I` are the corresponding - intrinsic time constants. - - - Reference - --------- - .. [1] Chaudhuri, Rishidev, et al. "A large-scale circuit mechanism - for hierarchical dynamical processing in the primate cortex." - Neuron 88.2 (2015): 419-431. - - """ - - def __init__( - self, - size: Shape, - tau_e: Union[float, Callable, Initializer, ArrayType] = 2e-2, - tau_i: Union[float, Callable, Initializer, ArrayType] = 1e-2, - beta_e: Union[float, Callable, Initializer, ArrayType] = .066, - beta_i: Union[float, Callable, Initializer, ArrayType] = .351, - noise_e: Union[float, Callable, Initializer, ArrayType] = 0., - noise_i: Union[float, Callable, Initializer, ArrayType] = 0., - e_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - i_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - seed: int = None, - keep_size: bool = False, - name: str = None, - - # parameter for training - mode: bm.Mode = None, - input_var: bool = True, - ): - super(ThresholdLinearModel, self).__init__(size, - name=name, - keep_size=keep_size, - mode=mode) - - # parameters - self.seed = seed - self.tau_e = parameter(tau_e, self.varshape, False) - self.tau_i = parameter(tau_i, self.varshape, False) - self.beta_e = parameter(beta_e, self.varshape, False) - self.beta_i = parameter(beta_i, self.varshape, False) - self.noise_e = parameter(noise_e, self.varshape, False) - self.noise_i = parameter(noise_i, self.varshape, False) - self._e_initializer = e_initializer - self._i_initializer = i_initializer - self.input_var = input_var - - # variables - self.e = variable(e_initializer, self.mode, self.varshape) # Firing rate of excitatory population - self.i = variable(i_initializer, self.mode, self.varshape) # Firing rate of inhibitory population - if self.input_var: - self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population - self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population - - def reset_state(self, batch_or_mode=None, **kwargs): - self.e.value = variable(self._e_initializer, batch_or_mode, self.varshape) - self.i.value = variable(self._i_initializer, batch_or_mode, self.varshape) - if self.input_var: - self.Ie.value = variable(bm.zeros, batch_or_mode, self.varshape) - self.Ii.value = variable(bm.zeros, batch_or_mode, self.varshape) - - def update(self, inp_e=None, inp_i=None): - dt = share.load('dt') - - # input - if self.input_var: - if inp_e is not None: - self.Ie += inp_e - if inp_i is not None: - self.Ii += inp_i - input_e = self.Ie.value - input_i = self.Ii.value - else: - input_e = inp_e if (inp_e is not None) else 0. - input_i = inp_i if (inp_i is not None) else 0. - - de = -self.e + self.beta_e * bm.maximum(input_e, 0.) - with jax.ensure_compile_time_eval(): - has_noise = bm.any(self.noise_e != 0.) - - if has_noise: - de += bm.random.randn(self.varshape) * self.noise_e - de = de / self.tau_e - self.e.value = bm.maximum(self.e + de * dt, 0.) - - di = -self.i + self.beta_i * bm.maximum(input_i, 0.) - with jax.ensure_compile_time_eval(): - has_noise = bm.any(self.noise_i != 0.) - - if has_noise: - di += bm.random.randn(self.varshape) * self.noise_i - di = di / self.tau_i - self.i.value = bm.maximum(self.i + di * dt, 0.) - return self.e.value - - def clear_input(self): - if self.input_var: - self.Ie.value = bm.zeros_like(self.Ie) - self.Ii.value = bm.zeros_like(self.Ii) + r"""A threshold linear rate model. + + The threshold linear rate model is given by [1]_ + + .. math:: + + \begin{aligned} + &\tau_{E} \frac{d \nu_{E}}{d t}=-\nu_{E}+\beta_{E}\left[I_{E}\right]_{+} \\ + &\tau_{I} \frac{d \nu_{I}}{d t}=-\nu_{I}+\beta_{I}\left[I_{I}\right]_{+} + \end{aligned} + + where :math:`\left[I_{E}\right]_{+}=\max \left(I_{E}, 0\right)`. + :math:`v_E` and :math:`v_I` denote the firing rates of the excitatory and inhibitory + populations respectively, :math:`\tau_E` and :math:`\tau_I` are the corresponding + intrinsic time constants. + + + Reference + --------- + .. [1] Chaudhuri, Rishidev, et al. "A large-scale circuit mechanism + for hierarchical dynamical processing in the primate cortex." + Neuron 88.2 (2015): 419-431. + + """ + + def __init__( + self, + size: Shape, + tau_e: Union[float, Callable, Initializer, ArrayType] = 2e-2, + tau_i: Union[float, Callable, Initializer, ArrayType] = 1e-2, + beta_e: Union[float, Callable, Initializer, ArrayType] = .066, + beta_i: Union[float, Callable, Initializer, ArrayType] = .351, + noise_e: Union[float, Callable, Initializer, ArrayType] = 0., + noise_i: Union[float, Callable, Initializer, ArrayType] = 0., + e_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + i_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + seed: int = None, + keep_size: bool = False, + name: str = None, + + # parameter for training + mode: bm.Mode = None, + input_var: bool = True, + ): + super(ThresholdLinearModel, self).__init__(size, + name=name, + keep_size=keep_size, + mode=mode) + + # parameters + self.seed = seed + self.tau_e = parameter(tau_e, self.varshape, False) + self.tau_i = parameter(tau_i, self.varshape, False) + self.beta_e = parameter(beta_e, self.varshape, False) + self.beta_i = parameter(beta_i, self.varshape, False) + self.noise_e = parameter(noise_e, self.varshape, False) + self.noise_i = parameter(noise_i, self.varshape, False) + self._e_initializer = e_initializer + self._i_initializer = i_initializer + self.input_var = input_var + + # variables + self.e = variable(e_initializer, self.mode, self.varshape) # Firing rate of excitatory population + self.i = variable(i_initializer, self.mode, self.varshape) # Firing rate of inhibitory population + if self.input_var: + self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population + self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population + + def reset_state(self, batch_or_mode=None, **kwargs): + self.e.value = variable(self._e_initializer, batch_or_mode, self.varshape) + self.i.value = variable(self._i_initializer, batch_or_mode, self.varshape) + if self.input_var: + self.Ie.value = variable(bm.zeros, batch_or_mode, self.varshape) + self.Ii.value = variable(bm.zeros, batch_or_mode, self.varshape) + + def update(self, inp_e=None, inp_i=None): + dt = share.load('dt') + + # input + if self.input_var: + if inp_e is not None: + self.Ie += inp_e + if inp_i is not None: + self.Ii += inp_i + input_e = self.Ie.value + input_i = self.Ii.value + else: + input_e = inp_e if (inp_e is not None) else 0. + input_i = inp_i if (inp_i is not None) else 0. + + de = -self.e + self.beta_e * bm.maximum(input_e, 0.) + with jax.ensure_compile_time_eval(): + has_noise = bm.any(self.noise_e != 0.) + + if has_noise: + de += bm.random.randn(self.varshape) * self.noise_e + de = de / self.tau_e + self.e.value = bm.maximum(self.e + de * dt, 0.) + + di = -self.i + self.beta_i * bm.maximum(input_i, 0.) + with jax.ensure_compile_time_eval(): + has_noise = bm.any(self.noise_i != 0.) + + if has_noise: + di += bm.random.randn(self.varshape) * self.noise_i + di = di / self.tau_i + self.i.value = bm.maximum(self.i + di * dt, 0.) + return self.e.value + + def clear_input(self): + if self.input_var: + self.Ie.value = bm.zeros_like(self.Ie) + self.Ii.value = bm.zeros_like(self.Ii) diff --git a/brainpy/_src/dyn/rates/reservoir.py b/brainpy/_src/dyn/rates/reservoir.py index f6a2247fd..33e43ca59 100644 --- a/brainpy/_src/dyn/rates/reservoir.py +++ b/brainpy/_src/dyn/rates/reservoir.py @@ -5,215 +5,215 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy._src.initialize import Normal, ZeroInit, Initializer, parameter, variable from brainpy import check +from brainpy._src.dnn.base import Layer +from brainpy._src.initialize import Normal, ZeroInit, Initializer, parameter, variable from brainpy.tools import to_size from brainpy.types import ArrayType -from brainpy._src.dnn.base import Layer __all__ = [ - 'Reservoir', + 'Reservoir', ] class Reservoir(Layer): - r"""Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_. - - Parameters:: - - input_shape: int, tuple of int - The input shape. - num_out: int - The number of reservoir nodes. - Win_initializer: Initializer - The initialization method for the feedforward connections. - Wrec_initializer: Initializer - The initialization method for the recurrent connections. - b_initializer: optional, ArrayType, Initializer - The initialization method for the bias. - leaky_rate: float - A float between 0 and 1. - activation : str, callable, optional - Reservoir activation function. - - - If a str, should be a :py:mod:`brainpy.math.activations` function name. - - If a callable, should be an element-wise operator. - activation_type : str - - If "internal" (default), then leaky integration happens on states transformed - by the activation function: - - .. math:: - - r[n+1] = (1 - \alpha) \cdot r[t] + - \alpha \cdot f(W_{ff} \cdot u[n] + W_{fb} \cdot b[n] + W_{rec} \cdot r[t]) - - - If "external", then leaky integration happens on internal states of - each neuron, stored in an ``internal_state`` parameter (:math:`x` in - the equation below). - A neuron internal state is the value of its state before applying - the activation function :math:`f`: + r"""Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_. + + Parameters:: + + input_shape: int, tuple of int + The input shape. + num_out: int + The number of reservoir nodes. + Win_initializer: Initializer + The initialization method for the feedforward connections. + Wrec_initializer: Initializer + The initialization method for the recurrent connections. + b_initializer: optional, ArrayType, Initializer + The initialization method for the bias. + leaky_rate: float + A float between 0 and 1. + activation : str, callable, optional + Reservoir activation function. + + - If a str, should be a :py:mod:`brainpy.math.activations` function name. + - If a callable, should be an element-wise operator. + activation_type : str + - If "internal" (default), then leaky integration happens on states transformed + by the activation function: .. math:: - x[n+1] &= (1 - \alpha) \cdot x[t] + - \alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\ - r[n+1] &= f(x[n+1]) - in_connectivity : float, optional - Connectivity of input neurons, i.e. ratio of input neurons connected - to reservoir neurons. Must be in [0, 1], by default 0.1 - rec_connectivity : float, optional - Connectivity of recurrent weights matrix, i.e. ratio of reservoir - neurons connected to other reservoir neurons, including themselves. - Must be in [0, 1], by default 0.1 - comp_type: str - The connectivity type, can be "dense" or "sparse", "jit". - - - ``"dense"`` means the connectivity matrix is a dense matrix. - - ``"sparse"`` means the connectivity matrix is a CSR sparse matrix. - spectral_radius : float, optional - Spectral radius of recurrent weight matrix, by default None. - noise_rec : float, optional - Gain of noise applied to reservoir internal states, by default 0.0 - noise_in : float, optional - Gain of noise applied to feedforward signals, by default 0.0 - noise_type : optional, str, callable - Distribution of noise. Must be a random variable generator - distribution (see :py:class:`brainpy.math.random.RandomState`), - by default "normal". - - References:: - - .. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks." - Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686. - """ - - def __init__( - self, - input_shape: Union[int, Tuple[int]], - num_out: int, - leaky_rate: float = 0.3, - activation: Union[str, Callable] = 'tanh', - activation_type: str = 'internal', - Win_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1), - Wrec_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1), - b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), - in_connectivity: float = 0.1, - rec_connectivity: float = 0.1, - comp_type: str = 'dense', - spectral_radius: Optional[float] = None, - noise_in: float = 0., - noise_rec: float = 0., - noise_type: str = 'normal', - mode: Optional[bm.Mode] = None, - name: Optional[str] = None - ): - super(Reservoir, self).__init__(mode=mode, name=name) - - # parameters - input_shape = to_size(input_shape) - if input_shape[0] is None: - input_shape = input_shape[1:] - self.input_shape = input_shape - self.output_shape = input_shape[:-1] + (num_out,) - self.num_unit = num_out - assert num_out > 0, f'Must be a positive integer, but we got {num_out}' - self.leaky_rate = leaky_rate - check.is_float(leaky_rate, 'leaky_rate', 0., 1.) - self.activation = getattr(bm, activation) if isinstance(activation, str) else activation - check.is_callable(self.activation, allow_none=False) - self.activation_type = activation_type - check.is_string(activation_type, 'activation_type', ['internal', 'external']) - check.is_float(spectral_radius, 'spectral_radius', allow_none=True) - self.spectral_radius = spectral_radius - - # initializations - check.is_initializer(Win_initializer, 'ff_initializer', allow_none=False) - check.is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False) - check.is_initializer(b_initializer, 'bias_initializer', allow_none=True) - self._Win_initializer = Win_initializer - self._Wrec_initializer = Wrec_initializer - self._b_initializer = b_initializer - - # connectivity - check.is_float(in_connectivity, 'ff_connectivity', 0., 1.) - check.is_float(rec_connectivity, 'rec_connectivity', 0., 1.) - self.ff_connectivity = in_connectivity - self.rec_connectivity = rec_connectivity - check.is_string(comp_type, 'conn_type', ['dense', 'sparse', 'jit']) - self.comp_type = comp_type - - # noises - check.is_float(noise_in, 'noise_ff') - check.is_float(noise_rec, 'noise_rec') - self.noise_ff = noise_in - self.noise_rec = noise_rec - self.noise_type = noise_type - check.is_string(noise_type, 'noise_type', ['normal', 'uniform']) - - # initialize feedforward weights - weight_shape = (input_shape[-1], self.num_unit) - self.Wff_shape = weight_shape - self.Win = parameter(self._Win_initializer, weight_shape) - if self.ff_connectivity < 1.: - conn_mat = bm.random.random(weight_shape) > self.ff_connectivity - self.Win[conn_mat] = 0. - if self.comp_type == 'sparse' and self.ff_connectivity < 1.: - self.ff_pres, self.ff_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) - self.Win = self.Win[self.ff_pres, self.ff_posts] - if isinstance(self.mode, bm.TrainingMode): - self.Win = bm.TrainVar(self.Win) - - # initialize recurrent weights - recurrent_shape = (self.num_unit, self.num_unit) - self.Wrec = parameter(self._Wrec_initializer, recurrent_shape) - if self.rec_connectivity < 1.: - conn_mat = bm.random.random(recurrent_shape) > self.rec_connectivity - self.Wrec[conn_mat] = 0. - if self.spectral_radius is not None: - current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0])) - self.Wrec *= self.spectral_radius / current_sr - if self.comp_type == 'sparse' and self.rec_connectivity < 1.: - self.rec_pres, self.rec_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) - self.Wrec = self.Wrec[self.rec_pres, self.rec_posts] - self.bias = parameter(self._b_initializer, (self.num_unit,)) - if isinstance(self.mode, bm.TrainingMode): - self.Wrec = bm.TrainVar(self.Wrec) - self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) - - # initialize state - self.state = variable(jnp.zeros, self.mode, self.output_shape) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.state.value = variable(jnp.zeros, batch_or_mode, self.output_shape) - - def update(self, x): - """Feedforward output.""" - # inputs - x = bm.as_jax(x) - if self.noise_ff > 0: - x += self.noise_ff * bm.random.uniform(-1, 1, x.shape) - if self.comp_type == 'sparse' and self.ff_connectivity < 1.: - sparse = {'data': self.Win, - 'index': (self.ff_pres, self.ff_posts), - 'shape': self.Wff_shape} - hidden = bm.sparse.seg_matmul(x, sparse) - else: - hidden = x @ self.Win - # recurrent - if self.comp_type == 'sparse' and self.rec_connectivity < 1.: - sparse = {'data': self.Wrec, - 'index': (self.rec_pres, self.rec_posts), - 'shape': (self.num_unit, self.num_unit)} - hidden += bm.sparse.seg_matmul(self.state, sparse) - else: - hidden += self.state @ self.Wrec - if self.activation_type == 'internal': - hidden = self.activation(hidden) - if self.noise_rec > 0.: - hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape) - # new state/output - state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden - if self.activation_type == 'external': - state = self.activation(state) - self.state.value = state - return state + r[n+1] = (1 - \alpha) \cdot r[t] + + \alpha \cdot f(W_{ff} \cdot u[n] + W_{fb} \cdot b[n] + W_{rec} \cdot r[t]) + + - If "external", then leaky integration happens on internal states of + each neuron, stored in an ``internal_state`` parameter (:math:`x` in + the equation below). + A neuron internal state is the value of its state before applying + the activation function :math:`f`: + + .. math:: + + x[n+1] &= (1 - \alpha) \cdot x[t] + + \alpha \cdot f(W_{ff} \cdot u[n] + W_{rec} \cdot r[t] + W_{fb} \cdot b[n]) \\ + r[n+1] &= f(x[n+1]) + in_connectivity : float, optional + Connectivity of input neurons, i.e. ratio of input neurons connected + to reservoir neurons. Must be in [0, 1], by default 0.1 + rec_connectivity : float, optional + Connectivity of recurrent weights matrix, i.e. ratio of reservoir + neurons connected to other reservoir neurons, including themselves. + Must be in [0, 1], by default 0.1 + comp_type: str + The connectivity type, can be "dense" or "sparse", "jit". + + - ``"dense"`` means the connectivity matrix is a dense matrix. + - ``"sparse"`` means the connectivity matrix is a CSR sparse matrix. + spectral_radius : float, optional + Spectral radius of recurrent weight matrix, by default None. + noise_rec : float, optional + Gain of noise applied to reservoir internal states, by default 0.0 + noise_in : float, optional + Gain of noise applied to feedforward signals, by default 0.0 + noise_type : optional, str, callable + Distribution of noise. Must be a random variable generator + distribution (see :py:class:`brainpy.math.random.RandomState`), + by default "normal". + + References:: + + .. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks." + Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686. + """ + + def __init__( + self, + input_shape: Union[int, Tuple[int]], + num_out: int, + leaky_rate: float = 0.3, + activation: Union[str, Callable] = 'tanh', + activation_type: str = 'internal', + Win_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1), + Wrec_initializer: Union[Initializer, Callable, ArrayType] = Normal(scale=0.1), + b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), + in_connectivity: float = 0.1, + rec_connectivity: float = 0.1, + comp_type: str = 'dense', + spectral_radius: Optional[float] = None, + noise_in: float = 0., + noise_rec: float = 0., + noise_type: str = 'normal', + mode: Optional[bm.Mode] = None, + name: Optional[str] = None + ): + super(Reservoir, self).__init__(mode=mode, name=name) + + # parameters + input_shape = to_size(input_shape) + if input_shape[0] is None: + input_shape = input_shape[1:] + self.input_shape = input_shape + self.output_shape = input_shape[:-1] + (num_out,) + self.num_unit = num_out + assert num_out > 0, f'Must be a positive integer, but we got {num_out}' + self.leaky_rate = leaky_rate + check.is_float(leaky_rate, 'leaky_rate', 0., 1.) + self.activation = getattr(bm, activation) if isinstance(activation, str) else activation + check.is_callable(self.activation, allow_none=False) + self.activation_type = activation_type + check.is_string(activation_type, 'activation_type', ['internal', 'external']) + check.is_float(spectral_radius, 'spectral_radius', allow_none=True) + self.spectral_radius = spectral_radius + + # initializations + check.is_initializer(Win_initializer, 'ff_initializer', allow_none=False) + check.is_initializer(Wrec_initializer, 'rec_initializer', allow_none=False) + check.is_initializer(b_initializer, 'bias_initializer', allow_none=True) + self._Win_initializer = Win_initializer + self._Wrec_initializer = Wrec_initializer + self._b_initializer = b_initializer + + # connectivity + check.is_float(in_connectivity, 'ff_connectivity', 0., 1.) + check.is_float(rec_connectivity, 'rec_connectivity', 0., 1.) + self.ff_connectivity = in_connectivity + self.rec_connectivity = rec_connectivity + check.is_string(comp_type, 'conn_type', ['dense', 'sparse', 'jit']) + self.comp_type = comp_type + + # noises + check.is_float(noise_in, 'noise_ff') + check.is_float(noise_rec, 'noise_rec') + self.noise_ff = noise_in + self.noise_rec = noise_rec + self.noise_type = noise_type + check.is_string(noise_type, 'noise_type', ['normal', 'uniform']) + + # initialize feedforward weights + weight_shape = (input_shape[-1], self.num_unit) + self.Wff_shape = weight_shape + self.Win = parameter(self._Win_initializer, weight_shape) + if self.ff_connectivity < 1.: + conn_mat = bm.random.random(weight_shape) > self.ff_connectivity + self.Win[conn_mat] = 0. + if self.comp_type == 'sparse' and self.ff_connectivity < 1.: + self.ff_pres, self.ff_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) + self.Win = self.Win[self.ff_pres, self.ff_posts] + if isinstance(self.mode, bm.TrainingMode): + self.Win = bm.TrainVar(self.Win) + + # initialize recurrent weights + recurrent_shape = (self.num_unit, self.num_unit) + self.Wrec = parameter(self._Wrec_initializer, recurrent_shape) + if self.rec_connectivity < 1.: + conn_mat = bm.random.random(recurrent_shape) > self.rec_connectivity + self.Wrec[conn_mat] = 0. + if self.spectral_radius is not None: + current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0])) + self.Wrec *= self.spectral_radius / current_sr + if self.comp_type == 'sparse' and self.rec_connectivity < 1.: + self.rec_pres, self.rec_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) + self.Wrec = self.Wrec[self.rec_pres, self.rec_posts] + self.bias = parameter(self._b_initializer, (self.num_unit,)) + if isinstance(self.mode, bm.TrainingMode): + self.Wrec = bm.TrainVar(self.Wrec) + self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) + + # initialize state + self.state = variable(jnp.zeros, self.mode, self.output_shape) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.state.value = variable(jnp.zeros, batch_or_mode, self.output_shape) + + def update(self, x): + """Feedforward output.""" + # inputs + x = bm.as_jax(x) + if self.noise_ff > 0: + x += self.noise_ff * bm.random.uniform(-1, 1, x.shape) + if self.comp_type == 'sparse' and self.ff_connectivity < 1.: + sparse = {'data': self.Win, + 'index': (self.ff_pres, self.ff_posts), + 'shape': self.Wff_shape} + hidden = bm.sparse.seg_matmul(x, sparse) + else: + hidden = x @ self.Win + # recurrent + if self.comp_type == 'sparse' and self.rec_connectivity < 1.: + sparse = {'data': self.Wrec, + 'index': (self.rec_pres, self.rec_posts), + 'shape': (self.num_unit, self.num_unit)} + hidden += bm.sparse.seg_matmul(self.state, sparse) + else: + hidden += self.state @ self.Wrec + if self.activation_type == 'internal': + hidden = self.activation(hidden) + if self.noise_rec > 0.: + hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape) + # new state/output + state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden + if self.activation_type == 'external': + state = self.activation(state) + self.state.value = state + return state diff --git a/brainpy/_src/dyn/rates/rnncells.py b/brainpy/_src/dyn/rates/rnncells.py index dfdd55576..54d95c490 100644 --- a/brainpy/_src/dyn/rates/rnncells.py +++ b/brainpy/_src/dyn/rates/rnncells.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- -import warnings from typing import Union, Callable, Sequence, Optional, Tuple import jax.numpy as jnp import brainpy.math as bm -from brainpy.math import activations from brainpy._src.dnn.base import Layer +from brainpy._src.dnn.conv import _GeneralConv from brainpy.check import (is_integer, is_initializer) from brainpy.initialize import (XavierNormal, @@ -17,701 +16,699 @@ variable, variable_, Initializer) +from brainpy.math import activations from brainpy.types import ArrayType -from brainpy._src.dnn.conv import _GeneralConv - __all__ = [ - 'RNNCell', 'GRUCell', 'LSTMCell', - 'Conv1dLSTMCell', 'Conv2dLSTMCell', 'Conv3dLSTMCell', + 'RNNCell', 'GRUCell', 'LSTMCell', + 'Conv1dLSTMCell', 'Conv2dLSTMCell', 'Conv3dLSTMCell', ] class RNNCell(Layer): - r"""Basic fully-connected RNN core. - - Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the - core computes - - .. math:: - - h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h) - - The output is equal to the new state, :math:`h_t`. - - - Parameters:: - - num_in: int - The dimension of the input vector - num_out: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - - """ - - def __init__( - self, - num_in: int, - num_out: int, - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - Wi_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), - Wh_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), - b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - activation: str = 'relu', - mode: bm.Mode = None, - train_state: bool = False, - name: str = None, - ): - super(RNNCell, self).__init__(mode=mode, name=name) - - # parameters - self._state_initializer = state_initializer - is_initializer(state_initializer, 'state_initializer', allow_none=False) - self.num_out = num_out - is_integer(num_out, 'num_out', min_bound=1, allow_none=False) - self.train_state = train_state - - # parameters - self.num_in = num_in - is_integer(num_in, 'num_in', min_bound=1, allow_none=False) - - # initializers - self._Wi_initializer = Wi_initializer - self._Wh_initializer = Wh_initializer - self._b_initializer = b_initializer - is_initializer(Wi_initializer, 'wi_initializer', allow_none=False) - is_initializer(Wh_initializer, 'wh_initializer', allow_none=False) - is_initializer(b_initializer, 'b_initializer', allow_none=True) - - # activation function - self.activation = getattr(activations, activation) - - # weights - self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out)) - self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out)) - self.b = parameter(self._b_initializer, (self.num_out,)) - if isinstance(self.mode, bm.TrainingMode): - self.Wi = bm.TrainVar(self.Wi) - self.Wh = bm.TrainVar(self.Wh) - self.b = None if (self.b is None) else bm.TrainVar(self.b) - - # state - self.state = variable(jnp.zeros, self.mode, self.num_out) - if train_state and isinstance(self.mode, bm.TrainingMode): - self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) - self.state[:] = self.state2train - - def reset_state(self, batch_or_mode=None, **kwargs): - self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out,), allow_none=False) - if self.train_state: - self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) - self.state[:] = self.state2train - - def update(self, x): - h = x @ self.Wi - h += self.state.value @ self.Wh - if self.b is not None: - h += self.b - self.state.value = self.activation(h) - return self.state.value + r"""Basic fully-connected RNN core. + + Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the + core computes + + .. math:: + + h_t = \mathrm{ReLU}(w_i x_t + b_i + w_h h_{t-1} + b_h) + + The output is equal to the new state, :math:`h_t`. + + + Parameters:: + + num_in: int + The dimension of the input vector + num_out: int + The number of hidden unit in the node. + state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The state initializer. + Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The input weight initializer. + Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The hidden weight initializer. + b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray + The bias weight initializer. + activation: str, callable + The activation function. It can be a string or a callable function. + See ``brainpy.math.activations`` for more details. + + """ + + def __init__( + self, + num_in: int, + num_out: int, + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + Wi_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), + Wh_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), + b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + activation: str = 'relu', + mode: bm.Mode = None, + train_state: bool = False, + name: str = None, + ): + super(RNNCell, self).__init__(mode=mode, name=name) + + # parameters + self._state_initializer = state_initializer + is_initializer(state_initializer, 'state_initializer', allow_none=False) + self.num_out = num_out + is_integer(num_out, 'num_out', min_bound=1, allow_none=False) + self.train_state = train_state + + # parameters + self.num_in = num_in + is_integer(num_in, 'num_in', min_bound=1, allow_none=False) + + # initializers + self._Wi_initializer = Wi_initializer + self._Wh_initializer = Wh_initializer + self._b_initializer = b_initializer + is_initializer(Wi_initializer, 'wi_initializer', allow_none=False) + is_initializer(Wh_initializer, 'wh_initializer', allow_none=False) + is_initializer(b_initializer, 'b_initializer', allow_none=True) + + # activation function + self.activation = getattr(activations, activation) + + # weights + self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out)) + self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out)) + self.b = parameter(self._b_initializer, (self.num_out,)) + if isinstance(self.mode, bm.TrainingMode): + self.Wi = bm.TrainVar(self.Wi) + self.Wh = bm.TrainVar(self.Wh) + self.b = None if (self.b is None) else bm.TrainVar(self.b) + + # state + self.state = variable(jnp.zeros, self.mode, self.num_out) + if train_state and isinstance(self.mode, bm.TrainingMode): + self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) + self.state[:] = self.state2train + + def reset_state(self, batch_or_mode=None, **kwargs): + self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out,), allow_none=False) + if self.train_state: + self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) + self.state[:] = self.state2train + + def update(self, x): + h = x @ self.Wi + h += self.state.value @ self.Wh + if self.b is not None: + h += self.b + self.state.value = self.activation(h) + return self.state.value class GRUCell(Layer): - r"""Gated Recurrent Unit. - - The implementation is based on (Chung, et al., 2014) [1]_ with biases. - - Given :math:`x_t` and the previous state :math:`h_{t-1}` the core computes - - .. math:: - - \begin{array}{ll} - z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ - r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ - a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ - h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t - \end{array} - - where :math:`z_t` and :math:`r_t` are reset and update gates. - - The output is equal to the new hidden state, :math:`h_t`. - - Warning: Backwards compatibility of GRU weights is currently unsupported. - - Parameters:: - - num_in: int - The dimension of the input vector - num_out: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - - References:: - - .. [1] Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical - evaluation of gated recurrent neural networks on sequence modeling. - arXiv preprint arXiv:1412.3555. - """ - - def __init__( - self, - num_in: int, - num_out: int, - Wi_initializer: Union[ArrayType, Callable, Initializer] = Orthogonal(), - Wh_initializer: Union[ArrayType, Callable, Initializer] = Orthogonal(), - b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - activation: str = 'tanh', - mode: bm.Mode = None, - train_state: bool = False, - name: str = None, - ): - super(GRUCell, self).__init__(mode=mode, name=name) - - # parameters - self._state_initializer = state_initializer - is_initializer(state_initializer, 'state_initializer', allow_none=False) - self.num_out = num_out - is_integer(num_out, 'num_out', min_bound=1, allow_none=False) - self.train_state = train_state - self.num_in = num_in - is_integer(num_in, 'num_in', min_bound=1, allow_none=False) - - # initializers - self._Wi_initializer = Wi_initializer - self._Wh_initializer = Wh_initializer - self._b_initializer = b_initializer - is_initializer(Wi_initializer, 'Wi_initializer', allow_none=False) - is_initializer(Wh_initializer, 'Wh_initializer', allow_none=False) - is_initializer(b_initializer, 'b_initializer', allow_none=True) - - # activation function - self.activation = getattr(activations, activation) - - # weights - self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 3), allow_none=False) - self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 3), allow_none=False) - self.b = parameter(self._b_initializer, (self.num_out * 3,)) - if isinstance(self.mode, bm.TrainingMode): - self.Wi = bm.TrainVar(self.Wi) - self.Wh = bm.TrainVar(self.Wh) - self.b = bm.TrainVar(self.b) if (self.b is not None) else None - - # state - self.state = variable(jnp.zeros, self.mode, self.num_out) - if train_state and isinstance(self.mode, bm.TrainingMode): - self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) - self.state[:] = self.state2train - - def reset_state(self, batch_or_mode=None, **kwargs): - self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out), allow_none=False) - if self.train_state: - self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) - self.state[:] = self.state2train - - def update(self, x): - gates_x = jnp.matmul(x, bm.as_jax(self.Wi)) - zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1) - w_h_z, w_h_a = jnp.split(bm.as_jax(self.Wh), indices_or_sections=[2 * self.num_out], axis=-1) - zr_h = jnp.matmul(bm.as_jax(self.state), w_h_z) - zr = zr_x + zr_h - has_bias = (self.b is not None) - if has_bias: - b_z, b_a = jnp.split(bm.as_jax(self.b), indices_or_sections=[2 * self.num_out], axis=0) - zr += jnp.broadcast_to(b_z, zr_h.shape) - z, r = jnp.split(bm.sigmoid(zr), indices_or_sections=2, axis=-1) - a_h = jnp.matmul(r * self.state, w_h_a) - if has_bias: - a = self.activation(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape)) - else: - a = self.activation(a_x + a_h) - next_state = (1 - z) * self.state + z * a - self.state.value = next_state - return self.state.value + r"""Gated Recurrent Unit. + + The implementation is based on (Chung, et al., 2014) [1]_ with biases. + + Given :math:`x_t` and the previous state :math:`h_{t-1}` the core computes + + .. math:: + + \begin{array}{ll} + z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ + r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ + a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t \bigodot h_{t-1}) + b_a) \\ + h_t &= (1 - z_t) \bigodot h_{t-1} + z_t \bigodot a_t + \end{array} + + where :math:`z_t` and :math:`r_t` are reset and update gates. + + The output is equal to the new hidden state, :math:`h_t`. + + Warning: Backwards compatibility of GRU weights is currently unsupported. + + Parameters:: + + num_in: int + The dimension of the input vector + num_out: int + The number of hidden unit in the node. + state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The state initializer. + Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The input weight initializer. + Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The hidden weight initializer. + b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray + The bias weight initializer. + activation: str, callable + The activation function. It can be a string or a callable function. + See ``brainpy.math.activations`` for more details. + + References:: + + .. [1] Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical + evaluation of gated recurrent neural networks on sequence modeling. + arXiv preprint arXiv:1412.3555. + """ + + def __init__( + self, + num_in: int, + num_out: int, + Wi_initializer: Union[ArrayType, Callable, Initializer] = Orthogonal(), + Wh_initializer: Union[ArrayType, Callable, Initializer] = Orthogonal(), + b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + activation: str = 'tanh', + mode: bm.Mode = None, + train_state: bool = False, + name: str = None, + ): + super(GRUCell, self).__init__(mode=mode, name=name) + + # parameters + self._state_initializer = state_initializer + is_initializer(state_initializer, 'state_initializer', allow_none=False) + self.num_out = num_out + is_integer(num_out, 'num_out', min_bound=1, allow_none=False) + self.train_state = train_state + self.num_in = num_in + is_integer(num_in, 'num_in', min_bound=1, allow_none=False) + + # initializers + self._Wi_initializer = Wi_initializer + self._Wh_initializer = Wh_initializer + self._b_initializer = b_initializer + is_initializer(Wi_initializer, 'Wi_initializer', allow_none=False) + is_initializer(Wh_initializer, 'Wh_initializer', allow_none=False) + is_initializer(b_initializer, 'b_initializer', allow_none=True) + + # activation function + self.activation = getattr(activations, activation) + + # weights + self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 3), allow_none=False) + self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 3), allow_none=False) + self.b = parameter(self._b_initializer, (self.num_out * 3,)) + if isinstance(self.mode, bm.TrainingMode): + self.Wi = bm.TrainVar(self.Wi) + self.Wh = bm.TrainVar(self.Wh) + self.b = bm.TrainVar(self.b) if (self.b is not None) else None + + # state + self.state = variable(jnp.zeros, self.mode, self.num_out) + if train_state and isinstance(self.mode, bm.TrainingMode): + self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out,), allow_none=False)) + self.state[:] = self.state2train + + def reset_state(self, batch_or_mode=None, **kwargs): + self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out), allow_none=False) + if self.train_state: + self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) + self.state[:] = self.state2train + + def update(self, x): + gates_x = jnp.matmul(x, bm.as_jax(self.Wi)) + zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1) + w_h_z, w_h_a = jnp.split(bm.as_jax(self.Wh), indices_or_sections=[2 * self.num_out], axis=-1) + zr_h = jnp.matmul(bm.as_jax(self.state), w_h_z) + zr = zr_x + zr_h + has_bias = (self.b is not None) + if has_bias: + b_z, b_a = jnp.split(bm.as_jax(self.b), indices_or_sections=[2 * self.num_out], axis=0) + zr += jnp.broadcast_to(b_z, zr_h.shape) + z, r = jnp.split(bm.sigmoid(zr), indices_or_sections=2, axis=-1) + a_h = jnp.matmul(r * self.state, w_h_a) + if has_bias: + a = self.activation(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape)) + else: + a = self.activation(a_x + a_h) + next_state = (1 - z) * self.state + z * a + self.state.value = next_state + return self.state.value class LSTMCell(Layer): - r"""Long short-term memory (LSTM) RNN core. - - The implementation is based on (zaremba, et al., 2014) [1]_. Given - :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core - computes - - .. math:: - - \begin{array}{ll} - i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ - f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ - g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ - o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ - c_t = f_t c_{t-1} + i_t g_t \\ - h_t = o_t \tanh(c_t) - \end{array} - - where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and - output gate activations, and :math:`g_t` is a vector of cell updates. - - The output is equal to the new hidden, :math:`h_t`. - - Notes:: - - Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0 - to :math:`b_f` after initialization in order to reduce the scale of forgetting in - the beginning of the training. - - - Parameters:: - - num_in: int - The dimension of the input vector - num_out: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - - References:: - - .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural - network regularization." arXiv preprint arXiv:1409.2329 (2014). - .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical - exploration of recurrent network architectures." In International conference - on machine learning, pp. 2342-2350. PMLR, 2015. - """ - - def __init__( - self, - num_in: int, - num_out: int, - Wi_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), - Wh_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), - b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - activation: str = 'tanh', - mode: bm.Mode = None, - train_state: bool = False, - name: str = None, - ): - super(LSTMCell, self).__init__(mode=mode, name=name) - - # parameters - self._state_initializer = state_initializer - is_initializer(state_initializer, 'state_initializer', allow_none=False) - self.num_out = num_out - is_integer(num_out, 'num_out', min_bound=1, allow_none=False) - self.train_state = train_state - self.num_in = num_in - is_integer(num_in, 'num_in', min_bound=1, allow_none=False) - - # initializers - self._state_initializer = state_initializer - self._Wi_initializer = Wi_initializer - self._Wh_initializer = Wh_initializer - self._b_initializer = b_initializer - is_initializer(Wi_initializer, 'wi_initializer', allow_none=False) - is_initializer(Wh_initializer, 'wh_initializer', allow_none=False) - is_initializer(b_initializer, 'b_initializer', allow_none=True) - is_initializer(state_initializer, 'state_initializer', allow_none=False) - - # activation function - self.activation = getattr(activations, activation) - - # weights - self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 4)) - self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 4)) - self.b = parameter(self._b_initializer, (self.num_out * 4,)) - if isinstance(self.mode, bm.TrainingMode): - self.Wi = bm.TrainVar(self.Wi) - self.Wh = bm.TrainVar(self.Wh) - self.b = None if (self.b is None) else bm.TrainVar(self.b) - - # state - self.state = variable(jnp.zeros, self.mode, self.num_out * 2) - if train_state and isinstance(self.mode, bm.TrainingMode): - self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out * 2,), allow_none=False)) - self.state[:] = self.state2train - - def reset_state(self, batch_or_mode=None, **kwargs): - self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out * 2), allow_none=False) - if self.train_state: - self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False) - self.state[:] = self.state2train - - def update(self, x): - h, c = bm.split(self.state.value, 2, axis=-1) - gated = x @ self.Wi - if self.b is not None: - gated += self.b - gated += h @ self.Wh - i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1) - c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * self.activation(g) - h = bm.sigmoid(o) * self.activation(c) - self.state.value = bm.concatenate([h, c], axis=-1) - return h - - @property - def h(self): - """Hidden state.""" - return jnp.split(self.state.value, 2, axis=-1)[0] - - @h.setter - def h(self, value): - if self.state is None: - raise ValueError('Cannot set "h" state. Because the state is not initialized.') - self.state[:self.state.shape[0] // 2, :] = value - - @property - def c(self): - """Memory cell.""" - return jnp.split(self.state.value, 2, axis=-1)[1] - - @c.setter - def c(self, value): - if self.state is None: - raise ValueError('Cannot set "c" state. Because the state is not initialized.') - self.state[self.state.shape[0] // 2:, :] = value + r"""Long short-term memory (LSTM) RNN core. + + The implementation is based on (zaremba, et al., 2014) [1]_. Given + :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core + computes + + .. math:: + + \begin{array}{ll} + i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ + f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ + g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ + o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ + c_t = f_t c_{t-1} + i_t g_t \\ + h_t = o_t \tanh(c_t) + \end{array} + + where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and + output gate activations, and :math:`g_t` is a vector of cell updates. + + The output is equal to the new hidden, :math:`h_t`. + + Notes:: + + Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0 + to :math:`b_f` after initialization in order to reduce the scale of forgetting in + the beginning of the training. + + + Parameters:: + + num_in: int + The dimension of the input vector + num_out: int + The number of hidden unit in the node. + state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The state initializer. + Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The input weight initializer. + Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray + The hidden weight initializer. + b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray + The bias weight initializer. + activation: str, callable + The activation function. It can be a string or a callable function. + See ``brainpy.math.activations`` for more details. + + References:: + + .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural + network regularization." arXiv preprint arXiv:1409.2329 (2014). + .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical + exploration of recurrent network architectures." In International conference + on machine learning, pp. 2342-2350. PMLR, 2015. + """ + def __init__( + self, + num_in: int, + num_out: int, + Wi_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), + Wh_initializer: Union[ArrayType, Callable, Initializer] = XavierNormal(), + b_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + activation: str = 'tanh', + mode: bm.Mode = None, + train_state: bool = False, + name: str = None, + ): + super(LSTMCell, self).__init__(mode=mode, name=name) + + # parameters + self._state_initializer = state_initializer + is_initializer(state_initializer, 'state_initializer', allow_none=False) + self.num_out = num_out + is_integer(num_out, 'num_out', min_bound=1, allow_none=False) + self.train_state = train_state + self.num_in = num_in + is_integer(num_in, 'num_in', min_bound=1, allow_none=False) + + # initializers + self._state_initializer = state_initializer + self._Wi_initializer = Wi_initializer + self._Wh_initializer = Wh_initializer + self._b_initializer = b_initializer + is_initializer(Wi_initializer, 'wi_initializer', allow_none=False) + is_initializer(Wh_initializer, 'wh_initializer', allow_none=False) + is_initializer(b_initializer, 'b_initializer', allow_none=True) + is_initializer(state_initializer, 'state_initializer', allow_none=False) + + # activation function + self.activation = getattr(activations, activation) + + # weights + self.Wi = parameter(self._Wi_initializer, (num_in, self.num_out * 4)) + self.Wh = parameter(self._Wh_initializer, (self.num_out, self.num_out * 4)) + self.b = parameter(self._b_initializer, (self.num_out * 4,)) + if isinstance(self.mode, bm.TrainingMode): + self.Wi = bm.TrainVar(self.Wi) + self.Wh = bm.TrainVar(self.Wh) + self.b = None if (self.b is None) else bm.TrainVar(self.b) + + # state + self.state = variable(jnp.zeros, self.mode, self.num_out * 2) + if train_state and isinstance(self.mode, bm.TrainingMode): + self.state2train = bm.TrainVar(parameter(state_initializer, (self.num_out * 2,), allow_none=False)) + self.state[:] = self.state2train + + def reset_state(self, batch_or_mode=None, **kwargs): + self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out * 2), allow_none=False) + if self.train_state: + self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False) + self.state[:] = self.state2train + + def update(self, x): + h, c = bm.split(self.state.value, 2, axis=-1) + gated = x @ self.Wi + if self.b is not None: + gated += self.b + gated += h @ self.Wh + i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1) + c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * self.activation(g) + h = bm.sigmoid(o) * self.activation(c) + self.state.value = bm.concatenate([h, c], axis=-1) + return h + + @property + def h(self): + """Hidden state.""" + return jnp.split(self.state.value, 2, axis=-1)[0] + + @h.setter + def h(self, value): + if self.state is None: + raise ValueError('Cannot set "h" state. Because the state is not initialized.') + self.state[:self.state.shape[0] // 2, :] = value + + @property + def c(self): + """Memory cell.""" + return jnp.split(self.state.value, 2, axis=-1)[1] + + @c.setter + def c(self, value): + if self.state is None: + raise ValueError('Cannot set "c" state. Because the state is not initialized.') + self.state[self.state.shape[0] // 2:, :] = value class _ConvNDLSTMCell(Layer): - r"""``num_spatial_dims``-D convolutional LSTM. - - The implementation is based on :cite:`xingjian2015convolutional`. - Given :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` - the core computes - - .. math:: - - \begin{array}{ll} - i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ - f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ - g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ - o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ - c_t = f_t c_{t-1} + i_t g_t \\ - h_t = o_t \tanh(c_t) - \end{array} - - where :math:`*` denotes the convolution operator; :math:`i_t`, - :math:`f_t`, :math:`o_t` are input, forget and output gate activations, - and :math:`g_t` is a vector of cell updates. - - The output is equal to the new hidden state, :math:`h_t`. - - Notes: - Forget gate initialization: - Following :cite:`jozefowicz2015empirical` we add 1.0 to :math:`b_f` - after initialization in order to reduce the scale of forgetting in - the beginning of the training. - """ - - def __init__( - self, - input_shape: Tuple[int, ...], - - # convolution parameters - num_spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - - # recurrent parameters - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - train_state: bool = False, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - """Constructs a convolutional LSTM. - - Args: - num_spatial_dims: Number of spatial dimensions of the input. - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length ``num_spatial_dims``), - or an int. ``kernel_shape`` will be expanded to define a kernel size in - all dimensions. - name: Name of the module. + r"""``num_spatial_dims``-D convolutional LSTM. + + The implementation is based on :cite:`xingjian2015convolutional`. + Given :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` + the core computes + + .. math:: + + \begin{array}{ll} + i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ + f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ + g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ + o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ + c_t = f_t c_{t-1} + i_t g_t \\ + h_t = o_t \tanh(c_t) + \end{array} + + where :math:`*` denotes the convolution operator; :math:`i_t`, + :math:`f_t`, :math:`o_t` are input, forget and output gate activations, + and :math:`g_t` is a vector of cell updates. + + The output is equal to the new hidden state, :math:`h_t`. + + Notes: + Forget gate initialization: + Following :cite:`jozefowicz2015empirical` we add 1.0 to :math:`b_f` + after initialization in order to reduce the scale of forgetting in + the beginning of the training. """ - super().__init__(name=name, mode=mode) - - # parameters - self._state_initializer = state_initializer - is_initializer(state_initializer, 'state_initializer', allow_none=False) - self.train_state = train_state - self.num_spatial_dims = num_spatial_dims - self.in_channels = in_channels - self.out_channels = out_channels - self.input_shape = tuple(input_shape) - self.input_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims, - in_channels=in_channels, - out_channels=out_channels * 4, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - mode=mode) - self.hidden_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims, - in_channels=out_channels, - out_channels=out_channels * 4, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - mode=mode) - self.reset_state() - - def reset_state(self, batch_or_mode: int = 1, **kwargs): - if self.mode.is_a(bm.NonBatchingMode): - shape = self.input_shape + (self.out_channels,) - self.h = variable_(self._state_initializer, shape) - self.c = variable_(self._state_initializer, shape) - else: - shape = self.input_shape + (self.out_channels,) - self.h = variable_(self._state_initializer, shape, batch_or_mode) - self.c = variable_(self._state_initializer, shape, batch_or_mode) - self.c = variable_(self.c, batch_axis=0) - if self.mode.is_a(bm.TrainingMode) and self.train_state: - h_to_train = parameter(self._state_initializer, shape, allow_none=False) - c_to_train = parameter(self._state_initializer, shape, allow_none=False) - self.h_to_train = bm.TrainVar(h_to_train) - self.c_to_train = bm.TrainVar(c_to_train) - self.h[:] = self.h_to_train - self.c[:] = self.c_to_train - - def update(self, x): - gates = self.input_to_hidden(x) + self.hidden_to_hidden(self.h) - i, g, f, o = bm.split(gates, indices_or_sections=4, axis=-1) - f = bm.sigmoid(f + 1) - c = f * self.c + bm.sigmoid(i) * bm.tanh(g) - h = bm.sigmoid(o) * bm.tanh(c) - self.h.value = h - self.c.value = c - return h + + def __init__( + self, + input_shape: Tuple[int, ...], + + # convolution parameters + num_spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + + # recurrent parameters + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + train_state: bool = False, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + """Constructs a convolutional LSTM. + + Args: + num_spatial_dims: Number of spatial dimensions of the input. + input_shape: Shape of the inputs excluding batch size. + out_channels: Number of output channels. + kernel_size: Sequence of kernel sizes (of length ``num_spatial_dims``), + or an int. ``kernel_shape`` will be expanded to define a kernel size in + all dimensions. + name: Name of the module. + """ + super().__init__(name=name, mode=mode) + + # parameters + self._state_initializer = state_initializer + is_initializer(state_initializer, 'state_initializer', allow_none=False) + self.train_state = train_state + self.num_spatial_dims = num_spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.input_shape = tuple(input_shape) + self.input_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims, + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + mode=mode) + self.hidden_to_hidden = _GeneralConv(num_spatial_dims=num_spatial_dims, + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + mode=mode) + self.reset_state() + + def reset_state(self, batch_or_mode: int = 1, **kwargs): + if self.mode.is_a(bm.NonBatchingMode): + shape = self.input_shape + (self.out_channels,) + self.h = variable_(self._state_initializer, shape) + self.c = variable_(self._state_initializer, shape) + else: + shape = self.input_shape + (self.out_channels,) + self.h = variable_(self._state_initializer, shape, batch_or_mode) + self.c = variable_(self._state_initializer, shape, batch_or_mode) + self.c = variable_(self.c, batch_axis=0) + if self.mode.is_a(bm.TrainingMode) and self.train_state: + h_to_train = parameter(self._state_initializer, shape, allow_none=False) + c_to_train = parameter(self._state_initializer, shape, allow_none=False) + self.h_to_train = bm.TrainVar(h_to_train) + self.c_to_train = bm.TrainVar(c_to_train) + self.h[:] = self.h_to_train + self.c[:] = self.c_to_train + + def update(self, x): + gates = self.input_to_hidden(x) + self.hidden_to_hidden(self.h) + i, g, f, o = bm.split(gates, indices_or_sections=4, axis=-1) + f = bm.sigmoid(f + 1) + c = f * self.c + bm.sigmoid(i) * bm.tanh(g) + h = bm.sigmoid(o) * bm.tanh(c) + self.h.value = h + self.c.value = c + return h class Conv1dLSTMCell(_ConvNDLSTMCell): # pylint: disable=empty-docstring - __doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "1") - - def __init__( - self, - input_shape: Tuple[int, ...], - - # convolution parameters - in_channels: int, - out_channels: int, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - - # recurrent parameters - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - train_state: bool = False, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - """Constructs a 1-D convolutional LSTM. - - Input: [Batch_Size, Input_Data_Size, Input_Channel_Size] - - Output: [Batch_Size, Output_Data_Size, Output_Channel_Size] - - Args: - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length 1), or an int. - ``kernel_shape`` will be expanded to define a kernel size in all - dimensions. - name: Name of the module. - """ - super().__init__( - num_spatial_dims=1, - input_shape=input_shape, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - state_initializer=state_initializer, - train_state=train_state, - mode=mode, - name=name - ) + __doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "1") + + def __init__( + self, + input_shape: Tuple[int, ...], + + # convolution parameters + in_channels: int, + out_channels: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + + # recurrent parameters + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + train_state: bool = False, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + """Constructs a 1-D convolutional LSTM. + + Input: [Batch_Size, Input_Data_Size, Input_Channel_Size] + + Output: [Batch_Size, Output_Data_Size, Output_Channel_Size] + + Args: + input_shape: Shape of the inputs excluding batch size. + out_channels: Number of output channels. + kernel_size: Sequence of kernel sizes (of length 1), or an int. + ``kernel_shape`` will be expanded to define a kernel size in all + dimensions. + name: Name of the module. + """ + super().__init__( + num_spatial_dims=1, + input_shape=input_shape, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + state_initializer=state_initializer, + train_state=train_state, + mode=mode, + name=name + ) class Conv2dLSTMCell(_ConvNDLSTMCell): # pylint: disable=empty-docstring - __doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "2") - - def __init__( - self, - input_shape: Tuple[int, ...], - - # convolution parameters - in_channels: int, - out_channels: int, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - - # recurrent parameters - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - train_state: bool = False, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - """Constructs a 2-D convolutional LSTM. - - Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2, Input_Channel_Size] - - Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2 , Output_Channel_Size] - - Args: - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length 2), or an int. - ``kernel_shape`` will be expanded to define a kernel size in all - dimensions. - name: Name of the module. - """ - super().__init__( - num_spatial_dims=2, - input_shape=input_shape, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - state_initializer=state_initializer, - train_state=train_state, - mode=mode, - name=name - ) + __doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "2") + + def __init__( + self, + input_shape: Tuple[int, ...], + + # convolution parameters + in_channels: int, + out_channels: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + + # recurrent parameters + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + train_state: bool = False, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + """Constructs a 2-D convolutional LSTM. + + Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2, Input_Channel_Size] + + Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2 , Output_Channel_Size] + + Args: + input_shape: Shape of the inputs excluding batch size. + out_channels: Number of output channels. + kernel_size: Sequence of kernel sizes (of length 2), or an int. + ``kernel_shape`` will be expanded to define a kernel size in all + dimensions. + name: Name of the module. + """ + super().__init__( + num_spatial_dims=2, + input_shape=input_shape, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + state_initializer=state_initializer, + train_state=train_state, + mode=mode, + name=name + ) class Conv3dLSTMCell(_ConvNDLSTMCell): # pylint: disable=empty-docstring - __doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "3") - - def __init__( - self, - input_shape: Tuple[int, ...], - - # convolution parameters - in_channels: int, - out_channels: int, - kernel_size: Union[int, Sequence[int]], - stride: Union[int, Tuple[int, ...]] = 1, - padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', - lhs_dilation: Union[int, Tuple[int, ...]] = 1, - rhs_dilation: Union[int, Tuple[int, ...]] = 1, - groups: int = 1, - w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), - b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), - - # recurrent parameters - state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), - train_state: bool = False, - - # others - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - """Constructs a 3-D convolutional LSTM. - - Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2,Input_Data_Size_Dim3 ,Input_Channel_Size] - - Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2,Output_Data_Size_Dim3,Output_Channel_Size] - - Args: - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length 3), or an int. - ``kernel_shape`` will be expanded to define a kernel size in all - dimensions. - name: Name of the module. - """ - super().__init__( - num_spatial_dims=3, - input_shape=input_shape, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - lhs_dilation=lhs_dilation, - rhs_dilation=rhs_dilation, - groups=groups, - w_initializer=w_initializer, - b_initializer=b_initializer, - state_initializer=state_initializer, - train_state=train_state, - mode=mode, - name=name - ) + __doc__ = _ConvNDLSTMCell.__doc__.replace("``num_spatial_dims``", "3") + + def __init__( + self, + input_shape: Tuple[int, ...], + + # convolution parameters + in_channels: int, + out_channels: int, + kernel_size: Union[int, Sequence[int]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[str, int, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME', + lhs_dilation: Union[int, Tuple[int, ...]] = 1, + rhs_dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + w_initializer: Union[Callable, ArrayType, Initializer] = XavierNormal(), + b_initializer: Optional[Union[Callable, ArrayType, Initializer]] = ZeroInit(), + + # recurrent parameters + state_initializer: Union[ArrayType, Callable, Initializer] = ZeroInit(), + train_state: bool = False, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + """Constructs a 3-D convolutional LSTM. + + Input: [Batch_Size, Input_Data_Size_Dim1,Input_Data_Size_Dim2,Input_Data_Size_Dim3 ,Input_Channel_Size] + + Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2,Output_Data_Size_Dim3,Output_Channel_Size] + + Args: + input_shape: Shape of the inputs excluding batch size. + out_channels: Number of output channels. + kernel_size: Sequence of kernel sizes (of length 3), or an int. + ``kernel_shape`` will be expanded to define a kernel size in all + dimensions. + name: Name of the module. + """ + super().__init__( + num_spatial_dims=3, + input_shape=input_shape, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + groups=groups, + w_initializer=w_initializer, + b_initializer=b_initializer, + state_initializer=state_initializer, + train_state=train_state, + mode=mode, + name=name + ) diff --git a/brainpy/_src/dyn/rates/tests/test_nvar.py b/brainpy/_src/dyn/rates/tests/test_nvar.py index 24659815c..f15b8cb6e 100644 --- a/brainpy/_src/dyn/rates/tests/test_nvar.py +++ b/brainpy/_src/dyn/rates/tests/test_nvar.py @@ -1,24 +1,27 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm + class Test_NVAR(parameterized.TestCase): @parameterized.product( mode=[bm.BatchingMode(), bm.NonBatchingMode()] ) - def test_NVAR(self,mode): + def test_NVAR(self, mode): bm.random.seed() - input=bm.random.randn(1,5) - layer=bp.dyn.NVAR(num_in=5, - delay=10, - mode=mode) + input = bm.random.randn(1, 5) + layer = bp.dyn.NVAR(num_in=5, + delay=10, + mode=mode) if mode in [bm.NonBatchingMode()]: for i in input: - output=layer(i) + output = layer(i) else: - output=layer(input) + output = layer(input) + if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/brainpy/_src/dyn/rates/tests/test_rates.py b/brainpy/_src/dyn/rates/tests/test_rates.py index ac394fb83..3c8d34514 100644 --- a/brainpy/_src/dyn/rates/tests/test_rates.py +++ b/brainpy/_src/dyn/rates/tests/test_rates.py @@ -1,105 +1,101 @@ # -*- coding: utf-8 -*- +from unittest import TestCase + +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized from brainpy._src.dyn.rates import populations -from unittest import TestCase class TestRate(TestCase): - def test_fhn(self): - bm.random.seed() - fhn = bp.rates.FHN(10) - self.assertTrue(fhn.tau is not None) + def test_fhn(self): + bm.random.seed() + fhn = bp.rates.FHN(10) + self.assertTrue(fhn.tau is not None) - def test_ffhn(self): - bm.random.seed() - ffhn = bp.rates.FeedbackFHN(size=1) - self.assertTrue(ffhn.tau is not None) + def test_ffhn(self): + bm.random.seed() + ffhn = bp.rates.FeedbackFHN(size=1) + self.assertTrue(ffhn.tau is not None) - def test_qif(self): - bm.random.seed() - qif = bp.rates.QIF(size=1) - self.assertTrue(qif.tau is not None) + def test_qif(self): + bm.random.seed() + qif = bp.rates.QIF(size=1) + self.assertTrue(qif.tau is not None) - def test_slo(self): - bm.random.seed() - slo = bp.rates.StuartLandauOscillator(size=1) - self.assertTrue(slo.x_ou_tau is not None) + def test_slo(self): + bm.random.seed() + slo = bp.rates.StuartLandauOscillator(size=1) + self.assertTrue(slo.x_ou_tau is not None) - def test_wcm(self): - bm.random.seed() - wcm = bp.rates.WilsonCowanModel(size=1) - self.assertTrue(wcm.x_ou_tau is not None) + def test_wcm(self): + bm.random.seed() + wcm = bp.rates.WilsonCowanModel(size=1) + self.assertTrue(wcm.x_ou_tau is not None) - def test_tlm(self): - bm.random.seed() - tlm = bp.rates.ThresholdLinearModel(size=1) - self.assertTrue(tlm.tau_e is not None) + def test_tlm(self): + bm.random.seed() + tlm = bp.rates.ThresholdLinearModel(size=1) + self.assertTrue(tlm.tau_e is not None) class TestPopulation(parameterized.TestCase): - @parameterized.named_parameters( - {'testcase_name': f'noise_of_{name}', 'neuron': name} - for name in populations.__all__ - ) - def test_runner(self, neuron): - bm.random.seed() - model = getattr(populations, neuron)(size=10) - runner = bp.DSRunner(model, progress_bar=False) - runner.run(10.) - + @parameterized.named_parameters( + {'testcase_name': f'noise_of_{name}', 'neuron': name} + for name in populations.__all__ + ) + def test_runner(self, neuron): + bm.random.seed() + model = getattr(populations, neuron)(size=10) + runner = bp.DSRunner(model, progress_bar=False) + runner.run(10.) -class TestShape(parameterized.TestCase): - def test_FHN_shape(self): - bm.random.seed() - model = getattr(populations, 'FHN')(size=10) - runner = bp.DSRunner(model, - monitors=['x'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon.x.shape, (100, 10)) - - - def test_FFHN_shape(self): - bm.random.seed() - model = getattr(populations, 'FeedbackFHN')(size=10) - runner = bp.DSRunner(model, - monitors=['x'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon.x.shape, (100, 10)) - - - def test_QIF_shape(self): - bm.random.seed() - model = getattr(populations, 'QIF')(size=10) - runner = bp.DSRunner(model, - monitors=['x'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon.x.shape, (100, 10)) - - - def test_SLO_shape(self): - bm.random.seed() - model = getattr(populations, 'StuartLandauOscillator')(size=10) - runner = bp.DSRunner(model, - monitors=['x'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon.x.shape, (100, 10)) - - - def test_TLM_shape(self): - bm.random.seed() - model = getattr(populations, 'ThresholdLinearModel')(size=10) - runner = bp.DSRunner(model, - monitors=['e'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon.e.shape, (100, 10)) - +class TestShape(parameterized.TestCase): + def test_FHN_shape(self): + bm.random.seed() + model = getattr(populations, 'FHN')(size=10) + runner = bp.DSRunner(model, + monitors=['x'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon.x.shape, (100, 10)) + + def test_FFHN_shape(self): + bm.random.seed() + model = getattr(populations, 'FeedbackFHN')(size=10) + runner = bp.DSRunner(model, + monitors=['x'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon.x.shape, (100, 10)) + + def test_QIF_shape(self): + bm.random.seed() + model = getattr(populations, 'QIF')(size=10) + runner = bp.DSRunner(model, + monitors=['x'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon.x.shape, (100, 10)) + + def test_SLO_shape(self): + bm.random.seed() + model = getattr(populations, 'StuartLandauOscillator')(size=10) + runner = bp.DSRunner(model, + monitors=['x'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon.x.shape, (100, 10)) + + def test_TLM_shape(self): + bm.random.seed() + model = getattr(populations, 'ThresholdLinearModel')(size=10) + runner = bp.DSRunner(model, + monitors=['e'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon.e.shape, (100, 10)) diff --git a/brainpy/_src/dyn/rates/tests/test_reservoir.py b/brainpy/_src/dyn/rates/tests/test_reservoir.py index 34d00c909..ca0df3691 100644 --- a/brainpy/_src/dyn/rates/tests/test_reservoir.py +++ b/brainpy/_src/dyn/rates/tests/test_reservoir.py @@ -1,7 +1,8 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class Test_Reservoir(parameterized.TestCase): diff --git a/brainpy/_src/dyn/rates/tests/test_rnncells.py b/brainpy/_src/dyn/rates/tests/test_rnncells.py index 206899aeb..7b907ef2f 100644 --- a/brainpy/_src/dyn/rates/tests/test_rnncells.py +++ b/brainpy/_src/dyn/rates/tests/test_rnncells.py @@ -1,171 +1,160 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class Test_Rnncells(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(20), - bm.BatchingMode(), - bm.BatchingMode(20) - ] - ) - def test_RNNCell(self, mode): - bm.random.seed() - input = bm.random.randn(20, 10) - layer = bp.dyn.RNNCell(num_in=10, - num_out=64, - mode=mode - ) - output = layer(input) - - - def test_RNNCell_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10) - layer = bp.dyn.RNNCell(num_in=10, - num_out=32, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(50), - bm.BatchingMode(), - bm.BatchingMode(50) - ] - ) - def test_GRUCell(self, mode): - bm.random.seed() - input = bm.random.randn(50, 100) - layer = bp.dyn.GRUCell(num_in=100, - num_out=64, - mode=mode) - output = layer(input) - - - def test_GRUCell_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10) - layer = bp.dyn.GRUCell(num_in=10, - num_out=12, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(50), - bm.BatchingMode(), - bm.BatchingMode(50) - ] - ) - def test_LSTMCell(self, mode): - bm.random.seed() - input = bm.random.randn(50, 100) - layer = bp.dyn.LSTMCell(num_in=100, - num_out=64, - mode=mode) - - output = layer(input) - - - def test_LSTMCell_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10) - layer = bp.dyn.LSTMCell(num_in=10, - num_out=5, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(4), - bm.BatchingMode(), - bm.BatchingMode(4)] - ) - def test_Conv1dLSTMCell(self, mode): - bm.random.seed() - input = bm.random.randn(4, 100, 3) - layer = bp.dyn.Conv1dLSTMCell(input_shape=(100,), - in_channels=3, - out_channels=5, - kernel_size=4, - mode=mode) - output = layer(input) - - - def test_Conv1dLSTMCell_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 3) - layer = bp.dyn.Conv1dLSTMCell(input_shape=(10,), - in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(4), - bm.BatchingMode(), - bm.BatchingMode(4)] - ) - def test_Conv2dLSTMCell(self, mode): - bm.random.seed() - input = bm.random.randn(4, 100, 100, 3) - layer = bp.dyn.Conv2dLSTMCell(input_shape=(100, 100), - in_channels=3, - out_channels=5, - kernel_size=(4, 4), - mode=mode) - output = layer(input) - - - def test_Conv2dLSTMCell_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dyn.Conv2dLSTMCell(input_shape=(10, 10), - in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(4), - bm.BatchingMode(), - bm.BatchingMode(4)] - ) - def test_Conv3dLSTMCell(self, mode): - bm.random.seed() - input = bm.random.randn(4, 100, 100, 100, 3) - layer = bp.dyn.Conv3dLSTMCell(input_shape=(100, 100, 100), - in_channels=3, - out_channels=5, - kernel_size=(4, 4, 4), - mode=mode) - output = layer(input) - - - def test_Conv3dLSTMCell_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dyn.Conv3dLSTMCell(input_shape=(10, 10, 10), - in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(20), + bm.BatchingMode(), + bm.BatchingMode(20) + ] + ) + def test_RNNCell(self, mode): + bm.random.seed() + input = bm.random.randn(20, 10) + layer = bp.dyn.RNNCell(num_in=10, + num_out=64, + mode=mode + ) + output = layer(input) + + def test_RNNCell_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10) + layer = bp.dyn.RNNCell(num_in=10, + num_out=32, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(50), + bm.BatchingMode(), + bm.BatchingMode(50) + ] + ) + def test_GRUCell(self, mode): + bm.random.seed() + input = bm.random.randn(50, 100) + layer = bp.dyn.GRUCell(num_in=100, + num_out=64, + mode=mode) + output = layer(input) + + def test_GRUCell_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10) + layer = bp.dyn.GRUCell(num_in=10, + num_out=12, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(50), + bm.BatchingMode(), + bm.BatchingMode(50) + ] + ) + def test_LSTMCell(self, mode): + bm.random.seed() + input = bm.random.randn(50, 100) + layer = bp.dyn.LSTMCell(num_in=100, + num_out=64, + mode=mode) + + output = layer(input) + + def test_LSTMCell_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10) + layer = bp.dyn.LSTMCell(num_in=10, + num_out=5, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(4), + bm.BatchingMode(), + bm.BatchingMode(4)] + ) + def test_Conv1dLSTMCell(self, mode): + bm.random.seed() + input = bm.random.randn(4, 100, 3) + layer = bp.dyn.Conv1dLSTMCell(input_shape=(100,), + in_channels=3, + out_channels=5, + kernel_size=4, + mode=mode) + output = layer(input) + + def test_Conv1dLSTMCell_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 3) + layer = bp.dyn.Conv1dLSTMCell(input_shape=(10,), + in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(4), + bm.BatchingMode(), + bm.BatchingMode(4)] + ) + def test_Conv2dLSTMCell(self, mode): + bm.random.seed() + input = bm.random.randn(4, 100, 100, 3) + layer = bp.dyn.Conv2dLSTMCell(input_shape=(100, 100), + in_channels=3, + out_channels=5, + kernel_size=(4, 4), + mode=mode) + output = layer(input) + + def test_Conv2dLSTMCell_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dyn.Conv2dLSTMCell(input_shape=(10, 10), + in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(4), + bm.BatchingMode(), + bm.BatchingMode(4)] + ) + def test_Conv3dLSTMCell(self, mode): + bm.random.seed() + input = bm.random.randn(4, 100, 100, 100, 3) + layer = bp.dyn.Conv3dLSTMCell(input_shape=(100, 100, 100), + in_channels=3, + out_channels=5, + kernel_size=(4, 4, 4), + mode=mode) + output = layer(input) + + def test_Conv3dLSTMCell_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dyn.Conv3dLSTMCell(input_shape=(10, 10, 10), + in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/dyn/synapses/__init__.py b/brainpy/_src/dyn/synapses/__init__.py index 2a296acb5..8eebdd6a1 100644 --- a/brainpy/_src/dyn/synapses/__init__.py +++ b/brainpy/_src/dyn/synapses/__init__.py @@ -1,3 +1,2 @@ - from .abstract_models import * from .bio_models import * diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index a2f800c8d..8bc29f31b 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -2,875 +2,875 @@ from brainpy import math as bm from brainpy._src.context import share -from brainpy._src.initialize import parameter from brainpy._src.dyn import _docs from brainpy._src.dyn.base import SynDyn +from brainpy._src.initialize import parameter from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint from brainpy._src.mixin import AlignPost, ReturnInfo from brainpy.types import ArrayType __all__ = [ - 'Expon', - 'DualExpon', - 'DualExponV2', - 'Alpha', - 'NMDA', - 'STD', - 'STP', + 'Expon', + 'DualExpon', + 'DualExponV2', + 'Alpha', + 'NMDA', + 'STD', + 'STP', ] class Expon(SynDyn, AlignPost): - r"""Exponential decay synapse model. - - %s - - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - .. code-block:: python + r"""Exponential decay synapse model. - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt + %s + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt - class ExponSparseCOBA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau, E): - super().__init__() - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.Expon.desc(pre.num, tau=tau), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) + class ExponSparseCOBA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau, E): + super().__init__() + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.Expon.desc(pre.num, tau=tau), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) - class SimpleNet(bp.DynSysGroup): - def __init__(self, syn_cls, E=0.): - super().__init__() - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., tau=5., E=E) - def update(self): - self.pre() - self.syn() - self.post() - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V + class SimpleNet(bp.DynSysGroup): + def __init__(self, syn_cls, E=0.): + super().__init__() + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., tau=5., E=E) - Moreover, it can also be used with interface ``ProjAlignPostMg2``: + def update(self): + self.pre() + self.syn() + self.post() + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V - .. code-block:: python + Moreover, it can also be used with interface ``ProjAlignPostMg2``: - class ExponSparseCOBAPost(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau, E): - super().__init__() + .. code-block:: python - self.proj = bp.dyn.ProjAlignPostMg2( - pre=pre, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - syn=bp.dyn.Expon.desc(post.num, tau=tau), - out=bp.dyn.COBA.desc(E=E), - post=post, - ) + class ExponSparseCOBAPost(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau, E): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg2( + pre=pre, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + syn=bp.dyn.Expon.desc(post.num, tau=tau), + out=bp.dyn.COBA.desc(E=E), + post=post, + ) - Args: - tau: float. The time constant of decay. [ms] - %s - """ - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, + Args: + tau: float. The time constant of decay. [ms] + %s + """ - # synapse parameters - tau: Union[float, ArrayType, Callable] = 8.0, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, - # parameters - self.tau = self.init_param(tau) + # synapse parameters + tau: Union[float, ArrayType, Callable] = 8.0, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) - # function - self.integral = odeint(self.derivative, method=method) - self._current = None + # parameters + self.tau = self.init_param(tau) - self.reset_state(self.mode) + # function + self.integral = odeint(self.derivative, method=method) + self._current = None - def derivative(self, g, t): - return -g / self.tau + self.reset_state(self.mode) - def reset_state(self, batch_or_mode=None, **kwargs): - self.g = self.init_variable(bm.zeros, batch_or_mode) + def derivative(self, g, t): + return -g / self.tau - def update(self, x=None): - self.g.value = self.integral(self.g.value, share['t'], share['dt']) - if x is not None: - self.add_current(x) - return self.g.value + def reset_state(self, batch_or_mode=None, **kwargs): + self.g = self.init_variable(bm.zeros, batch_or_mode) - def add_current(self, x): - self.g.value += x - - def return_info(self): - return self.g + def update(self, x=None): + self.g.value = self.integral(self.g.value, share['t'], share['dt']) + if x is not None: + self.add_current(x) + return self.g.value + + def add_current(self, x): + self.g.value += x + + def return_info(self): + return self.g Expon.__doc__ = Expon.__doc__ % (_docs.exp_syn_doc, _docs.pneu_doc,) def _format_dual_exp_A(self, A): - A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding) - if A is None: - A = (self.tau_decay / (self.tau_decay - self.tau_rise) * - bm.float_power(self.tau_rise / self.tau_decay, self.tau_rise / (self.tau_rise - self.tau_decay))) - return A + A = parameter(A, sizes=self.varshape, allow_none=True, sharding=self.sharding) + if A is None: + A = (self.tau_decay / (self.tau_decay - self.tau_rise) * + bm.float_power(self.tau_rise / self.tau_decay, self.tau_rise / (self.tau_rise - self.tau_decay))) + return A class DualExpon(SynDyn): - r"""Dual exponential synapse model. - - %s - - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - .. code-block:: python - - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt - - class DualExpSparseCOBA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): - super().__init__() - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.DualExpon.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) - - class SimpleNet(bp.DynSysGroup): - def __init__(self, syn_cls, E=0.): - super().__init__() - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., - tau_decay=5., tau_rise=1., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - net = SimpleNet(DualExpSparseCOBA, E=0.) - conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - See Also: - DualExponV2 - - .. note:: - - The implementation of this model can only be used in ``AlignPre`` projections. - One the contrary, to seek the ``AlignPost`` projection, please use ``DualExponV2``. - - Args: - %s - %s - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - tau_decay: Union[float, ArrayType, Callable] = 10.0, - tau_rise: Union[float, ArrayType, Callable] = 1., - A: Optional[Union[float, ArrayType, Callable]] = None, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.tau_rise = self.init_param(tau_rise) - self.tau_decay = self.init_param(tau_decay) - A = _format_dual_exp_A(self, A) - self.a = (self.tau_decay - self.tau_rise) / self.tau_rise / self.tau_decay * A - - # integrator - self.integral = odeint(JointEq(self.dg, self.dh), method=method) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.h = self.init_variable(bm.zeros, batch_or_mode) - self.g = self.init_variable(bm.zeros, batch_or_mode) - - def dh(self, h, t): - return -h / self.tau_rise - - def dg(self, g, t, h): - return -g / self.tau_decay + h - - def update(self, x): - # x: the pre-synaptic spikes + r"""Dual exponential synapse model. - # update synaptic variables - self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += self.a * x - return self.g.value + %s - def return_info(self): - return self.g + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt + + class DualExpSparseCOBA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): + super().__init__() + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.DualExpon.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) + + class SimpleNet(bp.DynSysGroup): + def __init__(self, syn_cls, E=0.): + super().__init__() + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., + tau_decay=5., tau_rise=1., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + net = SimpleNet(DualExpSparseCOBA, E=0.) + conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + See Also: + DualExponV2 + + .. note:: + + The implementation of this model can only be used in ``AlignPre`` projections. + One the contrary, to seek the ``AlignPost`` projection, please use ``DualExponV2``. + + Args: + %s + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau_decay: Union[float, ArrayType, Callable] = 10.0, + tau_rise: Union[float, ArrayType, Callable] = 1., + A: Optional[Union[float, ArrayType, Callable]] = None, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.tau_rise = self.init_param(tau_rise) + self.tau_decay = self.init_param(tau_decay) + A = _format_dual_exp_A(self, A) + self.a = (self.tau_decay - self.tau_rise) / self.tau_rise / self.tau_decay * A + + # integrator + self.integral = odeint(JointEq(self.dg, self.dh), method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.h = self.init_variable(bm.zeros, batch_or_mode) + self.g = self.init_variable(bm.zeros, batch_or_mode) + + def dh(self, h, t): + return -h / self.tau_rise + + def dg(self, g, t, h): + return -g / self.tau_decay + h + + def update(self, x): + # x: the pre-synaptic spikes + + # update synaptic variables + self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) + self.h += self.a * x + return self.g.value + + def return_info(self): + return self.g DualExpon.__doc__ = DualExpon.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args) class DualExponV2(SynDyn, AlignPost): - r"""Dual exponential synapse model. - - %s - - .. note:: - - Different from ``DualExpon``, this model can be used in both modes of ``AlignPre`` and ``AlignPost`` projections. - - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - .. code-block:: python - - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt - - - class DualExponV2SparseCOBA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): - super().__init__() - - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.DualExponV2.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) - - - class SimpleNet(bp.DynSysGroup): - def __init__(self, syn_cls, E=0.): - super().__init__() - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., tau_decay=5., tau_rise=1., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g_rise - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - net = SimpleNet(DualExponV2SparseCOBAPost, E=0.) - conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - Moreover, it can also be used with interface ``ProjAlignPostMg2``: - - .. code-block:: python - - class DualExponV2SparseCOBAPost(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): - super().__init__() - - self.proj = bp.dyn.ProjAlignPostMg2( - pre=pre, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - syn=bp.dyn.DualExponV2.desc(post.num, tau_decay=tau_decay, tau_rise=tau_rise), - out=bp.dyn.COBA.desc(E=E), - post=post, - ) - - See Also: - DualExpon - - Args: - %s - %s - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - tau_decay: Union[float, ArrayType, Callable] = 10.0, - tau_rise: Union[float, ArrayType, Callable] = 1., - A: Optional[Union[float, ArrayType, Callable]] = None, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.tau_rise = self.init_param(tau_rise) - self.tau_decay = self.init_param(tau_decay) - self.a = _format_dual_exp_A(self, A) - - # integrator - self.integral = odeint(lambda g, t, tau: -g / tau, method=method) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.g_rise = self.init_variable(bm.zeros, batch_or_mode) - self.g_decay = self.init_variable(bm.zeros, batch_or_mode) - - def update(self, x=None): - self.g_rise.value = self.integral(self.g_rise.value, share['t'], self.tau_rise, share['dt']) - self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt']) - if x is not None: - self.add_current(x) - return self.a * (self.g_decay - self.g_rise) - - def add_current(self, inp): - self.g_rise += inp - self.g_decay += inp - - def return_info(self): - return ReturnInfo(self.varshape, self.sharding, self.mode, - lambda shape: self.a * (self.g_decay - self.g_rise)) - - -DualExponV2.__doc__ = DualExponV2.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args,) - + r"""Dual exponential synapse model. -class Alpha(SynDyn): - r"""Alpha synapse model. - - %s - - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - .. code-block:: python - - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt - - - class AlphaSparseCOBA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau_decay, E): - super().__init__() - - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.Alpha.desc(pre.num, tau_decay=tau_decay), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) - - - class SimpleNet(bp.DynSysGroup): - def __init__(self, syn_cls, E=0.): - super().__init__() - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., - tau_decay=5., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - net = SimpleNet(AlphaSparseCOBA, E=0.) - conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - - Args: %s - tau_decay: float, ArrayType, Callable. The time constant [ms] of the synaptic decay phase. - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - tau_decay: Union[float, ArrayType, Callable] = 10.0, - ): - super().__init__( - name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding - ) - # parameters - self.tau_decay = self.init_param(tau_decay) + .. note:: + + Different from ``DualExpon``, this model can be used in both modes of ``AlignPre`` and ``AlignPost`` projections. + + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt + + + class DualExponV2SparseCOBA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): + super().__init__() + + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.DualExponV2.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) + + + class SimpleNet(bp.DynSysGroup): + def __init__(self, syn_cls, E=0.): + super().__init__() + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., tau_decay=5., tau_rise=1., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g_rise + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + net = SimpleNet(DualExponV2SparseCOBAPost, E=0.) + conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + Moreover, it can also be used with interface ``ProjAlignPostMg2``: + + .. code-block:: python + + class DualExponV2SparseCOBAPost(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): + super().__init__() + + self.proj = bp.dyn.ProjAlignPostMg2( + pre=pre, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + syn=bp.dyn.DualExponV2.desc(post.num, tau_decay=tau_decay, tau_rise=tau_rise), + out=bp.dyn.COBA.desc(E=E), + post=post, + ) + + See Also: + DualExpon + + Args: + %s + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau_decay: Union[float, ArrayType, Callable] = 10.0, + tau_rise: Union[float, ArrayType, Callable] = 1., + A: Optional[Union[float, ArrayType, Callable]] = None, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.tau_rise = self.init_param(tau_rise) + self.tau_decay = self.init_param(tau_decay) + self.a = _format_dual_exp_A(self, A) + + # integrator + self.integral = odeint(lambda g, t, tau: -g / tau, method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.g_rise = self.init_variable(bm.zeros, batch_or_mode) + self.g_decay = self.init_variable(bm.zeros, batch_or_mode) + + def update(self, x=None): + self.g_rise.value = self.integral(self.g_rise.value, share['t'], self.tau_rise, share['dt']) + self.g_decay.value = self.integral(self.g_decay.value, share['t'], self.tau_decay, share['dt']) + if x is not None: + self.add_current(x) + return self.a * (self.g_decay - self.g_rise) + + def add_current(self, inp): + self.g_rise += inp + self.g_decay += inp + + def return_info(self): + return ReturnInfo(self.varshape, self.sharding, self.mode, + lambda shape: self.a * (self.g_decay - self.g_rise)) - # integrator - self.integral = odeint(JointEq(self.dg, self.dh), method=method) - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.h = self.init_variable(bm.zeros, batch_or_mode) - self.g = self.init_variable(bm.zeros, batch_or_mode) +DualExponV2.__doc__ = DualExponV2.__doc__ % (_docs.dual_exp_syn_doc, _docs.pneu_doc, _docs.dual_exp_args,) - def dh(self, h, t): - return -h / self.tau_decay - def dg(self, g, t, h): - return -g / self.tau_decay + h / self.tau_decay +class Alpha(SynDyn): + r"""Alpha synapse model. - def update(self, x): - # update synaptic variables - self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) - self.h += x - return self.g.value + %s - def return_info(self): - return self.g + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt + + + class AlphaSparseCOBA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau_decay, E): + super().__init__() + + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.Alpha.desc(pre.num, tau_decay=tau_decay), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) + + + class SimpleNet(bp.DynSysGroup): + def __init__(self, syn_cls, E=0.): + super().__init__() + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., + tau_decay=5., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + net = SimpleNet(AlphaSparseCOBA, E=0.) + conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + + Args: + %s + tau_decay: float, ArrayType, Callable. The time constant [ms] of the synaptic decay phase. + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau_decay: Union[float, ArrayType, Callable] = 10.0, + ): + super().__init__( + name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding + ) + + # parameters + self.tau_decay = self.init_param(tau_decay) + + # integrator + self.integral = odeint(JointEq(self.dg, self.dh), method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.h = self.init_variable(bm.zeros, batch_or_mode) + self.g = self.init_variable(bm.zeros, batch_or_mode) + + def dh(self, h, t): + return -h / self.tau_decay + + def dg(self, g, t, h): + return -g / self.tau_decay + h / self.tau_decay + + def update(self, x): + # update synaptic variables + self.g.value, self.h.value = self.integral(self.g.value, self.h.value, share['t'], dt=share['dt']) + self.h += x + return self.g.value + + def return_info(self): + return self.g Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc, _docs.pneu_doc,) class NMDA(SynDyn): - r"""NMDA synapse model. + r"""NMDA synapse model. - **Model Descriptions** + **Model Descriptions** - The NMDA receptor is a glutamate receptor and ion channel found in neurons. - The NMDA receptor is one of three types of ionotropic glutamate receptors, - the other two being AMPA and kainate receptors. - - The NMDA receptor mediated conductance depends on the postsynaptic voltage. - The voltage dependence is due to the blocking of the pore of the NMDA receptor - from the outside by a positively charged magnesium ion. The channel is - nearly completely blocked at resting potential, but the magnesium block is - relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` - that are not blocked by magnesium can be fitted to + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. + + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to - .. math:: + .. math:: - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} - \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, - usually 1 mM. Thus, the channel acts as a - "coincidence detector" and only once both of these conditions are met, the - channel opens and it allows positively charged ions (cations) to flow through - the cell membrane [2]_. - - If we make the approximation that the magnesium block changes - instantaneously with voltage and is independent of the gating of the channel, - the net NMDA receptor-mediated synaptic current is given by - - .. math:: - - I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} - - where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the - reversal potential. - - Simultaneously, the kinetics of synaptic state :math:`g` is given by - - .. math:: - - & g_\mathrm{NMDA} (t) = g_{max} g \\ - & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\ - & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k}) - - where the decay time of NMDA currents is usually taken to be - :math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms. - - The NMDA receptor has been thought to be very important for controlling - synaptic plasticity and mediating learning and memory functions [3]_. - - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - .. code-block:: python - - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt - - class NMDASparseCOBA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): - super().__init__() - - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.NMDA.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) - - - class SimpleNet(bp.DynSysGroup): - def __init__(self, syn_cls, E=0.): - super().__init__() - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., - tau_decay=5., tau_rise=1., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - net = SimpleNet(NMDASparseCOBA, E=0.) - conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - - - .. [1] Brunel N, Wang X J. Effects of neuromodulation in a - cortical network model of object working memory dominated - by recurrent inhibition[J]. - Journal of computational neuroscience, 2001, 11(1): 63-85. - .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and - Eric Gouaux. "Subunit arrangement and function in NMDA receptors." - Nature 438, no. 7065 (2005): 185-192. - .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New - England journal of medicine, 361(3), p.302. - .. [4] https://en.wikipedia.org/wiki/NMDA_receptor - - Args: - tau_decay: float, ArrayType, Callable. The time constant of the synaptic decay phase. Default 100 [ms] - tau_rise: float, ArrayType, Callable. The time constant of the synaptic rise phase. Default 2 [ms] - a: float, ArrayType, Callable. Default 0.5 ms^-1. - %s - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - a: Union[float, ArrayType, Callable] = 0.5, - tau_decay: Union[float, ArrayType, Callable] = 100., - tau_rise: Union[float, ArrayType, Callable] = 2., - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.tau_decay = self.init_param(tau_decay) - self.tau_rise = self.init_param(tau_rise) - self.a = self.init_param(a) - - # integral - self.integral = odeint(method=method, f=JointEq(self.dg, self.dx)) - - self.reset_state(self.mode) - - def dg(self, g, t, x): - return -g / self.tau_decay + self.a * x * (1 - g) - - def dx(self, x, t): - return -x / self.tau_rise - - def reset_state(self, batch_or_mode=None, **kwargs): - self.g = self.init_variable(bm.zeros, batch_or_mode) - self.x = self.init_variable(bm.zeros, batch_or_mode) - - def update(self, pre_spike): - t = share.load('t') - dt = share.load('dt') - self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt) - self.x += pre_spike - return self.g.value - - def return_info(self): - return self.g + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} + \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. + + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is given by + + .. math:: + + & g_\mathrm{NMDA} (t) = g_{max} g \\ + & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\ + & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k}) + + where the decay time of NMDA currents is usually taken to be + :math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt + + class NMDASparseCOBA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): + super().__init__() + + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.NMDA.desc(pre.num, tau_decay=tau_decay, tau_rise=tau_rise), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) + + + class SimpleNet(bp.DynSysGroup): + def __init__(self, syn_cls, E=0.): + super().__init__() + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = syn_cls(self.pre, self.post, delay=None, prob=1., g_max=1., + tau_decay=5., tau_rise=1., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + net = SimpleNet(NMDASparseCOBA, E=0.) + conductances, currents, potentials = bm.for_loop(net.step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + + + .. [1] Brunel N, Wang X J. Effects of neuromodulation in a + cortical network model of object working memory dominated + by recurrent inhibition[J]. + Journal of computational neuroscience, 2001, 11(1): 63-85. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + Args: + tau_decay: float, ArrayType, Callable. The time constant of the synaptic decay phase. Default 100 [ms] + tau_rise: float, ArrayType, Callable. The time constant of the synaptic rise phase. Default 2 [ms] + a: float, ArrayType, Callable. Default 0.5 ms^-1. + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + a: Union[float, ArrayType, Callable] = 0.5, + tau_decay: Union[float, ArrayType, Callable] = 100., + tau_rise: Union[float, ArrayType, Callable] = 2., + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.tau_decay = self.init_param(tau_decay) + self.tau_rise = self.init_param(tau_rise) + self.a = self.init_param(a) + + # integral + self.integral = odeint(method=method, f=JointEq(self.dg, self.dx)) + + self.reset_state(self.mode) + + def dg(self, g, t, x): + return -g / self.tau_decay + self.a * x * (1 - g) + + def dx(self, x, t): + return -x / self.tau_rise + + def reset_state(self, batch_or_mode=None, **kwargs): + self.g = self.init_variable(bm.zeros, batch_or_mode) + self.x = self.init_variable(bm.zeros, batch_or_mode) + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, dt=dt) + self.x += pre_spike + return self.g.value + + def return_info(self): + return self.g NMDA.__doc__ = NMDA.__doc__ % (_docs.pneu_doc,) class STD(SynDyn): - r"""Synaptic output with short-term depression. + r"""Synaptic output with short-term depression. - %s - - Args: - tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles. - U: float, ArrayType, Callable. The fraction of resources used per action potential. %s - """ - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, + Args: + tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles. + U: float, ArrayType, Callable. The fraction of resources used per action potential. + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, - # synapse parameters - tau: Union[float, ArrayType, Callable] = 200., - U: Union[float, ArrayType, Callable] = 0.07, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) + # synapse parameters + tau: Union[float, ArrayType, Callable] = 200., + U: Union[float, ArrayType, Callable] = 0.07, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) - # parameters - self.tau = self.init_param(tau) - self.U = self.init_param(U) + # parameters + self.tau = self.init_param(tau) + self.U = self.init_param(U) - # integral function - self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=method) + # integral function + self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=method) - self.reset_state(self.mode) + self.reset_state(self.mode) - def reset_state(self, batch_or_mode=None, **kwargs): - self.x = self.init_variable(bm.ones, batch_or_mode) + def reset_state(self, batch_or_mode=None, **kwargs): + self.x = self.init_variable(bm.ones, batch_or_mode) - def update(self, pre_spike): - t = share.load('t') - dt = share.load('dt') - x = self.integral(self.x.value, t, dt) + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + x = self.integral(self.x.value, t, dt) - # --- original code: - # self.x.value = bm.where(pre_spike, x - self.U * self.x, x) + # --- original code: + # self.x.value = bm.where(pre_spike, x - self.U * self.x, x) - # --- simplified code: - self.x.value = x - pre_spike * self.U * self.x + # --- simplified code: + self.x.value = x - pre_spike * self.U * self.x - return self.x.value + return self.x.value - def return_info(self): - return self.x + def return_info(self): + return self.x STD.__doc__ = STD.__doc__ % (_docs.std_doc, _docs.pneu_doc,) class STP(SynDyn): - r"""Synaptic output with short-term plasticity. + r"""Synaptic output with short-term plasticity. - %s - - Args: - tau_f: float, ArrayType, Callable. The time constant of short-term facilitation. - tau_d: float, ArrayType, Callable. The time constant of short-term depression. - U: float, ArrayType, Callable. The fraction of resources used per action potential. %s - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - U: Union[float, ArrayType, Callable] = 0.15, - tau_f: Union[float, ArrayType, Callable] = 1500., - tau_d: Union[float, ArrayType, Callable] = 200., - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.tau_f = self.init_param(tau_f) - self.tau_d = self.init_param(tau_d) - self.U = self.init_param(U) - self.method = method - - # integral function - self.integral = odeint(self.derivative, method=self.method) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.x = self.init_variable(bm.ones, batch_or_mode) - self.u = self.init_variable(bm.ones, batch_or_mode) - self.u.fill_(self.U) - - @property - def derivative(self): - du = lambda u, t: self.U - u / self.tau_f - dx = lambda x, t: (1 - x) / self.tau_d - return JointEq(du, dx) - - def update(self, pre_spike): - t = share.load('t') - dt = share.load('dt') - u, x = self.integral(self.u.value, self.x.value, t, dt) - - # --- original code: - # if pre_spike.dtype == jax.numpy.bool_: - # u = bm.where(pre_spike, u + self.U * (1 - self.u), u) - # x = bm.where(pre_spike, x - u * self.x, x) - # else: - # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u - # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x - - # --- simplified code: - u = pre_spike * self.U * (1 - self.u) + u - x = pre_spike * -u * self.x + x - - self.x.value = x - self.u.value = u - return u * x - - def return_info(self): - return ReturnInfo(self.varshape, self.sharding, self.mode, - lambda shape: self.u * self.x) + + Args: + tau_f: float, ArrayType, Callable. The time constant of short-term facilitation. + tau_d: float, ArrayType, Callable. The time constant of short-term depression. + U: float, ArrayType, Callable. The fraction of resources used per action potential. + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + U: Union[float, ArrayType, Callable] = 0.15, + tau_f: Union[float, ArrayType, Callable] = 1500., + tau_d: Union[float, ArrayType, Callable] = 200., + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.tau_f = self.init_param(tau_f) + self.tau_d = self.init_param(tau_d) + self.U = self.init_param(U) + self.method = method + + # integral function + self.integral = odeint(self.derivative, method=self.method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.x = self.init_variable(bm.ones, batch_or_mode) + self.u = self.init_variable(bm.ones, batch_or_mode) + self.u.fill_(self.U) + + @property + def derivative(self): + du = lambda u, t: self.U - u / self.tau_f + dx = lambda x, t: (1 - x) / self.tau_d + return JointEq(du, dx) + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + u, x = self.integral(self.u.value, self.x.value, t, dt) + + # --- original code: + # if pre_spike.dtype == jax.numpy.bool_: + # u = bm.where(pre_spike, u + self.U * (1 - self.u), u) + # x = bm.where(pre_spike, x - u * self.x, x) + # else: + # u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u + # x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x + + # --- simplified code: + u = pre_spike * self.U * (1 - self.u) + u + x = pre_spike * -u * self.x + x + + self.x.value = x + self.u.value = u + return u * x + + def return_info(self): + return ReturnInfo(self.varshape, self.sharding, self.mode, + lambda shape: self.u * self.x) STP.__doc__ = STP.__doc__ % (_docs.stp_doc, _docs.pneu_doc,) diff --git a/brainpy/_src/dyn/synapses/bio_models.py b/brainpy/_src/dyn/synapses/bio_models.py index cfb174dff..043347ff3 100644 --- a/brainpy/_src/dyn/synapses/bio_models.py +++ b/brainpy/_src/dyn/synapses/bio_models.py @@ -9,495 +9,495 @@ from brainpy.types import ArrayType __all__ = [ - 'AMPA', - 'GABAa', - 'BioNMDA', + 'AMPA', + 'GABAa', + 'BioNMDA', ] class AMPA(SynDyn): - r"""AMPA synapse model. + r"""AMPA synapse model. - **Model Descriptions** + **Model Descriptions** - AMPA receptor is an ionotropic receptor, which is an ion channel. - When it is bound by neurotransmitters, it will immediately open the - ion channel, causing the change of membrane potential of postsynaptic neurons. + AMPA receptor is an ionotropic receptor, which is an ion channel. + When it is bound by neurotransmitters, it will immediately open the + ion channel, causing the change of membrane potential of postsynaptic neurons. - A classical model is to use the Markov process to model ion channel switch. - Here :math:`g` represents the probability of channel opening, :math:`1-g` - represents the probability of ion channel closing, and :math:`\alpha` and - :math:`\beta` are the transition probability. Because neurotransmitters can - open ion channels, the transfer probability from :math:`1-g` to :math:`g` - is affected by the concentration of neurotransmitters. We denote the concentration - of neurotransmitters as :math:`[T]` and get the following Markov process. + A classical model is to use the Markov process to model ion channel switch. + Here :math:`g` represents the probability of channel opening, :math:`1-g` + represents the probability of ion channel closing, and :math:`\alpha` and + :math:`\beta` are the transition probability. Because neurotransmitters can + open ion channels, the transfer probability from :math:`1-g` to :math:`g` + is affected by the concentration of neurotransmitters. We denote the concentration + of neurotransmitters as :math:`[T]` and get the following Markov process. - .. image:: ../../_static/synapse_markov.png - :align: center + .. image:: ../../_static/synapse_markov.png + :align: center - We obtained the following formula when describing the process by a differential equation. - - .. math:: + We obtained the following formula when describing the process by a differential equation. - \frac{ds}{dt} =\alpha[T](1-g)-\beta g - - where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)` - to state :math:`(g)`; and :math:`\beta` represents the transition probability of - the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the - unbinding constant. :math:`[T]` is the neurotransmitter concentration, and - has the duration of 0.5 ms. - - Moreover, the post-synaptic current on the post-synaptic neuron is formulated as - - .. math:: - - I_{syn} = g_{max} g (V-E) - - where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential. - - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - .. code-block:: python - - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt - - class AMPA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, E=0.): - super().__init__() - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.AMPA.desc(pre.num, alpha=0.98, beta=0.18, T=0.5, T_dur=0.5), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) - - class SimpleNet(bp.DynSysGroup): - def __init__(self, E=0.): - super().__init__() - - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = AMPA(self.pre, self.post, delay=None, prob=1., g_max=1., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - conductances, currents, potentials = bm.for_loop(SimpleNet(E=0.).step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - - - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - - .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations - and implications for stimulus processing[J]. Proceedings of the - National Academy of Sciences, 2012, 109(45): 18553-18558. - - Args: - alpha: float, ArrayType, Callable. Binding constant. - beta: float, ArrayType, Callable. Unbinding constant. - T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by - a pre-synaptic spike.. Default 1 [mM]. - T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] - %s - """ - - supported_modes = (bm.NonBatchingMode, bm.BatchingMode) - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - alpha: Union[float, ArrayType, Callable] = 0.98, - beta: Union[float, ArrayType, Callable] = 0.18, - T: Union[float, ArrayType, Callable] = 0.5, - T_dur: Union[float, ArrayType, Callable] = 0.5, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.alpha = self.init_param(alpha) - self.beta = self.init_param(beta) - self.T = self.init_param(T) - self.T_duration = self.init_param(T_dur) - - # functions - self.integral = odeint(method=method, f=self.dg) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.g = self.init_variable(bm.zeros, batch_or_mode) - self.spike_arrival_time = self.init_variable(bm.ones, batch_or_mode) - self.spike_arrival_time.fill(-1e7) - - def dg(self, g, t, TT): - return self.alpha * TT * (1 - g) - self.beta * g - - def update(self, pre_spike): - t = share.load('t') - dt = share.load('dt') - self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) - TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T - self.g.value = self.integral(self.g.value, t, TT, dt) - return self.g.value - - def return_info(self): - return self.g + .. math:: + \frac{ds}{dt} =\alpha[T](1-g)-\beta g -AMPA.__doc__ = AMPA.__doc__ % (pneu_doc,) - - -class GABAa(AMPA): - r"""GABAa synapse model. + where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)` + to state :math:`(g)`; and :math:`\beta` represents the transition probability of + the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the + unbinding constant. :math:`[T]` is the neurotransmitter concentration, and + has the duration of 0.5 ms. - **Model Descriptions** + Moreover, the post-synaptic current on the post-synaptic neuron is formulated as - GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_, + .. math:: - .. math:: + I_{syn} = g_{max} g (V-E) - \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\ - I_{syn}&= - g_{max} g (V - E) + where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential. - but with the difference of: + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: - - Reversal potential of synapse :math:`E` is usually low, typically -80. mV - - Activating rate constant :math:`\alpha=0.53` - - De-activating rate constant :math:`\beta=0.18` - - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is - triggered by a pre-synaptic spike, with the duration of 1. ms. + .. code-block:: python - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + import numpy as np + import brainpy as bp + import brainpy.math as bm - .. code-block:: python + import matplotlib.pyplot as plt - import numpy as np - import brainpy as bp - import brainpy.math as bm - - import matplotlib.pyplot as plt - - class GABAa(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, E=-80.): - super().__init__() - self.proj = bp.dyn.ProjAlignPreMg2( + class AMPA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPreMg2( pre=pre, delay=delay, - syn=bp.dyn.GABAa.desc(pre.num, alpha=0.53, beta=0.18, T=1.0, T_dur=1.0), + syn=bp.dyn.AMPA.desc(pre.num, alpha=0.98, beta=0.18, T=0.5, T_dur=0.5), comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), out=bp.dyn.COBA(E=E), post=post, - ) - - - class SimpleNet(bp.DynSysGroup): - def __init__(self, E=0.): - super().__init__() - - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = AMPA(self.pre, self.post, delay=None, prob=1., g_max=1., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - conductances, currents, potentials = bm.for_loop(SimpleNet(E=0.).step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - - .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity - on the integrative properties of neocortical pyramidal neurons - in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. - - Args: - alpha: float, ArrayType, Callable. Binding constant. Default 0.062 - beta: float, ArrayType, Callable. Unbinding constant. Default 3.57 - T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by - a pre-synaptic spike.. Default 1 [mM]. - T_dur: float, ArrayType, Callable. Transmitter concentration duration time - after being triggered. Default 1 [ms] - %s - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - alpha: Union[float, ArrayType, Callable] = 0.53, - beta: Union[float, ArrayType, Callable] = 0.18, - T: Union[float, ArrayType, Callable] = 1., - T_dur: Union[float, ArrayType, Callable] = 1., - ): - super().__init__(alpha=alpha, - beta=beta, - T=T, - T_dur=T_dur, - method=method, - name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - -GABAa.__doc__ = GABAa.__doc__ % (pneu_doc,) - - -class BioNMDA(SynDyn): - r"""Biological NMDA synapse model. - - **Model Descriptions** - - The NMDA receptor is a glutamate receptor and ion channel found in neurons. - The NMDA receptor is one of three types of ionotropic glutamate receptors, - the other two being AMPA and kainate receptors. + ) + + class SimpleNet(bp.DynSysGroup): + def __init__(self, E=0.): + super().__init__() + + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = AMPA(self.pre, self.post, delay=None, prob=1., g_max=1., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + conductances, currents, potentials = bm.for_loop(SimpleNet(E=0.).step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + + + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + + .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations + and implications for stimulus processing[J]. Proceedings of the + National Academy of Sciences, 2012, 109(45): 18553-18558. + + Args: + alpha: float, ArrayType, Callable. Binding constant. + beta: float, ArrayType, Callable. Unbinding constant. + T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] + %s + """ + + supported_modes = (bm.NonBatchingMode, bm.BatchingMode) + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + alpha: Union[float, ArrayType, Callable] = 0.98, + beta: Union[float, ArrayType, Callable] = 0.18, + T: Union[float, ArrayType, Callable] = 0.5, + T_dur: Union[float, ArrayType, Callable] = 0.5, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.alpha = self.init_param(alpha) + self.beta = self.init_param(beta) + self.T = self.init_param(T) + self.T_duration = self.init_param(T_dur) + + # functions + self.integral = odeint(method=method, f=self.dg) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.g = self.init_variable(bm.zeros, batch_or_mode) + self.spike_arrival_time = self.init_variable(bm.ones, batch_or_mode) + self.spike_arrival_time.fill(-1e7) + + def dg(self, g, t, TT): + return self.alpha * TT * (1 - g) - self.beta * g + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) + TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T + self.g.value = self.integral(self.g.value, t, TT, dt) + return self.g.value + + def return_info(self): + return self.g - The NMDA receptor mediated conductance depends on the postsynaptic voltage. - The voltage dependence is due to the blocking of the pore of the NMDA receptor - from the outside by a positively charged magnesium ion. The channel is - nearly completely blocked at resting potential, but the magnesium block is - relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` - that are not blocked by magnesium can be fitted to - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V} - \frac{[{Mg}^{2+}]_{o}} {b})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, - usually 1 mM. Thus, the channel acts as a - "coincidence detector" and only once both of these conditions are met, the - channel opens and it allows positively charged ions (cations) to flow through - the cell membrane [2]_. - - If we make the approximation that the magnesium block changes - instantaneously with voltage and is independent of the gating of the channel, - the net NMDA receptor-mediated synaptic current is given by - - .. math:: +AMPA.__doc__ = AMPA.__doc__ % (pneu_doc,) - I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} - where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the - reversal potential. +class GABAa(AMPA): + r"""GABAa synapse model. + + **Model Descriptions** + + GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_, + + .. math:: + + \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\ + I_{syn}&= - g_{max} g (V - E) + + but with the difference of: + + - Reversal potential of synapse :math:`E` is usually low, typically -80. mV + - Activating rate constant :math:`\alpha=0.53` + - De-activating rate constant :math:`\beta=0.18` + - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is + triggered by a pre-synaptic spike, with the duration of 1. ms. + + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt + + class GABAa(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, E=-80.): + super().__init__() + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.GABAa.desc(pre.num, alpha=0.53, beta=0.18, T=1.0, T_dur=1.0), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) + + + class SimpleNet(bp.DynSysGroup): + def __init__(self, E=0.): + super().__init__() + + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = AMPA(self.pre, self.post, delay=None, prob=1., g_max=1., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + conductances, currents, potentials = bm.for_loop(SimpleNet(E=0.).step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + + .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity + on the integrative properties of neocortical pyramidal neurons + in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. + + Args: + alpha: float, ArrayType, Callable. Binding constant. Default 0.062 + beta: float, ArrayType, Callable. Unbinding constant. Default 3.57 + T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_dur: float, ArrayType, Callable. Transmitter concentration duration time + after being triggered. Default 1 [ms] + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + alpha: Union[float, ArrayType, Callable] = 0.53, + beta: Union[float, ArrayType, Callable] = 0.18, + T: Union[float, ArrayType, Callable] = 1., + T_dur: Union[float, ArrayType, Callable] = 1., + ): + super().__init__(alpha=alpha, + beta=beta, + T=T, + T_dur=T_dur, + method=method, + name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) - Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_: - .. math:: +GABAa.__doc__ = GABAa.__doc__ % (pneu_doc,) - & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\ - & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x - where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and - :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x. +class BioNMDA(SynDyn): + r"""Biological NMDA synapse model. - The NMDA receptor has been thought to be very important for controlling - synaptic plasticity and mediating learning and memory functions [3]_. + **Model Descriptions** - This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. - .. code-block:: python + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to - import numpy as np - import brainpy as bp - import brainpy.math as bm + .. math:: - import matplotlib.pyplot as plt + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V} + \frac{[{Mg}^{2+}]_{o}} {b})^{-1} + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. - class BioNMDA(bp.Projection): - def __init__(self, pre, post, delay, prob, g_max, E=0.): - super().__init__() - self.proj = bp.dyn.ProjAlignPreMg2( - pre=pre, - delay=delay, - syn=bp.dyn.BioNMDA.desc(pre.num, alpha1=2, beta1=0.01, alpha2=0.2, beta2=0.5, T=1, T_dur=1), - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - out=bp.dyn.COBA(E=E), - post=post, - ) - - class SimpleNet(bp.DynSysGroup): - def __init__(self, E=0.): - super().__init__() - - self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) - self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Constant(-60.)) - self.syn = BioNMDA(self.pre, self.post, delay=None, prob=1., g_max=1., E=E) - - def update(self): - self.pre() - self.syn() - self.post() - - # monitor the following variables - conductance = self.syn.proj.refs['syn'].g - current = self.post.sum_inputs(self.post.V) - return conductance, current, self.post.V - - - indices = np.arange(1000) # 100 ms, dt= 0.1 ms - conductances, currents, potentials = bm.for_loop(SimpleNet(E=0.).step_run, indices, progress_bar=True) - ts = indices * bm.get_dt() - - fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) - fig.add_subplot(gs[0, 0]) - plt.plot(ts, conductances) - plt.title('Syn conductance') - fig.add_subplot(gs[0, 1]) - plt.plot(ts, currents) - plt.title('Syn current') - fig.add_subplot(gs[0, 2]) - plt.plot(ts, potentials) - plt.title('Post V') - plt.show() - - .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. - Springer New York, 2010: 162. - .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and - Eric Gouaux. "Subunit arrangement and function in NMDA receptors." - Nature 438, no. 7065 (2005): 185-192. - .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New - England journal of medicine, 361(3), p.302. - .. [4] https://en.wikipedia.org/wiki/NMDA_receptor - - - Args: - alpha1: float, ArrayType, Callable. The conversion rate of g from inactive to active. Default 2 ms^-1. - beta1: float, ArrayType, Callable. The conversion rate of g from active to inactive. Default 0.01 ms^-1. - alpha2: float, ArrayType, Callable. The conversion rate of x from inactive to active. Default 1 ms^-1. - beta2: float, ArrayType, Callable. The conversion rate of x from active to inactive. Default 0.5 ms^-1. - T: float, ArrayType, Callable. Transmitter concentration when synapse is - triggered by a pre-synaptic spike. Default 1 [mM]. - T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] - %s - """ - supported_modes = (bm.NonBatchingMode, bm.BatchingMode) - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - - # synapse parameters - alpha1: Union[float, ArrayType, Callable] = 2., - beta1: Union[float, ArrayType, Callable] = 0.01, - alpha2: Union[float, ArrayType, Callable] = 1., - beta2: Union[float, ArrayType, Callable] = 0.5, - T: Union[float, ArrayType, Callable] = 1., - T_dur: Union[float, ArrayType, Callable] = 0.5, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - # parameters - self.beta1 = self.init_param(beta1) - self.beta2 = self.init_param(beta2) - self.alpha1 = self.init_param(alpha1) - self.alpha2 = self.init_param(alpha2) - self.T = self.init_param(T) - self.T_dur = self.init_param(T_dur) - - # integral - self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.g = self.init_variable(bm.zeros, batch_or_mode) - self.x = self.init_variable(bm.zeros, batch_or_mode) - self.spike_arrival_time = self.init_variable(bm.ones, batch_or_mode) - self.spike_arrival_time.fill(-1e7) - - def dg(self, g, t, x): - return self.alpha1 * x * (1 - g) - self.beta1 * g - - def dx(self, x, t, T): - return self.alpha2 * T * (1 - x) - self.beta2 * x - - def update(self, pre_spike): - t = share.load('t') - dt = share.load('dt') - self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) - T = ((t - self.spike_arrival_time) < self.T_dur) * self.T - self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, T, dt) - return self.g.value - - def return_info(self): - return self.g + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_: + + .. math:: + + & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\ + & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x + + where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and + :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example: + + .. code-block:: python + + import numpy as np + import brainpy as bp + import brainpy.math as bm + + import matplotlib.pyplot as plt + + + class BioNMDA(bp.Projection): + def __init__(self, pre, post, delay, prob, g_max, E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPreMg2( + pre=pre, + delay=delay, + syn=bp.dyn.BioNMDA.desc(pre.num, alpha1=2, beta1=0.01, alpha2=0.2, beta2=0.5, T=1, T_dur=1), + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + out=bp.dyn.COBA(E=E), + post=post, + ) + + class SimpleNet(bp.DynSysGroup): + def __init__(self, E=0.): + super().__init__() + + self.pre = bp.dyn.SpikeTimeGroup(1, indices=(0, 0, 0, 0), times=(10., 30., 50., 70.)) + self.post = bp.dyn.LifRef(1, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Constant(-60.)) + self.syn = BioNMDA(self.pre, self.post, delay=None, prob=1., g_max=1., E=E) + + def update(self): + self.pre() + self.syn() + self.post() + + # monitor the following variables + conductance = self.syn.proj.refs['syn'].g + current = self.post.sum_inputs(self.post.V) + return conductance, current, self.post.V + + + indices = np.arange(1000) # 100 ms, dt= 0.1 ms + conductances, currents, potentials = bm.for_loop(SimpleNet(E=0.).step_run, indices, progress_bar=True) + ts = indices * bm.get_dt() + + fig, gs = bp.visualize.get_figure(1, 3, 3.5, 4) + fig.add_subplot(gs[0, 0]) + plt.plot(ts, conductances) + plt.title('Syn conductance') + fig.add_subplot(gs[0, 1]) + plt.plot(ts, currents) + plt.title('Syn current') + fig.add_subplot(gs[0, 2]) + plt.plot(ts, potentials) + plt.title('Post V') + plt.show() + + .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. + Springer New York, 2010: 162. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + + Args: + alpha1: float, ArrayType, Callable. The conversion rate of g from inactive to active. Default 2 ms^-1. + beta1: float, ArrayType, Callable. The conversion rate of g from active to inactive. Default 0.01 ms^-1. + alpha2: float, ArrayType, Callable. The conversion rate of x from inactive to active. Default 1 ms^-1. + beta2: float, ArrayType, Callable. The conversion rate of x from active to inactive. Default 0.5 ms^-1. + T: float, ArrayType, Callable. Transmitter concentration when synapse is + triggered by a pre-synaptic spike. Default 1 [mM]. + T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] + %s + """ + supported_modes = (bm.NonBatchingMode, bm.BatchingMode) + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + sharding: Optional[Sequence[str]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + alpha1: Union[float, ArrayType, Callable] = 2., + beta1: Union[float, ArrayType, Callable] = 0.01, + alpha2: Union[float, ArrayType, Callable] = 1., + beta2: Union[float, ArrayType, Callable] = 0.5, + T: Union[float, ArrayType, Callable] = 1., + T_dur: Union[float, ArrayType, Callable] = 0.5, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + sharding=sharding) + + # parameters + self.beta1 = self.init_param(beta1) + self.beta2 = self.init_param(beta2) + self.alpha1 = self.init_param(alpha1) + self.alpha2 = self.init_param(alpha2) + self.T = self.init_param(T) + self.T_dur = self.init_param(T_dur) + + # integral + self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.g = self.init_variable(bm.zeros, batch_or_mode) + self.x = self.init_variable(bm.zeros, batch_or_mode) + self.spike_arrival_time = self.init_variable(bm.ones, batch_or_mode) + self.spike_arrival_time.fill(-1e7) + + def dg(self, g, t, x): + return self.alpha1 * x * (1 - g) - self.beta1 * g + + def dx(self, x, t, T): + return self.alpha2 * T * (1 - x) - self.beta2 * x + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) + T = ((t - self.spike_arrival_time) < self.T_dur) * self.T + self.g.value, self.x.value = self.integral(self.g.value, self.x.value, t, T, dt) + return self.g.value + + def return_info(self): + return self.g BioNMDA.__doc__ = BioNMDA.__doc__ % (pneu_doc,) diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py index 0e4918b91..097936afd 100644 --- a/brainpy/_src/dyn/synapses/delay_couplings.py +++ b/brainpy/_src/dyn/synapses/delay_couplings.py @@ -13,285 +13,285 @@ from brainpy.types import ArrayType __all__ = [ - 'DelayCoupling', - 'DiffusiveCoupling', - 'AdditiveCoupling', + 'DelayCoupling', + 'DiffusiveCoupling', + 'AdditiveCoupling', ] class DelayCoupling(Projection): - """Delay coupling. - - Parameters:: - - delay_var: Variable - The delay variable. - var_to_output: Variable, sequence of Variable - The target variables to output. - conn_mat: ArrayType - The connection matrix. - required_shape: sequence of int - The required shape of `(pre, post)`. - delay_steps: int, ArrayType - The matrix of delay time steps. Must be int. - initial_delay_data: Initializer, Callable - The initializer of the initial delay data. - """ - - def __init__( - self, - delay_var: bm.Variable, - var_to_output: Union[bm.Variable, Sequence[bm.Variable]], - conn_mat: ArrayType, - required_shape: Tuple[int, ...], - delay_steps: Optional[Union[int, ArrayType, Callable]] = None, - initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # delay variable - if not isinstance(delay_var, bm.Variable): - raise ValueError(f'"delay_var" must be an instance of brainpy.math.Variable. ' - f'But we got {type(delay_var)}') - self.delay_var = delay_var - - # output variables - if isinstance(var_to_output, bm.Variable): - var_to_output = [var_to_output] - is_sequence(var_to_output, 'output_var', elem_type=bm.Variable, allow_none=False) - self.output_var = var_to_output - - # Connection matrix - self.conn_mat = conn_mat - if self.conn_mat.shape != required_shape: - raise ValueError(f'we expect the structural connection matrix has the shape of ' - f'(pre.num, post.num), i.e., {required_shape}, ' - f'while we got {self.conn_mat.shape}.') - - # Delay matrix - if delay_steps is None: - self.delay_steps = None - self.delay_type = 'none' - num_delay_step = None - elif callable(delay_steps): - delay_steps = delay_steps(required_shape) - if delay_steps.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}') - self.delay_steps = delay_steps - self.delay_type = 'array' - num_delay_step = self.delay_steps.max() - elif isinstance(delay_steps, (bm.BaseArray, jnp.ndarray)): - if delay_steps.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}') - if delay_steps.ndim == 0: - self.delay_type = 'int' - else: - self.delay_type = 'array' - if delay_steps.shape != required_shape: - raise ValueError(f'we expect the delay matrix has the shape of ' - f'(pre.num, post.num), i.e., {required_shape}. ' - f'While we got {delay_steps.shape}.') - self.delay_steps = delay_steps - num_delay_step = self.delay_steps.max() - elif isinstance(delay_steps, int): - self.delay_steps = delay_steps - num_delay_step = delay_steps - self.delay_type = 'int' - else: - raise ValueError(f'Unknown type of delay steps: {type(delay_steps)}') - - # delay variables - _ = self.register_delay(f'delay_{id(delay_var)}', - delay_step=num_delay_step, - delay_target=delay_var, - initial_delay_data=initial_delay_data) - - def reset_state(self, batch_size=None): - pass + """Delay coupling. + + Parameters:: + + delay_var: Variable + The delay variable. + var_to_output: Variable, sequence of Variable + The target variables to output. + conn_mat: ArrayType + The connection matrix. + required_shape: sequence of int + The required shape of `(pre, post)`. + delay_steps: int, ArrayType + The matrix of delay time steps. Must be int. + initial_delay_data: Initializer, Callable + The initializer of the initial delay data. + """ + + def __init__( + self, + delay_var: bm.Variable, + var_to_output: Union[bm.Variable, Sequence[bm.Variable]], + conn_mat: ArrayType, + required_shape: Tuple[int, ...], + delay_steps: Optional[Union[int, ArrayType, Callable]] = None, + initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # delay variable + if not isinstance(delay_var, bm.Variable): + raise ValueError(f'"delay_var" must be an instance of brainpy.math.Variable. ' + f'But we got {type(delay_var)}') + self.delay_var = delay_var + + # output variables + if isinstance(var_to_output, bm.Variable): + var_to_output = [var_to_output] + is_sequence(var_to_output, 'output_var', elem_type=bm.Variable, allow_none=False) + self.output_var = var_to_output + + # Connection matrix + self.conn_mat = conn_mat + if self.conn_mat.shape != required_shape: + raise ValueError(f'we expect the structural connection matrix has the shape of ' + f'(pre.num, post.num), i.e., {required_shape}, ' + f'while we got {self.conn_mat.shape}.') + + # Delay matrix + if delay_steps is None: + self.delay_steps = None + self.delay_type = 'none' + num_delay_step = None + elif callable(delay_steps): + delay_steps = delay_steps(required_shape) + if delay_steps.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: + raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}') + self.delay_steps = delay_steps + self.delay_type = 'array' + num_delay_step = self.delay_steps.max() + elif isinstance(delay_steps, (bm.BaseArray, jnp.ndarray)): + if delay_steps.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: + raise ValueError(f'"delay_steps" must be integer typed. But we got {delay_steps.dtype}') + if delay_steps.ndim == 0: + self.delay_type = 'int' + else: + self.delay_type = 'array' + if delay_steps.shape != required_shape: + raise ValueError(f'we expect the delay matrix has the shape of ' + f'(pre.num, post.num), i.e., {required_shape}. ' + f'While we got {delay_steps.shape}.') + self.delay_steps = delay_steps + num_delay_step = self.delay_steps.max() + elif isinstance(delay_steps, int): + self.delay_steps = delay_steps + num_delay_step = delay_steps + self.delay_type = 'int' + else: + raise ValueError(f'Unknown type of delay steps: {type(delay_steps)}') + + # delay variables + _ = self.register_delay(f'delay_{id(delay_var)}', + delay_step=num_delay_step, + delay_target=delay_var, + initial_delay_data=initial_delay_data) + + def reset_state(self, batch_size=None): + pass class DiffusiveCoupling(DelayCoupling): - """Diffusive coupling. - - This class simulates the model of:: - - coupling = g * (delayed_coupling_var1 - coupling_var2) - target_var += coupling - - - Examples:: - - >>> import brainpy as bp - >>> from brainpy import rates - >>> areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') - >>> conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input, - >>> conn_mat=Cmat, delay_steps=Dmat, - >>> initial_delay_data=bp.init.Uniform(0, 0.05)) - >>> net = bp.Network(areas, conn) - - Parameters:: - - coupling_var1: Variable - The first coupling variable, used for delay. - coupling_var2: Variable - Another coupling variable. - var_to_output: Variable, sequence of Variable - The target variables to output. - conn_mat: ArrayType - The connection matrix. - delay_steps: int, ArrayType - The matrix of delay time steps. Must be int. - initial_delay_data: Initializer, Callable - The initializer of the initial delay data. - name: str - The name of the model. - """ - - def __init__( - self, - coupling_var1: bm.Variable, - coupling_var2: bm.Variable, - var_to_output: Union[bm.Variable, Sequence[bm.Variable]], - conn_mat: ArrayType, - delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None, - initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None, - name: str = None, - mode: bm.Mode = None, - ): - if not isinstance(coupling_var1, bm.Variable): - raise ValueError(f'"coupling_var1" must be an instance of brainpy.math.Variable. ' - f'But we got {type(coupling_var1)}') - if not isinstance(coupling_var2, bm.Variable): - raise ValueError(f'"coupling_var2" must be an instance of brainpy.math.Variable. ' - f'But we got {type(coupling_var2)}') - if jnp.ndim(coupling_var1) != 1: - raise ValueError(f'Only support 1d vector of coupling variable. ' - f'But we got {jnp.ndim(coupling_var1)}') - if jnp.ndim(coupling_var2) != 1: - raise ValueError(f'Only support 1d vector of coupling variable. ' - f'But we got {jnp.ndim(coupling_var2)}') - - super().__init__( - delay_var=coupling_var1, - var_to_output=var_to_output, - conn_mat=conn_mat, - required_shape=(coupling_var1.size, coupling_var2.size), - delay_steps=delay_steps, - initial_delay_data=initial_delay_data, - name=name, - mode=mode, - ) - - self.coupling_var1 = coupling_var1 - self.coupling_var2 = coupling_var2 - - def update(self): - # delays - axis = self.coupling_var1.ndim - delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') - if self.delay_steps is None: - diffusive = (jnp.expand_dims(self.coupling_var1.value, axis=axis) - - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) - diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) - elif self.delay_type == 'array': - if isinstance(self.mode, bm.TrainingMode): - indices = (slice(None, None, None), jnp.arange(self.coupling_var1.size),) - else: - indices = (jnp.arange(self.coupling_var1.size),) - f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (..., pre.num) - delays = f(self.delay_steps) # (..., post.num, pre.num) - diffusive = (jnp.moveaxis(bm.as_jax(delays), axis - 1, axis) - - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) - diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) - elif self.delay_type == 'int': - delayed_data = delay_var.retrieve(self.delay_steps) # (..., pre.num) - diffusive = (jnp.expand_dims(delayed_data, axis=axis) - - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) - diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) - else: - raise ValueError(f'Unknown delay type {self.delay_type}') - - # output to target variable - for target in self.output_var: - target.value += diffusive + """Diffusive coupling. + + This class simulates the model of:: + + coupling = g * (delayed_coupling_var1 - coupling_var2) + target_var += coupling + + + Examples:: + + >>> import brainpy as bp + >>> from brainpy import rates + >>> areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') + >>> conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input, + >>> conn_mat=Cmat, delay_steps=Dmat, + >>> initial_delay_data=bp.init.Uniform(0, 0.05)) + >>> net = bp.Network(areas, conn) + + Parameters:: + + coupling_var1: Variable + The first coupling variable, used for delay. + coupling_var2: Variable + Another coupling variable. + var_to_output: Variable, sequence of Variable + The target variables to output. + conn_mat: ArrayType + The connection matrix. + delay_steps: int, ArrayType + The matrix of delay time steps. Must be int. + initial_delay_data: Initializer, Callable + The initializer of the initial delay data. + name: str + The name of the model. + """ + + def __init__( + self, + coupling_var1: bm.Variable, + coupling_var2: bm.Variable, + var_to_output: Union[bm.Variable, Sequence[bm.Variable]], + conn_mat: ArrayType, + delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None, + initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None, + name: str = None, + mode: bm.Mode = None, + ): + if not isinstance(coupling_var1, bm.Variable): + raise ValueError(f'"coupling_var1" must be an instance of brainpy.math.Variable. ' + f'But we got {type(coupling_var1)}') + if not isinstance(coupling_var2, bm.Variable): + raise ValueError(f'"coupling_var2" must be an instance of brainpy.math.Variable. ' + f'But we got {type(coupling_var2)}') + if jnp.ndim(coupling_var1) != 1: + raise ValueError(f'Only support 1d vector of coupling variable. ' + f'But we got {jnp.ndim(coupling_var1)}') + if jnp.ndim(coupling_var2) != 1: + raise ValueError(f'Only support 1d vector of coupling variable. ' + f'But we got {jnp.ndim(coupling_var2)}') + + super().__init__( + delay_var=coupling_var1, + var_to_output=var_to_output, + conn_mat=conn_mat, + required_shape=(coupling_var1.size, coupling_var2.size), + delay_steps=delay_steps, + initial_delay_data=initial_delay_data, + name=name, + mode=mode, + ) + + self.coupling_var1 = coupling_var1 + self.coupling_var2 = coupling_var2 + + def update(self): + # delays + axis = self.coupling_var1.ndim + delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') + if self.delay_steps is None: + diffusive = (jnp.expand_dims(self.coupling_var1.value, axis=axis) - + jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) + diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) + elif self.delay_type == 'array': + if isinstance(self.mode, bm.TrainingMode): + indices = (slice(None, None, None), jnp.arange(self.coupling_var1.size),) + else: + indices = (jnp.arange(self.coupling_var1.size),) + f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (..., pre.num) + delays = f(self.delay_steps) # (..., post.num, pre.num) + diffusive = (jnp.moveaxis(bm.as_jax(delays), axis - 1, axis) - + jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) + diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) + elif self.delay_type == 'int': + delayed_data = delay_var.retrieve(self.delay_steps) # (..., pre.num) + diffusive = (jnp.expand_dims(delayed_data, axis=axis) - + jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) + diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) + else: + raise ValueError(f'Unknown delay type {self.delay_type}') + + # output to target variable + for target in self.output_var: + target.value += diffusive class AdditiveCoupling(DelayCoupling): - """Additive coupling. - - This class simulates the model of:: - - coupling = g * delayed_coupling_var - target_var += coupling - - Parameters:: - - coupling_var: Variable - The coupling variable, used for delay. - var_to_output: Variable, sequence of Variable - The target variables to output. - conn_mat: ArrayType - The connection matrix. - delay_steps: int, ArrayType - The matrix of delay time steps. Must be int. - initial_delay_data: Initializer, Callable - The initializer of the initial delay data. - name: str - The name of the model. - """ - - def __init__( - self, - coupling_var: bm.Variable, - var_to_output: Union[bm.Variable, Sequence[bm.Variable]], - conn_mat: ArrayType, - delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None, - initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None, - name: str = None, - mode: bm.Mode = None, - ): - if not isinstance(coupling_var, bm.Variable): - raise ValueError(f'"coupling_var" must be an instance of brainpy.math.Variable. ' - f'But we got {type(coupling_var)}') - if jnp.ndim(coupling_var) != 1: - raise ValueError(f'Only support 1d vector of coupling variable. ' - f'But we got {jnp.ndim(coupling_var)}') - - super().__init__( - delay_var=coupling_var, - var_to_output=var_to_output, - conn_mat=conn_mat, - required_shape=(coupling_var.size, coupling_var.size), - delay_steps=delay_steps, - initial_delay_data=initial_delay_data, - name=name, - mode=mode, - ) - - self.coupling_var = coupling_var - - def update(self): - # delay function - axis = self.coupling_var.ndim - delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') - if self.delay_steps is None: - additive = self.coupling_var @ self.conn_mat - elif self.delay_type == 'array': - if isinstance(self.mode, bm.TrainingMode): - indices = (slice(None, None, None), jnp.arange(self.coupling_var.size),) - else: - indices = (jnp.arange(self.coupling_var.size),) - f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (.., pre.num,) - delays = f(self.delay_steps) # (..., post.num, pre.num) - additive = (self.conn_mat * jnp.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1) - elif self.delay_type == 'int': - delayed_var = delay_var.retrieve(self.delay_steps) # (..., pre.num) - additive = delayed_var @ self.conn_mat - else: - raise ValueError - - # output to target variable - for target in self.output_var: - target.value += additive + """Additive coupling. + + This class simulates the model of:: + + coupling = g * delayed_coupling_var + target_var += coupling + + Parameters:: + + coupling_var: Variable + The coupling variable, used for delay. + var_to_output: Variable, sequence of Variable + The target variables to output. + conn_mat: ArrayType + The connection matrix. + delay_steps: int, ArrayType + The matrix of delay time steps. Must be int. + initial_delay_data: Initializer, Callable + The initializer of the initial delay data. + name: str + The name of the model. + """ + + def __init__( + self, + coupling_var: bm.Variable, + var_to_output: Union[bm.Variable, Sequence[bm.Variable]], + conn_mat: ArrayType, + delay_steps: Optional[Union[int, ArrayType, Initializer, Callable]] = None, + initial_delay_data: Union[Initializer, Callable, ArrayType, float, int, bool] = None, + name: str = None, + mode: bm.Mode = None, + ): + if not isinstance(coupling_var, bm.Variable): + raise ValueError(f'"coupling_var" must be an instance of brainpy.math.Variable. ' + f'But we got {type(coupling_var)}') + if jnp.ndim(coupling_var) != 1: + raise ValueError(f'Only support 1d vector of coupling variable. ' + f'But we got {jnp.ndim(coupling_var)}') + + super().__init__( + delay_var=coupling_var, + var_to_output=var_to_output, + conn_mat=conn_mat, + required_shape=(coupling_var.size, coupling_var.size), + delay_steps=delay_steps, + initial_delay_data=initial_delay_data, + name=name, + mode=mode, + ) + + self.coupling_var = coupling_var + + def update(self): + # delay function + axis = self.coupling_var.ndim + delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') + if self.delay_steps is None: + additive = self.coupling_var @ self.conn_mat + elif self.delay_type == 'array': + if isinstance(self.mode, bm.TrainingMode): + indices = (slice(None, None, None), jnp.arange(self.coupling_var.size),) + else: + indices = (jnp.arange(self.coupling_var.size),) + f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (.., pre.num,) + delays = f(self.delay_steps) # (..., post.num, pre.num) + additive = (self.conn_mat * jnp.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1) + elif self.delay_type == 'int': + delayed_var = delay_var.retrieve(self.delay_steps) # (..., pre.num) + additive = delayed_var @ self.conn_mat + else: + raise ValueError + + # output to target variable + for target in self.output_var: + target.value += additive diff --git a/brainpy/_src/dyn/synapses/tests/test_abstract_models.py b/brainpy/_src/dyn/synapses/tests/test_abstract_models.py index 7bdd34922..b3972eb72 100644 --- a/brainpy/_src/dyn/synapses/tests/test_abstract_models.py +++ b/brainpy/_src/dyn/synapses/tests/test_abstract_models.py @@ -9,79 +9,79 @@ class TestDualExpon(unittest.TestCase): - def test_dual_expon(self): - bm.set(dt=0.01) + def test_dual_expon(self): + bm.set(dt=0.01) - class Net(bp.DynSysGroup): - def __init__(self, tau_r, tau_d, n_spk): - super().__init__() + class Net(bp.DynSysGroup): + def __init__(self, tau_r, tau_d, n_spk): + super().__init__() - self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk)) - self.proj = bp.dyn.DualExpon(1, tau_rise=tau_r, tau_decay=tau_d) + self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk)) + self.proj = bp.dyn.DualExpon(1, tau_rise=tau_r, tau_decay=tau_d) - def update(self): - self.proj(self.inp()) - return self.proj.h.value, self.proj.g.value + def update(self): + self.proj(self.inp()) + return self.proj.h.value, self.proj.g.value - for tau_r, tau_d in [(1., 10.), (10., 100.)]: - for n_spk in [1, 10, 100]: - net = Net(tau_r, tau_d, n_spk) - indices = bm.as_numpy(bm.arange(1000)) - hs, gs = bm.for_loop(net.step_run, indices, progress_bar=True) + for tau_r, tau_d in [(1., 10.), (10., 100.)]: + for n_spk in [1, 10, 100]: + net = Net(tau_r, tau_d, n_spk) + indices = bm.as_numpy(bm.arange(1000)) + hs, gs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices * bm.get_dt(), hs, legend='h') - bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show) - plt.close('all') + bp.visualize.line_plot(indices * bm.get_dt(), hs, legend='h') + bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show) + plt.close('all') + def test_dual_expon_v2(self): + class Net(bp.DynSysGroup): + def __init__(self, tau_r, tau_d, n_spk): + super().__init__() - def test_dual_expon_v2(self): - class Net(bp.DynSysGroup): - def __init__(self, tau_r, tau_d, n_spk): - super().__init__() + self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk)) + self.syn = bp.dyn.DualExponV2(1, tau_rise=tau_r, tau_decay=tau_d) - self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk)) - self.syn = bp.dyn.DualExponV2(1, tau_rise=tau_r, tau_decay=tau_d) + def update(self): + return self.syn(self.inp()) - def update(self): - return self.syn(self.inp()) + for tau_r, tau_d in [(1., 10.), (5., 50.), (10., 100.)]: + for n_spk in [1, 10, 100]: + net = Net(tau_r, tau_d, n_spk) + indices = bm.as_numpy(bm.arange(1000)) + gs = bm.for_loop(net.step_run, indices, progress_bar=True) - for tau_r, tau_d in [(1., 10.), (5., 50.), (10., 100.)]: - for n_spk in [1, 10, 100]: - net = Net(tau_r, tau_d, n_spk) - indices = bm.as_numpy(bm.arange(1000)) - gs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show) - bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show) + plt.close('all') - plt.close('all') class TestAlpha(unittest.TestCase): - def test_v1(self): - class Net(bp.DynSysGroup): - def __init__(self, tau, n_spk): - super().__init__() - - self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk)) - self.neu = bp.dyn.LifRef(1) - self.proj = bp.dyn.FullProjAlignPreDS(self.inp, None, - bp.dyn.Alpha(1, tau_decay=tau), - bp.dnn.AllToAll(1, 1, 1.), - bp.dyn.CUBA(), self.neu) - - def update(self): - self.inp() - self.proj() - self.neu() - return self.proj.syn.h.value, self.proj.syn.g.value - - for tau in [10.]: - for n_spk in [1, 10, 50]: - net = Net(tau=tau, n_spk=n_spk) - indices = bm.as_numpy(bm.arange(1000)) - hs, gs = bm.for_loop(net.step_run, indices, progress_bar=True) - - bp.visualize.line_plot(indices * bm.get_dt(), hs, legend='h') - bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show) - - plt.close('all') + def test_v1(self): + class Net(bp.DynSysGroup): + def __init__(self, tau, n_spk): + super().__init__() + + self.inp = bp.dyn.SpikeTimeGroup(1, bm.zeros(n_spk, dtype=int), bm.linspace(2., 100., n_spk)) + self.neu = bp.dyn.LifRef(1) + self.proj = bp.dyn.FullProjAlignPreDS(self.inp, None, + bp.dyn.Alpha(1, tau_decay=tau), + bp.dnn.AllToAll(1, 1, 1.), + bp.dyn.CUBA(), self.neu) + + def update(self): + self.inp() + self.proj() + self.neu() + return self.proj.syn.h.value, self.proj.syn.g.value + + for tau in [10.]: + for n_spk in [1, 10, 50]: + net = Net(tau=tau, n_spk=n_spk) + indices = bm.as_numpy(bm.arange(1000)) + hs, gs = bm.for_loop(net.step_run, indices, progress_bar=True) + + bp.visualize.line_plot(indices * bm.get_dt(), hs, legend='h') + bp.visualize.line_plot(indices * bm.get_dt(), gs, legend='g', show=show) + + plt.close('all') diff --git a/brainpy/_src/dyn/synapses/tests/test_delay_couplings.py b/brainpy/_src/dyn/synapses/tests/test_delay_couplings.py index de42fbb7e..f5839cadf 100644 --- a/brainpy/_src/dyn/synapses/tests/test_delay_couplings.py +++ b/brainpy/_src/dyn/synapses/tests/test_delay_couplings.py @@ -3,43 +3,42 @@ from absl.testing import parameterized -import brainpy import brainpy as bp import brainpy.math as bm - class Test_delay_couplings(parameterized.TestCase): - def test_DiffusiveCoupling(self): - bm.random.seed() - bm.set_dt(0.1) - - areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn1') - conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input, - conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require('conn_mat'), - initial_delay_data=bp.init.Uniform(0, 0.05)) - net = bp.Network(areas, conn) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['fhn1.x'], - inputs=('fhn1.input', 35.)) - runner(10.) - self.assertTupleEqual(runner.mon['fhn1.x'].shape, (100, 80)) - - def test_AdditiveCoupling(self): - bm.random.seed() - bm.set_dt(0.1) - - areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn2') - conn = bp.synapses.AdditiveCoupling(areas.x, areas.input, - conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require('conn_mat'), - initial_delay_data=bp.init.Uniform(0, 0.05)) - net = bp.Network(areas, conn) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['fhn2.x'], - inputs=('fhn2.input', 35.)) - runner(10.) - self.assertTupleEqual(runner.mon['fhn2.x'].shape, (100, 80)) + def test_DiffusiveCoupling(self): + bm.random.seed() + bm.set_dt(0.1) + + areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn1') + conn = bp.synapses.DiffusiveCoupling(areas.x, areas.x, areas.input, + conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require( + 'conn_mat'), + initial_delay_data=bp.init.Uniform(0, 0.05)) + net = bp.Network(areas, conn) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['fhn1.x'], + inputs=('fhn1.input', 35.)) + runner(10.) + self.assertTupleEqual(runner.mon['fhn1.x'].shape, (100, 80)) + + def test_AdditiveCoupling(self): + bm.random.seed() + bm.set_dt(0.1) + + areas = bp.rates.FHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn2') + conn = bp.synapses.AdditiveCoupling(areas.x, areas.input, + conn_mat=bp.conn.All2All(pre=areas.num, post=areas.num).require('conn_mat'), + initial_delay_data=bp.init.Uniform(0, 0.05)) + net = bp.Network(areas, conn) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['fhn2.x'], + inputs=('fhn2.input', 35.)) + runner(10.) + self.assertTupleEqual(runner.mon['fhn2.x'].shape, (100, 80)) diff --git a/brainpy/_src/dyn/utils.py b/brainpy/_src/dyn/utils.py index 0af1d4532..51d43acb3 100644 --- a/brainpy/_src/dyn/utils.py +++ b/brainpy/_src/dyn/utils.py @@ -1,16 +1,17 @@ -from typing import Optional, Union +from typing import Optional + import brainpy.math as bm __all__ = [ - 'get_spk_type', + 'get_spk_type', ] def get_spk_type(spk_type: Optional[type] = None, mode: Optional[bm.Mode] = None): - if mode is None: - return bm.bool - elif isinstance(mode, bm.TrainingMode): - return bm.float_ if (spk_type is None) else spk_type - else: - assert isinstance(mode, bm.Mode) - return bm.bool if (spk_type is None) else spk_type + if mode is None: + return bm.bool + elif isinstance(mode, bm.TrainingMode): + return bm.float_ if (spk_type is None) else spk_type + else: + assert isinstance(mode, bm.Mode) + return bm.bool if (spk_type is None) else spk_type diff --git a/brainpy/_src/dynold/experimental/abstract_synapses.py b/brainpy/_src/dynold/experimental/abstract_synapses.py index 012b1e76d..c67ea4c94 100644 --- a/brainpy/_src/dynold/experimental/abstract_synapses.py +++ b/brainpy/_src/dynold/experimental/abstract_synapses.py @@ -15,384 +15,384 @@ class Exponential(SynConnNS): - r"""Exponential decay synapse model. - - **Model Descriptions** - - The single exponential decay synapse model assumes the release of neurotransmitter, - its diffusion across the cleft, the receptor binding, and channel opening all happen - very quickly, so that the channels instantaneously jump from the closed to the open state. - Therefore, its expression is given by - - .. math:: - - g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau} - - where :math:`\tau_{delay}` is the time constant of the synaptic state decay, - :math:`t_0` is the time of the pre-synaptic spike, - :math:`g_{\mathrm{max}}` is the maximal conductance. - - Accordingly, the differential form of the exponential synapse is given by - - .. math:: - - \begin{aligned} - & g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\ - & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). - \end{aligned} - - where :math:`\mathrm{STP}` is used to model the short-term plasticity effect. - - Parameters:: - - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - tau: float, ArrayType - The time constant of decay. [ms] - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - References:: - - .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. - "The Synapse." Principles of Computational Modelling in Neuroscience. - Cambridge: Cambridge UP, 2011. 172-95. Print. - - """ - - def __init__( - self, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - out: Optional[SynOutNS] = None, - stp: Optional[SynSTPNS] = None, - comp_method: str = 'sparse', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - tau: Union[float, ArrayType] = 8.0, - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(conn=conn, - out=out, - stp=stp, - name=name, - mode=mode) - - # parameters - self.comp_method = comp_method - self.tau = is_float(tau, allow_int=True) - - # connections and weights - self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') - - # function - self.integral = odeint(lambda g, t: -g / self.tau, method=method) - - # variables - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - self.g = variable_(bm.zeros, self.post_num, batch_size) - if self.out is not None: - self.out.reset_state(batch_size) - if self.stp is not None: - self.stp.reset_state(batch_size) - - def update(self, pre_spike, post_v=None): - if self.stp is not None: - syn_value = self.stp(pre_spike) * pre_spike - else: - syn_value = pre_spike - - # post values - if isinstance(self.conn, All2All): - post_vs = self._syn2post_with_all2all(syn_value, self.g_max, self.conn.include_self) - elif isinstance(self.conn, One2One): - post_vs = self._syn2post_with_one2one(syn_value, self.g_max) - else: - if self.comp_method == 'sparse': - if self.stp is None: - f = lambda s: bm.event.csrmv(self.g_max, - self.conn_mask[0], - self.conn_mask[1], - s, - shape=(self.pre_num, self.post_num), - transpose=True) - if isinstance(self.mode, bm.BatchingMode): - f = vmap(f) + r"""Exponential decay synapse model. + + **Model Descriptions** + + The single exponential decay synapse model assumes the release of neurotransmitter, + its diffusion across the cleft, the receptor binding, and channel opening all happen + very quickly, so that the channels instantaneously jump from the closed to the open state. + Therefore, its expression is given by + + .. math:: + + g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau} + + where :math:`\tau_{delay}` is the time constant of the synaptic state decay, + :math:`t_0` is the time of the pre-synaptic spike, + :math:`g_{\mathrm{max}}` is the maximal conductance. + + Accordingly, the differential form of the exponential synapse is given by + + .. math:: + + \begin{aligned} + & g_{\mathrm{syn}}(t) = g_{max} g * \mathrm{STP} \\ + & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). + \end{aligned} + + where :math:`\mathrm{STP}` is used to model the short-term plasticity effect. + + Parameters:: + + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + tau: float, ArrayType + The time constant of decay. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References:: + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + + """ + + def __init__( + self, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + out: Optional[SynOutNS] = None, + stp: Optional[SynSTPNS] = None, + comp_method: str = 'sparse', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType] = 8.0, + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(conn=conn, + out=out, + stp=stp, + name=name, + mode=mode) + + # parameters + self.comp_method = comp_method + self.tau = is_float(tau, allow_int=True) + + # connections and weights + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') + + # function + self.integral = odeint(lambda g, t: -g / self.tau, method=method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.g = variable_(bm.zeros, self.post_num, batch_size) + if self.out is not None: + self.out.reset_state(batch_size) + if self.stp is not None: + self.stp.reset_state(batch_size) + + def update(self, pre_spike, post_v=None): + if self.stp is not None: + syn_value = self.stp(pre_spike) * pre_spike else: - f = lambda s: bm.sparse.csrmv(self.g_max, - self.conn_mask[0], - self.conn_mask[1], - s, - shape=(self.pre_num, self.post_num), - transpose=True, - method='cusparse') - if isinstance(self.mode, bm.BatchingMode): - f = vmap(f) - post_vs = f(pre_spike) - else: - post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - - # updates - self.g.value = self.integral(self.g.value, share.load('t'), bm.dt) + post_vs - - # outputs - if self.out is not None: - return self.out(self.g.value, post_v) - else: - return self.g.value + syn_value = pre_spike + + # post values + if isinstance(self.conn, All2All): + post_vs = self._syn2post_with_all2all(syn_value, self.g_max, self.conn.include_self) + elif isinstance(self.conn, One2One): + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) + else: + if self.comp_method == 'sparse': + if self.stp is None: + f = lambda s: bm.event.csrmv(self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.pre_num, self.post_num), + transpose=True) + if isinstance(self.mode, bm.BatchingMode): + f = vmap(f) + else: + f = lambda s: bm.sparse.csrmv(self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.pre_num, self.post_num), + transpose=True, + method='cusparse') + if isinstance(self.mode, bm.BatchingMode): + f = vmap(f) + post_vs = f(pre_spike) + else: + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + + # updates + self.g.value = self.integral(self.g.value, share.load('t'), bm.dt) + post_vs + + # outputs + if self.out is not None: + return self.out(self.g.value, post_v) + else: + return self.g.value class DualExponential(SynConnNS): - r"""Dual exponential synapse model. - - **Model Descriptions** - - The dual exponential synapse model [1]_, also named as *difference of two exponentials* model, - is given by: - - .. math:: - - g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{ - \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right) - -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right) - - where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2` - is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic - spike, :math:`g_{\mathrm{max}}` is the maximal conductance. - - However, in practice, this formula is hard to implement. The equivalent solution is - two coupled linear differential equations [2]_: - - .. math:: - - \begin{aligned} - &g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g * \mathrm{STP} \\ - &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\ - &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right), - \end{aligned} - - where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses. - - Parameters:: - - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - tau_decay: float, ArrayArray, ndarray - The time constant of the synaptic decay phase. [ms] - tau_rise: float, ArrayArray, ndarray - The time constant of the synaptic rise phase. [ms] - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - References:: - - .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. - "The Synapse." Principles of Computational Modelling in Neuroscience. - Cambridge: Cambridge UP, 2011. 172-95. Print. - .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational - Modeling Methods for Neuroscientists. - - """ - - def __init__( - self, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - out: Optional[SynOutNS] = None, - stp: Optional[SynSTPNS] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - tau_decay: Union[float, ArrayType] = 10.0, - tau_rise: Union[float, ArrayType] = 1., - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super(DualExponential, self).__init__(conn=conn, - out=out, - stp=stp, - name=name, - mode=mode) - # parameters - self.comp_method = comp_method - self.tau_rise = is_float(tau_rise, allow_int=True, allow_none=False) - self.tau_decay = is_float(tau_decay, allow_int=True, allow_none=False) - - # connections and weights - self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') - - # function - self.integral = odeint(JointEq(self.dg, self.dh), method=method) - - # variables - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - self.h = variable_(bm.zeros, self.conn.pre_num, batch_size) - self.g = variable_(bm.zeros, self.conn.pre_num, batch_size) - if self.out is not None: - self.out.reset_state(batch_size) - if self.stp is not None: - self.stp.reset_state(batch_size) - - def dh(self, h, t): - return -h / self.tau_rise - - def dg(self, g, t, h): - return -g / self.tau_decay + h - - def update(self, pre_spike, post_v=None): - t = share.load('t') - dt = share.load('dt') - - # update synaptic variables - self.g.value, self.h.value = self.integral(self.g.value, self.h.value, t, dt=dt) - self.h += pre_spike - - # post values - syn_value = self.g.value - if self.stp is not None: - syn_value = self.stp(syn_value) - - if isinstance(self.conn, All2All): - post_vs = self._syn2post_with_all2all(syn_value, self.g_max, self.conn.include_self) - elif isinstance(self.conn, One2One): - post_vs = self._syn2post_with_one2one(syn_value, self.g_max) - else: - if self.comp_method == 'sparse': - f = lambda s: bm.sparse.csrmv( - self.g_max, - self.conn_mask[0], - self.conn_mask[1], - s, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=True, - method='cusparse' - ) - if isinstance(self.mode, bm.BatchingMode): - f = vmap(f) - post_vs = f(syn_value) - else: - post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - - # outputs - if self.out is not None: - return self.out(post_vs, post_v) - else: - return post_vs + r"""Dual exponential synapse model. + + **Model Descriptions** + + The dual exponential synapse model [1]_, also named as *difference of two exponentials* model, + is given by: + + .. math:: + + g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{ + \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right) + -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right) + + where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2` + is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic + spike, :math:`g_{\mathrm{max}}` is the maximal conductance. + + However, in practice, this formula is hard to implement. The equivalent solution is + two coupled linear differential equations [2]_: + + .. math:: + + \begin{aligned} + &g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g * \mathrm{STP} \\ + &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\ + &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right), + \end{aligned} + + where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses. + + Parameters:: + + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + tau_decay: float, ArrayArray, ndarray + The time constant of the synaptic decay phase. [ms] + tau_rise: float, ArrayArray, ndarray + The time constant of the synaptic rise phase. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References:: + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational + Modeling Methods for Neuroscientists. + + """ + + def __init__( + self, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + out: Optional[SynOutNS] = None, + stp: Optional[SynSTPNS] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau_decay: Union[float, ArrayType] = 10.0, + tau_rise: Union[float, ArrayType] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(DualExponential, self).__init__(conn=conn, + out=out, + stp=stp, + name=name, + mode=mode) + # parameters + self.comp_method = comp_method + self.tau_rise = is_float(tau_rise, allow_int=True, allow_none=False) + self.tau_decay = is_float(tau_decay, allow_int=True, allow_none=False) + + # connections and weights + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') + + # function + self.integral = odeint(JointEq(self.dg, self.dh), method=method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.h = variable_(bm.zeros, self.conn.pre_num, batch_size) + self.g = variable_(bm.zeros, self.conn.pre_num, batch_size) + if self.out is not None: + self.out.reset_state(batch_size) + if self.stp is not None: + self.stp.reset_state(batch_size) + + def dh(self, h, t): + return -h / self.tau_rise + + def dg(self, g, t, h): + return -g / self.tau_decay + h + + def update(self, pre_spike, post_v=None): + t = share.load('t') + dt = share.load('dt') + + # update synaptic variables + self.g.value, self.h.value = self.integral(self.g.value, self.h.value, t, dt=dt) + self.h += pre_spike + + # post values + syn_value = self.g.value + if self.stp is not None: + syn_value = self.stp(syn_value) + + if isinstance(self.conn, All2All): + post_vs = self._syn2post_with_all2all(syn_value, self.g_max, self.conn.include_self) + elif isinstance(self.conn, One2One): + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) + else: + if self.comp_method == 'sparse': + f = lambda s: bm.sparse.csrmv( + self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=True, + method='cusparse' + ) + if isinstance(self.mode, bm.BatchingMode): + f = vmap(f) + post_vs = f(syn_value) + else: + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + + # outputs + if self.out is not None: + return self.out(post_vs, post_v) + else: + return post_vs class Alpha(DualExponential): - r"""Alpha synapse model. - - **Model Descriptions** - - The analytical expression of alpha synapse is given by: - - .. math:: - - g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right). - - While, this equation is hard to implement. So, let's try to convert it into the - differential forms: - - .. math:: - - \begin{aligned} - &g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\ - &\frac{d g}{d t}=-\frac{g}{\tau}+h \\ - &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right) - \end{aligned} - - **Model Examples** - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> from brainpy import neurons, synapses, synouts - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.LIF(1) - >>> neu2 = neurons.LIF(1) - >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.legend() - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') - >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau_decay: float, ArrayType - The time constant of the synaptic decay phase. [ms] - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - References:: - - .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. - "The Synapse." Principles of Computational Modelling in Neuroscience. - Cambridge: Cambridge UP, 2011. 172-95. Print. - """ - - def __init__( - self, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - out: Optional[SynOutNS] = None, - stp: Optional[SynSTPNS] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - tau_decay: Union[float, ArrayType] = 10.0, - method: str = 'exp_auto', - - # other parameters - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(conn=conn, - comp_method=comp_method, - g_max=g_max, - tau_decay=tau_decay, - tau_rise=tau_decay, - method=method, - out=out, - stp=stp, - name=name, - mode=mode) + r"""Alpha synapse model. + + **Model Descriptions** + + The analytical expression of alpha synapse is given by: + + .. math:: + + g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right). + + While, this equation is hard to implement. So, let's try to convert it into the + differential forms: + + .. math:: + + \begin{aligned} + &g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\ + &\frac{d g}{d t}=-\frac{g}{\tau}+h \\ + &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right) + \end{aligned} + + **Model Examples** + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses, synouts + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau_decay: float, ArrayType + The time constant of the synaptic decay phase. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References:: + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + """ + + def __init__( + self, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + out: Optional[SynOutNS] = None, + stp: Optional[SynSTPNS] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau_decay: Union[float, ArrayType] = 10.0, + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(conn=conn, + comp_method=comp_method, + g_max=g_max, + tau_decay=tau_decay, + tau_rise=tau_decay, + method=method, + out=out, + stp=stp, + name=name, + mode=mode) diff --git a/brainpy/_src/dynold/experimental/base.py b/brainpy/_src/dynold/experimental/base.py index 0ff0d6cbc..c7350ddb7 100644 --- a/brainpy/_src/dynold/experimental/base.py +++ b/brainpy/_src/dynold/experimental/base.py @@ -11,123 +11,123 @@ class SynConnNS(DynamicalSystem): - def __init__( - self, - conn: TwoEndConnector, - out: Optional['SynOutNS'] = None, - stp: Optional['SynSTPNS'] = None, - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(name=name, mode=mode) - - # parameters - assert isinstance(conn, TwoEndConnector) - self.conn = self._init_conn(conn) - self.pre_size = conn.pre_size - self.post_size = conn.post_size - self.pre_num = conn.pre_num - self.post_num = conn.post_num - assert out is None or isinstance(out, SynOutNS) - assert stp is None or isinstance(stp, SynSTPNS) - self.out = out - self.stp = stp - - def _init_conn(self, conn): - if isinstance(conn, TwoEndConnector): - pass - elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)): - if (self.pre_num, self.post_num) != conn.shape: - raise ValueError(f'"conn" is provided as a matrix, and it is expected ' - f'to be an array with shape of (self.pre_num, self.post_num) = ' - f'{(self.pre_num, self.post_num)}, however we got {conn.shape}') - conn = MatConn(conn_mat=conn) - elif isinstance(conn, dict): - if not ('i' in conn and 'j' in conn): - raise ValueError(f'"conn" is provided as a dict, and it is expected to ' - f'be a dictionary with "i" and "j" specification, ' - f'however we got {conn}') - conn = IJConn(i=conn['i'], j=conn['j']) - elif conn is None: - conn = None - else: - raise ValueError(f'Unknown "conn" type: {conn}') - return conn - - def _init_weights( - self, - weight: Union[float, ArrayType, Initializer, Callable], - comp_method: str, - data_if_sparse: str = 'csr' - ) -> Tuple[Union[float, ArrayType], ArrayType]: - if comp_method not in ['sparse', 'dense']: - raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') - if data_if_sparse not in ['csr', 'ij', 'coo']: - raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {data_if_sparse}') - - # connections and weights - if isinstance(self.conn, One2One): - weight = parameter(weight, (self.pre_num,), allow_none=False) - conn_mask = None - - elif isinstance(self.conn, All2All): - weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) - conn_mask = None - - else: - if comp_method == 'sparse': - if data_if_sparse == 'csr': - conn_mask = self.conn.require('pre2post') - elif data_if_sparse in ['ij', 'coo']: - conn_mask = self.conn.require('post_ids', 'pre_ids') + def __init__( + self, + conn: TwoEndConnector, + out: Optional['SynOutNS'] = None, + stp: Optional['SynSTPNS'] = None, + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(name=name, mode=mode) + + # parameters + assert isinstance(conn, TwoEndConnector) + self.conn = self._init_conn(conn) + self.pre_size = conn.pre_size + self.post_size = conn.post_size + self.pre_num = conn.pre_num + self.post_num = conn.post_num + assert out is None or isinstance(out, SynOutNS) + assert stp is None or isinstance(stp, SynSTPNS) + self.out = out + self.stp = stp + + def _init_conn(self, conn): + if isinstance(conn, TwoEndConnector): + pass + elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)): + if (self.pre_num, self.post_num) != conn.shape: + raise ValueError(f'"conn" is provided as a matrix, and it is expected ' + f'to be an array with shape of (self.pre_num, self.post_num) = ' + f'{(self.pre_num, self.post_num)}, however we got {conn.shape}') + conn = MatConn(conn_mat=conn) + elif isinstance(conn, dict): + if not ('i' in conn and 'j' in conn): + raise ValueError(f'"conn" is provided as a dict, and it is expected to ' + f'be a dictionary with "i" and "j" specification, ' + f'however we got {conn}') + conn = IJConn(i=conn['i'], j=conn['j']) + elif conn is None: + conn = None else: - ValueError(f'Unknown sparse data type: {data_if_sparse}') - weight = parameter(weight, conn_mask[0].shape, allow_none=False) - elif comp_method == 'dense': - weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) - conn_mask = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {comp_method}') - - # training weights - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - return weight, conn_mask - - def _syn2post_with_all2all(self, syn_value, syn_weight, include_self): - if bm.ndim(syn_weight) == 0: - if isinstance(self.mode, bm.BatchingMode): - post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) - else: - post_vs = bm.sum(syn_value) - if not include_self: - post_vs = post_vs - syn_value - post_vs = syn_weight * post_vs - else: - post_vs = syn_value @ syn_weight - return post_vs - - def _syn2post_with_one2one(self, syn_value, syn_weight): - return syn_value * syn_weight - - def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): - if bm.ndim(syn_weight) == 0: - post_vs = (syn_weight * syn_value) @ conn_mat - else: - post_vs = syn_value @ (syn_weight * conn_mat) - return post_vs + raise ValueError(f'Unknown "conn" type: {conn}') + return conn + + def _init_weights( + self, + weight: Union[float, ArrayType, Initializer, Callable], + comp_method: str, + data_if_sparse: str = 'csr' + ) -> Tuple[Union[float, ArrayType], ArrayType]: + if comp_method not in ['sparse', 'dense']: + raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') + if data_if_sparse not in ['csr', 'ij', 'coo']: + raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {data_if_sparse}') + + # connections and weights + if isinstance(self.conn, One2One): + weight = parameter(weight, (self.pre_num,), allow_none=False) + conn_mask = None + + elif isinstance(self.conn, All2All): + weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) + conn_mask = None + + else: + if comp_method == 'sparse': + if data_if_sparse == 'csr': + conn_mask = self.conn.require('pre2post') + elif data_if_sparse in ['ij', 'coo']: + conn_mask = self.conn.require('post_ids', 'pre_ids') + else: + ValueError(f'Unknown sparse data type: {data_if_sparse}') + weight = parameter(weight, conn_mask[0].shape, allow_none=False) + elif comp_method == 'dense': + weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) + conn_mask = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {comp_method}') + + # training weights + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + return weight, conn_mask + + def _syn2post_with_all2all(self, syn_value, syn_weight, include_self): + if bm.ndim(syn_weight) == 0: + if isinstance(self.mode, bm.BatchingMode): + post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) + else: + post_vs = bm.sum(syn_value) + if not include_self: + post_vs = post_vs - syn_value + post_vs = syn_weight * post_vs + else: + post_vs = syn_value @ syn_weight + return post_vs + + def _syn2post_with_one2one(self, syn_value, syn_weight): + return syn_value * syn_weight + + def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): + if bm.ndim(syn_weight) == 0: + post_vs = (syn_weight * syn_value) @ conn_mat + else: + post_vs = syn_value @ (syn_weight * conn_mat) + return post_vs class SynOutNS(DynamicalSystem): - def update(self, post_g, post_v): - raise NotImplementedError + def update(self, post_g, post_v): + raise NotImplementedError - def reset_state(self, batch_size: Optional[int] = None): - pass + def reset_state(self, batch_size: Optional[int] = None): + pass class SynSTPNS(DynamicalSystem): - """Base class for synaptic short-term plasticity.""" + """Base class for synaptic short-term plasticity.""" - def update(self, pre_spike): - raise NotImplementedError + def update(self, pre_spike): + raise NotImplementedError diff --git a/brainpy/_src/dynold/experimental/others.py b/brainpy/_src/dynold/experimental/others.py index bbd45ab50..30d9e17ad 100644 --- a/brainpy/_src/dynold/experimental/others.py +++ b/brainpy/_src/dynold/experimental/others.py @@ -1,83 +1,80 @@ - from typing import Union, Optional import brainpy.math as bm -from brainpy._src.dynsys import DynamicalSystem from brainpy._src.context import share +from brainpy._src.dynsys import DynamicalSystem from brainpy.check import is_float, is_integer class PoissonInput(DynamicalSystem): - """Poisson Input. - - Adds independent Poisson input to a target variable. For large - numbers of inputs, this is much more efficient than creating a - `PoissonGroup`. The synaptic events are generated randomly during the - simulation and are not preloaded and stored in memory. All the inputs must - target the same variable, have the same frequency and same synaptic weight. - All neurons in the target variable receive independent realizations of - Poisson spike trains. - - Parameters:: - - num_input: int - The number of inputs. - freq: float - The frequency of each of the inputs. Must be a scalar. - weight: float - The synaptic weight. Must be a scalar. - """ - - def __init__( - self, - target_shape, - num_input: int, - freq: Union[int, float], - weight: Union[int, float], - seed: Optional[int] = None, - mode: bm.Mode = None, - name: str = None - ): - super(PoissonInput, self).__init__(name=name, mode=mode) - - # check data - is_integer(num_input, 'num_input', min_bound=1) - is_float(freq, 'freq', min_bound=0., allow_int=True) - is_float(weight, 'weight', allow_int=True) - assert self.mode.is_parent_of(bm.NonBatchingMode, bm.BatchingMode) - - # parameters - self.target_shape = target_shape - self.num_input = num_input - self.freq = freq - self.weight = weight - self.seed = seed - - def update(self): - p = self.freq * share.dt / 1e3 - a = self.num_input * p - b = self.num_input * (1 - p) - if isinstance(share.dt, (int, float)): # dt is not in tracing - if (a > 5) and (b > 5): - inp = bm.random.normal(a, b * p, self.target_shape) - else: - inp = bm.random.binomial(self.num_input, p, self.target_shape) - - else: # dt is in tracing - inp = bm.cond((a > 5) * (b > 5), - lambda _: bm.random.normal(a, b * p, self.target_shape), - lambda _: bm.random.binomial(self.num_input, p, self.target_shape), - None) - return inp * self.weight - - def __repr__(self): - names = self.__class__.__name__ - return f'{names}(shape={self.target_shape}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})' - - def reset_state(self, batch_size=None): - pass - - def reset(self, batch_size=None): - self.reset_state(batch_size) - - + """Poisson Input. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Parameters:: + + num_input: int + The number of inputs. + freq: float + The frequency of each of the inputs. Must be a scalar. + weight: float + The synaptic weight. Must be a scalar. + """ + + def __init__( + self, + target_shape, + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + seed: Optional[int] = None, + mode: bm.Mode = None, + name: str = None + ): + super(PoissonInput, self).__init__(name=name, mode=mode) + + # check data + is_integer(num_input, 'num_input', min_bound=1) + is_float(freq, 'freq', min_bound=0., allow_int=True) + is_float(weight, 'weight', allow_int=True) + assert self.mode.is_parent_of(bm.NonBatchingMode, bm.BatchingMode) + + # parameters + self.target_shape = target_shape + self.num_input = num_input + self.freq = freq + self.weight = weight + self.seed = seed + + def update(self): + p = self.freq * share.dt / 1e3 + a = self.num_input * p + b = self.num_input * (1 - p) + if isinstance(share.dt, (int, float)): # dt is not in tracing + if (a > 5) and (b > 5): + inp = bm.random.normal(a, b * p, self.target_shape) + else: + inp = bm.random.binomial(self.num_input, p, self.target_shape) + + else: # dt is in tracing + inp = bm.cond((a > 5) * (b > 5), + lambda _: bm.random.normal(a, b * p, self.target_shape), + lambda _: bm.random.binomial(self.num_input, p, self.target_shape), + None) + return inp * self.weight + + def __repr__(self): + names = self.__class__.__name__ + return f'{names}(shape={self.target_shape}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})' + + def reset_state(self, batch_size=None): + pass + + def reset(self, batch_size=None): + self.reset_state(batch_size) diff --git a/brainpy/_src/dynold/experimental/syn_outs.py b/brainpy/_src/dynold/experimental/syn_outs.py index a7c6d3de8..763ebdb30 100644 --- a/brainpy/_src/dynold/experimental/syn_outs.py +++ b/brainpy/_src/dynold/experimental/syn_outs.py @@ -7,113 +7,112 @@ from brainpy.types import ArrayType __all__ = [ - 'COBA', - 'CUBA', - 'MgBlock', + 'COBA', + 'CUBA', + 'MgBlock', ] class COBA(SynOutNS): - r"""Conductance-based synaptic output. + r"""Conductance-based synaptic output. - Given the synaptic conductance, the model output the post-synaptic current with + Given the synaptic conductance, the model output the post-synaptic current with - .. math:: + .. math:: - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - Parameters:: + Parameters:: - E: float, ArrayType, ndarray - The reversal potential. - name: str - The model name. + E: float, ArrayType, ndarray + The reversal potential. + name: str + The model name. - See Also:: + See Also:: - CUBA - """ + CUBA + """ - def __init__(self, E: Union[float, ArrayType] = 0., name: str = None, ): - super().__init__(name=name) - self.E = E + def __init__(self, E: Union[float, ArrayType] = 0., name: str = None, ): + super().__init__(name=name) + self.E = E - def update(self, conductance, potential): - return conductance * (self.E - potential) + def update(self, conductance, potential): + return conductance * (self.E - potential) class CUBA(SynOutNS): - r"""Current-based synaptic output. + r"""Current-based synaptic output. - Given the conductance, this model outputs the post-synaptic current with a identity function: + Given the conductance, this model outputs the post-synaptic current with a identity function: - .. math:: + .. math:: - I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) + I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - Parameters:: + Parameters:: - name: str - The model name. + name: str + The model name. - See Also:: + See Also:: - COBA - """ + COBA + """ - def __init__(self, name: str = None, ): - super().__init__(name=name) + def __init__(self, name: str = None, ): + super().__init__(name=name) - def update(self, conductance, potential=None): - return conductance + def update(self, conductance, potential=None): + return conductance class MgBlock(SynOutNS): - r"""Synaptic output based on Magnesium blocking. + r"""Synaptic output based on Magnesium blocking. - Given the synaptic conductance, the model output the post-synaptic current with + Given the synaptic conductance, the model output the post-synaptic current with - .. math:: + .. math:: - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) - where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to - .. math:: + .. math:: - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - Parameters:: + Parameters:: - E: float, ArrayType - The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType - Binding constant. Default 0.062 - beta: float, ArrayType - Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType - Concentration of Magnesium ion. Default 1.2 [mM]. - name: str - The model name. - """ + E: float, ArrayType + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType + Concentration of Magnesium ion. Default 1.2 [mM]. + name: str + The model name. + """ - def __init__( - self, - E: Union[float, ArrayType] = 0., - cc_Mg: Union[float, ArrayType] = 1.2, - alpha: Union[float, ArrayType] = 0.062, - beta: Union[float, ArrayType] = 3.57, - name: str = None, - ): - super().__init__(name=name) - self.E = E - self.cc_Mg = cc_Mg - self.alpha = alpha - self.beta = beta - - def update(self, conductance, potential): - return conductance * (self.E - potential) / (1 + self.cc_Mg / self.beta * exp(-self.alpha * potential)) + def __init__( + self, + E: Union[float, ArrayType] = 0., + cc_Mg: Union[float, ArrayType] = 1.2, + alpha: Union[float, ArrayType] = 0.062, + beta: Union[float, ArrayType] = 3.57, + name: str = None, + ): + super().__init__(name=name) + self.E = E + self.cc_Mg = cc_Mg + self.alpha = alpha + self.beta = beta + def update(self, conductance, potential): + return conductance * (self.E - potential) / (1 + self.cc_Mg / self.beta * exp(-self.alpha * potential)) diff --git a/brainpy/_src/dynold/experimental/syn_plasticity.py b/brainpy/_src/dynold/experimental/syn_plasticity.py index 4bde77182..20a60b591 100644 --- a/brainpy/_src/dynold/experimental/syn_plasticity.py +++ b/brainpy/_src/dynold/experimental/syn_plasticity.py @@ -12,157 +12,157 @@ from brainpy.types import ArrayType, Shape __all__ = [ - 'STD', - 'STP', + 'STD', + 'STP', ] class STD(SynSTPNS): - r"""Synaptic output with short-term depression. + r"""Synaptic output with short-term depression. - This model filters the synaptic current by the following equation: + This model filters the synaptic current by the following equation: - .. math:: + .. math:: - I_{syn}^+(t) = I_{syn}^-(t) * x + I_{syn}^+(t) = I_{syn}^-(t) * x - where :math:`x` is the normalized variable between 0 and 1, and - :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before - and after STD filtering. + where :math:`x` is the normalized variable between 0 and 1, and + :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STD filtering. - Moreover, :math:`x` is updated according to the dynamics of: + Moreover, :math:`x` is updated according to the dynamics of: - .. math:: + .. math:: - \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike}) + \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike}) - where :math:`U` is the fraction of resources used per action potential, - :math:`\tau` is the time constant of recovery of the synaptic vesicles. + where :math:`U` is the fraction of resources used per action potential, + :math:`\tau` is the time constant of recovery of the synaptic vesicles. - Parameters:: + Parameters:: - tau: float - The time constant of recovery of the synaptic vesicles. - U: float - The fraction of resources used per action potential. + tau: float + The time constant of recovery of the synaptic vesicles. + U: float + The fraction of resources used per action potential. - See Also:: + See Also:: - STP - """ + STP + """ - def __init__( - self, - pre_size: Shape, - tau: float = 200., - U: float = 0.07, - method: str = 'exp_auto', - name: str = None - ): - super().__init__(name=name) + def __init__( + self, + pre_size: Shape, + tau: float = 200., + U: float = 0.07, + method: str = 'exp_auto', + name: str = None + ): + super().__init__(name=name) - # parameters - self.pre_size = tools.to_size(pre_size) - self.num = tools.size2num(self.pre_size) - self.U = parameter(U, self.num) - self.tau = parameter(tau, self.num) - self.method = method + # parameters + self.pre_size = tools.to_size(pre_size) + self.num = tools.size2num(self.pre_size) + self.U = parameter(U, self.num) + self.tau = parameter(tau, self.num) + self.method = method - # integral function - self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=self.method) + # integral function + self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=self.method) - # variables - self.reset_state(self.mode) + # variables + self.reset_state(self.mode) - def reset_state(self, batch_size=None): - self.x = variable_(jnp.ones, self.num, batch_size) + def reset_state(self, batch_size=None): + self.x = variable_(jnp.ones, self.num, batch_size) - def update(self, pre_spike): - x = self.integral(self.x.value, share.load('t'), share.load('dt')) - self.x.value = bm.where(pre_spike, x - self.U * self.x, x) - return self.x.value + def update(self, pre_spike): + x = self.integral(self.x.value, share.load('t'), share.load('dt')) + self.x.value = bm.where(pre_spike, x - self.U * self.x, x) + return self.x.value class STP(SynSTPNS): - r"""Synaptic output with short-term plasticity. - - This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. - - .. math:: - - I_{syn}^+(t) = I_{syn}^-(t) * x * u - - where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before - and after STP filtering, :math:`x` denotes the fraction of resources that remain available - after neurotransmitter depletion, and :math:`u` represents the fraction of available - resources ready for use (release probability). - - The dynamics of :math:`u` and :math:`x` are governed by - - .. math:: - - \begin{aligned} - \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\ - \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ - \tag{1}\end{aligned} - - where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment - of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding - variables just before the arrival of the spike, and :math:`u^+` - refers to the moment just after the spike. - - Parameters:: - - tau_f: float - The time constant of short-term facilitation. - tau_d: float - The time constant of short-term depression. - U: float - The fraction of resources used per action potential. - method: str - The numerical integral method. - - See Also:: - - STD - """ - - def __init__( - self, - pre_size: Shape, - U: Union[float, ArrayType] = 0.15, - tau_f: Union[float, ArrayType] = 1500., - tau_d: Union[float, ArrayType] = 200., - method: str = 'exp_auto', - name: str = None - ): - super().__init__(name=name) - - # parameters - self.pre_size = tools.to_size(pre_size) - self.num = tools.size2num(self.pre_size) - self.tau_f = parameter(tau_f, self.num) - self.tau_d = parameter(tau_d, self.num) - self.U = parameter(U, self.num) - self.method = method - - # integral function - self.integral = odeint(JointEq([self.du, self.dx]), method=self.method) - - # variables - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - self.x = variable_(jnp.ones, batch_size, self.num) - self.u = variable_(OneInit(self.U), batch_size, self.num) - - du = lambda self, u, t: self.U - u / self.tau_f - dx = lambda self, x, t: (1 - x) / self.tau_d - - def update(self, pre_spike): - u, x = self.integral(self.u.value, self.x.value, share.load('t'), bm.get_dt()) - u = bm.where(pre_spike, u + self.U * (1 - self.u), u) - x = bm.where(pre_spike, x - u * self.x, x) - self.x.value = x - self.u.value = u - return self.x.value * self.u.value + r"""Synaptic output with short-term plasticity. + + This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * x * u + + where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STP filtering, :math:`x` denotes the fraction of resources that remain available + after neurotransmitter depletion, and :math:`u` represents the fraction of available + resources ready for use (release probability). + + The dynamics of :math:`u` and :math:`x` are governed by + + .. math:: + + \begin{aligned} + \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\ + \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ + \tag{1}\end{aligned} + + where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment + of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding + variables just before the arrival of the spike, and :math:`u^+` + refers to the moment just after the spike. + + Parameters:: + + tau_f: float + The time constant of short-term facilitation. + tau_d: float + The time constant of short-term depression. + U: float + The fraction of resources used per action potential. + method: str + The numerical integral method. + + See Also:: + + STD + """ + + def __init__( + self, + pre_size: Shape, + U: Union[float, ArrayType] = 0.15, + tau_f: Union[float, ArrayType] = 1500., + tau_d: Union[float, ArrayType] = 200., + method: str = 'exp_auto', + name: str = None + ): + super().__init__(name=name) + + # parameters + self.pre_size = tools.to_size(pre_size) + self.num = tools.size2num(self.pre_size) + self.tau_f = parameter(tau_f, self.num) + self.tau_d = parameter(tau_d, self.num) + self.U = parameter(U, self.num) + self.method = method + + # integral function + self.integral = odeint(JointEq([self.du, self.dx]), method=self.method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.x = variable_(jnp.ones, batch_size, self.num) + self.u = variable_(OneInit(self.U), batch_size, self.num) + + du = lambda self, u, t: self.U - u / self.tau_f + dx = lambda self, x, t: (1 - x) / self.tau_d + + def update(self, pre_spike): + u, x = self.integral(self.u.value, self.x.value, share.load('t'), bm.get_dt()) + u = bm.where(pre_spike, u + self.U * (1 - self.u), u) + x = bm.where(pre_spike, x - u * self.x, x) + self.x.value = x + self.u.value = u + return self.x.value * self.u.value diff --git a/brainpy/_src/dynold/neurons/biological_models.py b/brainpy/_src/dynold/neurons/biological_models.py index 907e7ede3..775048398 100644 --- a/brainpy/_src/dynold/neurons/biological_models.py +++ b/brainpy/_src/dynold/neurons/biological_models.py @@ -5,8 +5,8 @@ import brainpy.math as bm from brainpy import check from brainpy._src.context import share -from brainpy._src.dyn.neurons import hh from brainpy._src.dyn.base import NeuDyn +from brainpy._src.dyn.neurons import hh from brainpy._src.initialize import (OneInit, Initializer, parameter, @@ -18,807 +18,807 @@ from brainpy.types import Shape, ArrayType __all__ = [ - 'HH', - 'MorrisLecar', - 'PinskyRinzelModel', - 'WangBuzsakiModel', + 'HH', + 'MorrisLecar', + 'PinskyRinzelModel', + 'WangBuzsakiModel', ] class HH(hh.HH): - r"""Hodgkin–Huxley neuron model. + r"""Hodgkin–Huxley neuron model. - **Model Descriptions** + **Model Descriptions** - The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of - the nerve action potential is one of the most successful mathematical models of - a complex biological process that has ever been formulated. The basic concepts - expressed in the model have proved a valid approach to the study of bio-electrical - activity from the most primitive single-celled organisms such as *Paramecium*, - right through to the neurons within our own brains. + The Hodgkin-Huxley (HH; Hodgkin & Huxley, 1952) model [1]_ for the generation of + the nerve action potential is one of the most successful mathematical models of + a complex biological process that has ever been formulated. The basic concepts + expressed in the model have proved a valid approach to the study of bio-electrical + activity from the most primitive single-celled organisms such as *Paramecium*, + right through to the neurons within our own brains. - Mathematically, the model is given by, - - .. math:: - - C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) - + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) - - \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} - - &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} - - &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) - - &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) - - &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} - - &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} - - &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) - - The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. - - The Hodgkin–Huxley model can be thought of as a differential equation system with - four state variables, :math:`V_{m}(t),n(t),m(t)`, and :math:`h(t)`, that change - with respect to time :math:`t`. The system is difficult to study because it is a - nonlinear system and cannot be solved analytically. However, there are many numeric - methods available to analyze the system. Certain properties and general behaviors, - such as limit cycles, can be proven to exist. - - *1. Center manifold* - - Because there are four state variables, visualizing the path in phase space can - be difficult. Usually two variables are chosen, voltage :math:`V_{m}(t)` and the - potassium gating variable :math:`n(t)`, allowing one to visualize the limit cycle. - However, one must be careful because this is an ad-hoc method of visualizing the - 4-dimensional system. This does not prove the existence of the limit cycle. - - .. image:: ../../../_static/Hodgkin_Huxley_Limit_Cycle.png - :align: center - - A better projection can be constructed from a careful analysis of the Jacobian of - the system, evaluated at the equilibrium point. Specifically, the eigenvalues of - the Jacobian are indicative of the center manifold's existence. Likewise, the - eigenvectors of the Jacobian reveal the center manifold's orientation. The - Hodgkin–Huxley model has two negative eigenvalues and two complex eigenvalues - with slightly positive real parts. The eigenvectors associated with the two - negative eigenvalues will reduce to zero as time :math:`t` increases. The remaining - two complex eigenvectors define the center manifold. In other words, the - 4-dimensional system collapses onto a 2-dimensional plane. Any solution - starting off the center manifold will decay towards the *center manifold*. - Furthermore, the limit cycle is contained on the center manifold. - - *2. Bifurcations* - - If the injected current :math:`I` were used as a bifurcation parameter, then the - Hodgkin–Huxley model undergoes a Hopf bifurcation. As with most neuronal models, - increasing the injected current will increase the firing rate of the neuron. - One consequence of the Hopf bifurcation is that there is a minimum firing rate. - This means that either the neuron is not firing at all (corresponding to zero - frequency), or firing at the minimum firing rate. Because of the all-or-none - principle, there is no smooth increase in action potential amplitude, but - rather there is a sudden "jump" in amplitude. The resulting transition is - known as a `canard `_. - - .. image:: ../../../_static/Hodgkins_Huxley_bifurcation_by_I.gif - :align: center - - The following image shows the bifurcation diagram of the Hodgkin–Huxley model - as a function of the external drive :math:`I` [3]_. The green lines show the amplitude - of a stable limit cycle and the blue lines indicate unstable limit-cycle behaviour, - both born from Hopf bifurcations. The solid red line shows the stable fixed point - and the black line shows the unstable fixed point. - - .. image:: ../../../_static/Hodgkin_Huxley_bifurcation.png - :align: center - - **Model Examples** - - >>> import brainpy as bp - >>> group = bp.neurons.HH(2) - >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 10.)) - >>> runner.run(200.) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) - - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> - >>> group = bp.neurons.HH(2) - >>> - >>> I1 = bp.inputs.spike_input(sp_times=[500., 550., 1000, 1030, 1060, 1100, 1200], sp_lens=5, sp_sizes=5., duration=2000, ) - >>> I2 = bp.inputs.spike_input(sp_times=[600., 900, 950, 1500], sp_lens=5, sp_sizes=5., duration=2000, ) - >>> I1 += bp.math.random.normal(0, 3, size=I1.shape) - >>> I2 += bp.math.random.normal(0, 3, size=I2.shape) - >>> I = bm.stack((I1, I2), axis=-1) - >>> - >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', I, 'iter')) - >>> runner.run(2000.) - >>> - >>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon.V[:, 0]) - >>> plt.plot(runner.mon.ts, runner.mon.V[:, 1] + 130) - >>> plt.xlim(10, 2000) - >>> plt.xticks([]) - >>> plt.yticks([]) - >>> plt.show() - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable - The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable - The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable - The membrane capacitance. Default is 1 ufarad. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - m_initializer: ArrayType, Initializer, callable - The initializer of m channel. - h_initializer: ArrayType, Initializer, callable - The initializer of h channel. - n_initializer: ArrayType, Initializer, callable - The initializer of n channel. - method: str - The numerical integration method. - name: str - The group name. - - References:: - - .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description - of membrane current and its application to conduction and excitation - in nerve." The Journal of physiology 117.4 (1952): 500. - .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model - .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical - frameworks for oscillatory network dynamics in neuroscience." - The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. - """ - - def __init__( - self, - *args, - input_var: bool = True, - **kwargs, - ): - self.input_var = input_var - super().__init__(*args, **kwargs, init_var=False) - - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + Mathematically, the model is given by, + + .. math:: + + C \frac {dV} {dt} = -(\bar{g}_{Na} m^3 h (V &-E_{Na}) + + \bar{g}_K n^4 (V-E_K) + g_{leak} (V - E_{leak})) + I(t) + + \frac {dx} {dt} &= \alpha_x (1-x) - \beta_x, \quad x\in {\rm{\{m, h, n\}}} + + &\alpha_m(V) = \frac {0.1(V+40)}{1-\exp(\frac{-(V + 40)} {10})} + + &\beta_m(V) = 4.0 \exp(\frac{-(V + 65)} {18}) + + &\alpha_h(V) = 0.07 \exp(\frac{-(V+65)}{20}) + + &\beta_h(V) = \frac 1 {1 + \exp(\frac{-(V + 35)} {10})} + + &\alpha_n(V) = \frac {0.01(V+55)}{1-\exp(-(V+55)/10)} + + &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) + + The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. + + The Hodgkin–Huxley model can be thought of as a differential equation system with + four state variables, :math:`V_{m}(t),n(t),m(t)`, and :math:`h(t)`, that change + with respect to time :math:`t`. The system is difficult to study because it is a + nonlinear system and cannot be solved analytically. However, there are many numeric + methods available to analyze the system. Certain properties and general behaviors, + such as limit cycles, can be proven to exist. + + *1. Center manifold* + + Because there are four state variables, visualizing the path in phase space can + be difficult. Usually two variables are chosen, voltage :math:`V_{m}(t)` and the + potassium gating variable :math:`n(t)`, allowing one to visualize the limit cycle. + However, one must be careful because this is an ad-hoc method of visualizing the + 4-dimensional system. This does not prove the existence of the limit cycle. + + .. image:: ../../../_static/Hodgkin_Huxley_Limit_Cycle.png + :align: center + + A better projection can be constructed from a careful analysis of the Jacobian of + the system, evaluated at the equilibrium point. Specifically, the eigenvalues of + the Jacobian are indicative of the center manifold's existence. Likewise, the + eigenvectors of the Jacobian reveal the center manifold's orientation. The + Hodgkin–Huxley model has two negative eigenvalues and two complex eigenvalues + with slightly positive real parts. The eigenvectors associated with the two + negative eigenvalues will reduce to zero as time :math:`t` increases. The remaining + two complex eigenvectors define the center manifold. In other words, the + 4-dimensional system collapses onto a 2-dimensional plane. Any solution + starting off the center manifold will decay towards the *center manifold*. + Furthermore, the limit cycle is contained on the center manifold. + + *2. Bifurcations* + + If the injected current :math:`I` were used as a bifurcation parameter, then the + Hodgkin–Huxley model undergoes a Hopf bifurcation. As with most neuronal models, + increasing the injected current will increase the firing rate of the neuron. + One consequence of the Hopf bifurcation is that there is a minimum firing rate. + This means that either the neuron is not firing at all (corresponding to zero + frequency), or firing at the minimum firing rate. Because of the all-or-none + principle, there is no smooth increase in action potential amplitude, but + rather there is a sudden "jump" in amplitude. The resulting transition is + known as a `canard `_. + + .. image:: ../../../_static/Hodgkins_Huxley_bifurcation_by_I.gif + :align: center + + The following image shows the bifurcation diagram of the Hodgkin–Huxley model + as a function of the external drive :math:`I` [3]_. The green lines show the amplitude + of a stable limit cycle and the blue lines indicate unstable limit-cycle behaviour, + both born from Hopf bifurcations. The solid red line shows the stable fixed point + and the black line shows the unstable fixed point. + + .. image:: ../../../_static/Hodgkin_Huxley_bifurcation.png + :align: center + + **Model Examples** + + >>> import brainpy as bp + >>> group = bp.neurons.HH(2) + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 10.)) + >>> runner.run(200.) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) + + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> + >>> group = bp.neurons.HH(2) + >>> + >>> I1 = bp.inputs.spike_input(sp_times=[500., 550., 1000, 1030, 1060, 1100, 1200], sp_lens=5, sp_sizes=5., duration=2000, ) + >>> I2 = bp.inputs.spike_input(sp_times=[600., 900, 950, 1500], sp_lens=5, sp_sizes=5., duration=2000, ) + >>> I1 += bp.math.random.normal(0, 3, size=I1.shape) + >>> I2 += bp.math.random.normal(0, 3, size=I2.shape) + >>> I = bm.stack((I1, I2), axis=-1) + >>> + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', I, 'iter')) + >>> runner.run(2000.) + >>> + >>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon.V[:, 0]) + >>> plt.plot(runner.mon.ts, runner.mon.V[:, 1] + 130) + >>> plt.xlim(10, 2000) + >>> plt.xticks([]) + >>> plt.yticks([]) + >>> plt.show() + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + m_initializer: ArrayType, Initializer, callable + The initializer of m channel. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + References:: + + .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description + of membrane current and its application to conduction and excitation + in nerve." The Journal of physiology 117.4 (1952): 500. + .. [2] https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model + .. [3] Ashwin, Peter, Stephen Coombes, and Rachel Nicks. "Mathematical + frameworks for oscillatory network dynamics in neuroscience." + The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. + """ + + def __init__( + self, + *args, + input_var: bool = True, + **kwargs, + ): + self.input_var = input_var + super().__init__(*args, **kwargs, init_var=False) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class MorrisLecar(hh.MorrisLecar): - r"""The Morris-Lecar neuron model. + r"""The Morris-Lecar neuron model. + + **Model Descriptions** + + The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) + is a two-dimensional "reduced" excitation model applicable to + systems having two non-inactivating voltage-sensitive conductances. + This model was named after Cathy Morris and Harold Lecar, who + derived it in 1981. Because it is two-dimensional, the Morris-Lecar + model is one of the favorite conductance-based models in computational neuroscience. + + The original form of the model employed an instantaneously + responding voltage-sensitive Ca2+ conductance for excitation and a delayed + voltage-dependent K+ conductance for recovery. The equations of the model are: + + .. math:: + + \begin{aligned} + C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - + g_{Leak} (V - V_{Leak}) + I_{ext} \\ + \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} + \end{aligned} + + Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", + which is almost invariably the normalized :math:`K^+`-ion conductance, and + :math:`I_{ext}` is the applied current stimulus. + + **Model Examples** + + + >>> import brainpy as bp + >>> + >>> group = bp.neurons.MorrisLecar(1) + >>> runner = bp.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.)) + >>> runner.run(1000) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W') + >>> fig.add_subplot(gs[1, 0]) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) + + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_Ca 130 mV Equilibrium potentials of Ca+.(mV) + g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) + V_K -84 mV Equilibrium potentials of K+.(mV) + g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) + V_Leak -60 mV Equilibrium potentials of leak current.(mV) + g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) + C 20 \ Membrane capacitance.(uF/cm2) + V1 -1.2 \ Potential at which M_inf = 0.5.(mV) + V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) + V3 2 \ Potential at which W_inf = 0.5.(mV) + V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) + phi 0.04 \ A temperature factor. (1/s) + V_th 10 mV The spike threshold. + ============= ============== ======== ======================================================= + + References:: + + .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. + .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model + .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model + """ + + def __init__( + self, + *args, + input_var: bool = True, + **kwargs, + ): + self.input_var = input_var + super().__init__(*args, **kwargs, init_var=False) + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) - **Model Descriptions** - The Morris-Lecar model [4]_ (Also known as :math:`I_{Ca}+I_K`-model) - is a two-dimensional "reduced" excitation model applicable to - systems having two non-inactivating voltage-sensitive conductances. - This model was named after Cathy Morris and Harold Lecar, who - derived it in 1981. Because it is two-dimensional, the Morris-Lecar - model is one of the favorite conductance-based models in computational neuroscience. +class PinskyRinzelModel(NeuDyn): + r"""The Pinsky and Rinsel (1994) model. - The original form of the model employed an instantaneously - responding voltage-sensitive Ca2+ conductance for excitation and a delayed - voltage-dependent K+ conductance for recovery. The equations of the model are: + The Pinsky and Rinsel (1994) model [7]_ is a 2-compartment (soma and dendrite), + conductance-based (Hodgin-Huxley type) model of a hippocampal CA3 pyramidal + neuron. It is a reduced version of an earlier, 19-compartment model by + Traub, et. al. (1991) [8]_. This model demonstrates how similar qualitative + and quantitative spiking behaviors can be obtained despite the reduction + in model complexity. - .. math:: + Specifically, this model demonstrates calcium bursting behavior and how + the 'ping-pong' interplay between somatic and dendritic currents results + in a complex shape of the burst. - \begin{aligned} - C\frac{dV}{dt} =& - g_{Ca} M_{\infty} (V - V_{Ca})- g_{K} W(V - V_{K}) - - g_{Leak} (V - V_{Leak}) + I_{ext} \\ - \frac{dW}{dt} =& \frac{W_{\infty}(V) - W}{ \tau_W(V)} - \end{aligned} + .. image:: ../../../_static/Pinsky-Rinzel-model-illustration.png + :align: center - Here, :math:`V` is the membrane potential, :math:`W` is the "recovery variable", - which is almost invariably the normalized :math:`K^+`-ion conductance, and - :math:`I_{ext}` is the applied current stimulus. - - **Model Examples** - - - >>> import brainpy as bp - >>> - >>> group = bp.neurons.MorrisLecar(1) - >>> runner = bp.DSRunner(group, monitors=['V', 'W'], inputs=('input', 100.)) - >>> runner.run(1000) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.W, ylabel='W') - >>> fig.add_subplot(gs[1, 0]) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) - - - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_Ca 130 mV Equilibrium potentials of Ca+.(mV) - g_Ca 4.4 \ Maximum conductance of corresponding Ca+.(mS/cm2) - V_K -84 mV Equilibrium potentials of K+.(mV) - g_K 8 \ Maximum conductance of corresponding K+.(mS/cm2) - V_Leak -60 mV Equilibrium potentials of leak current.(mV) - g_Leak 2 \ Maximum conductance of leak current.(mS/cm2) - C 20 \ Membrane capacitance.(uF/cm2) - V1 -1.2 \ Potential at which M_inf = 0.5.(mV) - V2 18 \ Reciprocal of slope of voltage dependence of M_inf.(mV) - V3 2 \ Potential at which W_inf = 0.5.(mV) - V4 30 \ Reciprocal of slope of voltage dependence of W_inf.(mV) - phi 0.04 \ A temperature factor. (1/s) - V_th 10 mV The spike threshold. - ============= ============== ======== ======================================================= - - References:: - - .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. - .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model - .. [6] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model - """ - - def __init__( - self, - *args, - input_var: bool = True, - **kwargs, - ): - self.input_var = input_var - super().__init__(*args, **kwargs, init_var=False) - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + Mathematically, the model is given by: + .. math:: -class PinskyRinzelModel(NeuDyn): - r"""The Pinsky and Rinsel (1994) model. - - The Pinsky and Rinsel (1994) model [7]_ is a 2-compartment (soma and dendrite), - conductance-based (Hodgin-Huxley type) model of a hippocampal CA3 pyramidal - neuron. It is a reduced version of an earlier, 19-compartment model by - Traub, et. al. (1991) [8]_. This model demonstrates how similar qualitative - and quantitative spiking behaviors can be obtained despite the reduction - in model complexity. - - Specifically, this model demonstrates calcium bursting behavior and how - the 'ping-pong' interplay between somatic and dendritic currents results - in a complex shape of the burst. + \begin{aligned} + &\mathrm{C}_{\mathrm{m}} \mathrm{V}_{\mathrm{s}}^{\prime}=-\mathrm{I}_{\mathrm{Leak}}-\mathrm{I}_{\mathrm{Na}}-\mathrm{I}_{\mathrm{K}_{\mathrm{DR}}}-\frac{\mathrm{I}_{\mathrm{DS}}}{\mathrm{p}}+\frac{\mathrm{I}_{\mathrm{S}_{\mathrm{app}}}}{\mathrm{p}} \\ + &\mathrm{C}_{\mathrm{m}} \mathrm{V}_{\mathrm{d}}^{\prime}=-\mathrm{I}_{\mathrm{Leak}}-\mathrm{I}_{\mathrm{Ca}}-\mathrm{I}_{\mathrm{K}_{\mathrm{Ca}}}-\mathrm{I}_{\mathrm{K}_{\mathrm{AHP}}}+\frac{\mathrm{I}_{\mathrm{SD}}}{(1-\mathrm{p})}+\frac{\mathrm{I}_{\mathrm{D}_{\mathrm{app}}}}{(1-\mathrm{p})} \\ + &\frac{\mathrm{dCa}}{\mathrm{dt}}=-0.13 \mathrm{I}_{\mathrm{Ca}}-0.075 \mathrm{Ca} + \end{aligned} - .. image:: ../../../_static/Pinsky-Rinzel-model-illustration.png - :align: center + The currents of the model are functions of potentials as follows: - Mathematically, the model is given by: + .. math:: - .. math:: + \begin{aligned} + \mathrm{I}_{\mathrm{Na}} &=\mathrm{g}_{\mathrm{Na}} m_{\infty}^{2}\left(\mathrm{~V}_{\mathrm{s}}\right) h\left(\mathrm{~V}_{\mathrm{s}}-\mathrm{V}_{\mathrm{Na}}\right) \\ + \mathrm{I}_{\mathrm{K}_{\mathrm{DR}}} &=\mathrm{g}_{\mathrm{K}_{\mathrm{DR}}} n\left(\mathrm{~V}_{\mathrm{s}}-\mathrm{V}_{\mathrm{K}}\right) \\ + \mathrm{I}_{\mathrm{Ca}} &=\mathrm{g}_{\mathrm{Ca}}{ }^{2}\left(\mathrm{~V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{N}}\right) \\ + \mathrm{I}_{\mathrm{K}_{\mathrm{Ca}}} &=\mathrm{g}_{\mathrm{k}_{\mathrm{Ca}}} C \chi(\mathrm{Ca})\left(\mathrm{V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{Ca}}\right) \\ + \mathrm{I}_{\mathrm{K}_{\mathrm{AHP}}} &=\mathrm{g}_{\mathrm{K}_{\mathrm{AHP}}} q\left(\mathrm{~V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{K}}\right) \\ + \mathrm{I}_{\mathrm{SD}} &=-\mathrm{I}_{\mathrm{DS}}=\mathrm{g}_{\mathrm{c}}\left(\mathrm{V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{s}}\right) \\ + \mathrm{I}_{\mathrm{Leak}} &=\mathrm{g}_{\mathrm{L}}\left(\mathrm{V}-\mathrm{V}_{\mathrm{L}}\right) + \end{aligned} - \begin{aligned} - &\mathrm{C}_{\mathrm{m}} \mathrm{V}_{\mathrm{s}}^{\prime}=-\mathrm{I}_{\mathrm{Leak}}-\mathrm{I}_{\mathrm{Na}}-\mathrm{I}_{\mathrm{K}_{\mathrm{DR}}}-\frac{\mathrm{I}_{\mathrm{DS}}}{\mathrm{p}}+\frac{\mathrm{I}_{\mathrm{S}_{\mathrm{app}}}}{\mathrm{p}} \\ - &\mathrm{C}_{\mathrm{m}} \mathrm{V}_{\mathrm{d}}^{\prime}=-\mathrm{I}_{\mathrm{Leak}}-\mathrm{I}_{\mathrm{Ca}}-\mathrm{I}_{\mathrm{K}_{\mathrm{Ca}}}-\mathrm{I}_{\mathrm{K}_{\mathrm{AHP}}}+\frac{\mathrm{I}_{\mathrm{SD}}}{(1-\mathrm{p})}+\frac{\mathrm{I}_{\mathrm{D}_{\mathrm{app}}}}{(1-\mathrm{p})} \\ - &\frac{\mathrm{dCa}}{\mathrm{dt}}=-0.13 \mathrm{I}_{\mathrm{Ca}}-0.075 \mathrm{Ca} - \end{aligned} + The activation and inactivation variables should satisfy these equations - The currents of the model are functions of potentials as follows: + .. math:: - .. math:: + \begin{aligned} + \omega^{\prime}(\mathrm{V}) &=\frac{\omega_{\infty}(\mathrm{V})-\omega}{\tau_{\omega}(\mathrm{V})} \\ + \omega_{\infty}(\mathrm{V}) &=\frac{\alpha_{\omega}(\mathrm{V})}{\alpha_{\omega}(\mathrm{V})+\beta_{\omega}(\mathrm{V})} \\ + \tau_{\omega}(\mathrm{V}) &=\frac{1}{\alpha_{\omega}(\mathrm{V})+\beta_{\omega}(\mathrm{V})} + \end{aligned} - \begin{aligned} - \mathrm{I}_{\mathrm{Na}} &=\mathrm{g}_{\mathrm{Na}} m_{\infty}^{2}\left(\mathrm{~V}_{\mathrm{s}}\right) h\left(\mathrm{~V}_{\mathrm{s}}-\mathrm{V}_{\mathrm{Na}}\right) \\ - \mathrm{I}_{\mathrm{K}_{\mathrm{DR}}} &=\mathrm{g}_{\mathrm{K}_{\mathrm{DR}}} n\left(\mathrm{~V}_{\mathrm{s}}-\mathrm{V}_{\mathrm{K}}\right) \\ - \mathrm{I}_{\mathrm{Ca}} &=\mathrm{g}_{\mathrm{Ca}}{ }^{2}\left(\mathrm{~V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{N}}\right) \\ - \mathrm{I}_{\mathrm{K}_{\mathrm{Ca}}} &=\mathrm{g}_{\mathrm{k}_{\mathrm{Ca}}} C \chi(\mathrm{Ca})\left(\mathrm{V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{Ca}}\right) \\ - \mathrm{I}_{\mathrm{K}_{\mathrm{AHP}}} &=\mathrm{g}_{\mathrm{K}_{\mathrm{AHP}}} q\left(\mathrm{~V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{K}}\right) \\ - \mathrm{I}_{\mathrm{SD}} &=-\mathrm{I}_{\mathrm{DS}}=\mathrm{g}_{\mathrm{c}}\left(\mathrm{V}_{\mathrm{d}}-\mathrm{V}_{\mathrm{s}}\right) \\ - \mathrm{I}_{\mathrm{Leak}} &=\mathrm{g}_{\mathrm{L}}\left(\mathrm{V}-\mathrm{V}_{\mathrm{L}}\right) + where, independently, we consider :math:`\omega = h, n, s, m, c, q`. + + The rate functions are defined as follows + + .. math:: + + \begin{aligned} + \alpha_{m}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{0.32\left(-46.9-\mathrm{V}_{\mathrm{s}}\right)}{\exp \left(\frac{-46.9-\mathrm{V}_{\mathrm{s}}}{4}\right)-1} \\ + \beta_{m}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{0.28\left(\mathrm{~V}_{\mathrm{s}}+19.9\right)}{\exp \left(\frac{\mathrm{V}_{\mathrm{s}}+19.9}{5}\right)-1}, \\ + \alpha_{n}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{0.016\left(-24.9-\mathrm{V}_{\mathrm{s}}\right)}{\exp \left(\frac{-24.9-\mathrm{V}_{\mathrm{s}}}{5}\right)-1} \\ + \beta_{n}\left(\mathrm{~V}_{\mathrm{s}}\right) &=0.25 \exp \left(-1-0.025 \mathrm{~V}_{\mathrm{s}}\right) \\ + \alpha_{h}\left(\mathrm{~V}_{\mathrm{s}}\right) &=0.128 \exp \left(\frac{-43-\mathrm{V}_{\mathrm{s}}}{18}\right) \\ + \beta_{h}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{4}{1+\exp \left(\frac{\left(-20-\mathrm{V}_{\mathrm{s}}\right.}{5}\right)}, \\ + \alpha_{s}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\frac{1.6}{1+\exp \left(-0.072\left(\mathrm{~V}_{\mathrm{d}}-5\right)\right)} \\ + \beta_{s}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\frac{0.02\left(\mathrm{~V}_{\mathrm{d}}+8.9\right)}{\exp \left(\frac{\left(\mathrm{V}_{\mathrm{d}}+8.9\right)}{5}\right)-1}, \\ + \alpha_{C}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\frac{\left(1-H\left(\mathrm{~V}_{\mathrm{d}}+10\right)\right) \exp \left(\frac{\left(\mathrm{V}_{\mathrm{d}}+50\right)}{11}-\frac{\left(\mathrm{V}_{\mathrm{d}}+53.5\right)}{27}\right)}{18.975}+H\left(\mathrm{~V}_{\mathrm{d}}+10\right)\left(2 \exp \left(\frac{\left(-53.5-\mathrm{V}_{\mathrm{d}}\right.}{27}\right)\right) \\ + \beta_{C}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\left(1-H\left(\mathrm{~V}_{\mathrm{d}}+10\right)\right)\left(2 \exp \left(\frac{\left(-53.5-\mathrm{V}_{\mathrm{d}}\right)}{27}\right)-\alpha_{c}\left(\mathrm{~V}_{\mathrm{d}}\right)\right) \\ + \alpha_{q}(\mathrm{Ca}) &=\min (0.00002 \mathrm{Ca}, 0.01) \\ + \beta_{q}(\mathrm{Ca}) &=0.001 \\ + \chi(\mathrm{Ca}) &=\min \left(\frac{\mathrm{Ca}}{250}, 1\right) \end{aligned} - The activation and inactivation variables should satisfy these equations - - .. math:: - - \begin{aligned} - \omega^{\prime}(\mathrm{V}) &=\frac{\omega_{\infty}(\mathrm{V})-\omega}{\tau_{\omega}(\mathrm{V})} \\ - \omega_{\infty}(\mathrm{V}) &=\frac{\alpha_{\omega}(\mathrm{V})}{\alpha_{\omega}(\mathrm{V})+\beta_{\omega}(\mathrm{V})} \\ - \tau_{\omega}(\mathrm{V}) &=\frac{1}{\alpha_{\omega}(\mathrm{V})+\beta_{\omega}(\mathrm{V})} - \end{aligned} - - where, independently, we consider :math:`\omega = h, n, s, m, c, q`. - - The rate functions are defined as follows - - .. math:: - - \begin{aligned} - \alpha_{m}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{0.32\left(-46.9-\mathrm{V}_{\mathrm{s}}\right)}{\exp \left(\frac{-46.9-\mathrm{V}_{\mathrm{s}}}{4}\right)-1} \\ - \beta_{m}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{0.28\left(\mathrm{~V}_{\mathrm{s}}+19.9\right)}{\exp \left(\frac{\mathrm{V}_{\mathrm{s}}+19.9}{5}\right)-1}, \\ - \alpha_{n}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{0.016\left(-24.9-\mathrm{V}_{\mathrm{s}}\right)}{\exp \left(\frac{-24.9-\mathrm{V}_{\mathrm{s}}}{5}\right)-1} \\ - \beta_{n}\left(\mathrm{~V}_{\mathrm{s}}\right) &=0.25 \exp \left(-1-0.025 \mathrm{~V}_{\mathrm{s}}\right) \\ - \alpha_{h}\left(\mathrm{~V}_{\mathrm{s}}\right) &=0.128 \exp \left(\frac{-43-\mathrm{V}_{\mathrm{s}}}{18}\right) \\ - \beta_{h}\left(\mathrm{~V}_{\mathrm{s}}\right) &=\frac{4}{1+\exp \left(\frac{\left(-20-\mathrm{V}_{\mathrm{s}}\right.}{5}\right)}, \\ - \alpha_{s}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\frac{1.6}{1+\exp \left(-0.072\left(\mathrm{~V}_{\mathrm{d}}-5\right)\right)} \\ - \beta_{s}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\frac{0.02\left(\mathrm{~V}_{\mathrm{d}}+8.9\right)}{\exp \left(\frac{\left(\mathrm{V}_{\mathrm{d}}+8.9\right)}{5}\right)-1}, \\ - \alpha_{C}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\frac{\left(1-H\left(\mathrm{~V}_{\mathrm{d}}+10\right)\right) \exp \left(\frac{\left(\mathrm{V}_{\mathrm{d}}+50\right)}{11}-\frac{\left(\mathrm{V}_{\mathrm{d}}+53.5\right)}{27}\right)}{18.975}+H\left(\mathrm{~V}_{\mathrm{d}}+10\right)\left(2 \exp \left(\frac{\left(-53.5-\mathrm{V}_{\mathrm{d}}\right.}{27}\right)\right) \\ - \beta_{C}\left(\mathrm{~V}_{\mathrm{d}}\right) &=\left(1-H\left(\mathrm{~V}_{\mathrm{d}}+10\right)\right)\left(2 \exp \left(\frac{\left(-53.5-\mathrm{V}_{\mathrm{d}}\right)}{27}\right)-\alpha_{c}\left(\mathrm{~V}_{\mathrm{d}}\right)\right) \\ - \alpha_{q}(\mathrm{Ca}) &=\min (0.00002 \mathrm{Ca}, 0.01) \\ - \beta_{q}(\mathrm{Ca}) &=0.001 \\ - \chi(\mathrm{Ca}) &=\min \left(\frac{\mathrm{Ca}}{250}, 1\right) - \end{aligned} - - The standard values of the parameters are given below. The maximal conductances - (in :math:`\mathrm{mS} / \mathrm{cm}^{2}`) are - :math:`\bar{g}_{L}=0.1`, :math:`\bar{g}_{\mathrm{Na}}=30`, - :math:`\bar{g}_{\mathrm{K}-\mathrm{DR}}=15`, - :math:`\bar{g}_{\mathrm{Ca}}=10`, - :math:`\bar{g}_{\mathrm{K}-\mathrm{AHP}}=0.8`, - :math:`\bar{g}_{\mathrm{K}-\mathrm{C}}=15`, - :math:`\bar{g}_{\mathrm{NMDA}}=0.0` and - :math:`\bar{g}_{\mathrm{AMPA}}=0.0`. - The reversal potentials (in :math:`\mathrm{mV}` ) are - :math:`V_{\mathrm{Na}}=120, V_{\mathrm{C}}=140, V_{\mathrm{K}}=-15 \mathrm{mV})` - are :math:`V_{\mathrm{Na}}=120, V_{\mathrm{Ca}}=140, V_{\mathrm{K}}=-15, $V_{L}=0` - and :math:`V_{\text {Syn }}=60`. The applied currents - (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ) are :math:`I_{s}=-0.5` and :math:`I_{d}=0.0`. - The coupling parameters are :math:`g_{c}=2.1 \mathrm{mS} / \mathrm{cm}^{2}` and - :math:`p=0.5`. The capacitance, :math:`C_{M}`, is - :math:`3 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`\chi(C a)=\min (C a / 250,1)`. - Values for these parameters, and these function definitions, are taken from Traub et al, 1991. - - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium delayed-rectifier channel. - gCa: float, ArrayType, Initializer, callable - The maximum conductance of calcium channel. - gAHP: float, ArrayType, Initializer, callable - The maximum conductance of potassium after-hyper-polarization channel. - gC: float, ArrayType, Initializer, callable - The maximum conductance of calcium activated potassium channel. - gL: float, ArrayType, Initializer, callable - The conductance of leaky channel. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium channel. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium delayed-rectifier channel. - ECa: float, ArrayType, Initializer, callable - The reversal potential of calcium channel. - EL: float, ArrayType, Initializer, callable - The reversal potential of leaky channel. - gc: float, ArrayType, Initializer, callable - The coupling strength between the soma and dendrite. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. - Cm: float, ArrayType, Initializer, callable - The threshold of the membrane spike. - A: float, ArrayType, Initializer, callable - The total cell membrane area, which is normalized to 1. - p: float, ArrayType, Initializer, callable - The proportion of cell area taken up by the soma. - Vs_initializer: ArrayType, Initializer, callable - The initializer of somatic membrane potential. - Vd_initializer: ArrayType, Initializer, callable - The initializer of dendritic membrane potential. - Ca_initializer: ArrayType, Initializer, callable - The initializer of Calcium concentration. - method: str - The numerical integration method. - name: str - The group name. - - References:: - - .. [7] Pinsky, Paul F., and John Rinzel. "Intrinsic and network - rhythmogenesis in a reduced Traub model for CA3 neurons." - Journal of computational neuroscience 1.1 (1994): 39-60. - .. [8] Traub, R. D., Wong, R. K., Miles, R., & Michelson, H. (1991). - A model of a CA3 hippocampal pyramidal neuron incorporating - voltage-clamp data on intrinsic conductances. Journal of - neurophysiology, 66(2), 635-650. - """ - - supported_modes = (bm.BatchingMode, bm.NonBatchingMode) - - def __init__( - self, - size: Shape, - keep_size: bool = False, - # maximum conductance - gNa: Union[float, ArrayType, Initializer, Callable] = 30., - gK: Union[float, ArrayType, Initializer, Callable] = 15., - gCa: Union[float, ArrayType, Initializer, Callable] = 10., - gAHP: Union[float, ArrayType, Initializer, Callable] = 0.8, - gC: Union[float, ArrayType, Initializer, Callable] = 15., - gL: Union[float, ArrayType, Initializer, Callable] = 0.1, - # reversal potential - ENa: Union[float, ArrayType, Initializer, Callable] = 60., - EK: Union[float, ArrayType, Initializer, Callable] = -75., - ECa: Union[float, ArrayType, Initializer, Callable] = 80., - EL: Union[float, ArrayType, Initializer, Callable] = -60., - # other parameters - gc: Union[float, ArrayType, Initializer, Callable] = 2.1, - V_th: Union[float, ArrayType, Initializer, Callable] = 20., - Cm: Union[float, ArrayType, Initializer, Callable] = 3.0, - p: Union[float, ArrayType, Initializer, Callable] = 0.5, - A: Union[float, ArrayType, Initializer, Callable] = 1., - # initializers - Vs_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-64.6), - Vd_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-64.5), - Ca_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.2), - # others - noise: Union[float, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - # initialization - super(PinskyRinzelModel, self).__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - - # conductance parameters - self.gAHP = parameter(gAHP, self.varshape, allow_none=False) - self.gCa = parameter(gCa, self.varshape, allow_none=False) - self.gNa = parameter(gNa, self.varshape, allow_none=False) - self.gK = parameter(gK, self.varshape, allow_none=False) - self.gL = parameter(gL, self.varshape, allow_none=False) - self.gC = parameter(gC, self.varshape, allow_none=False) - - # reversal potential parameters - self.ENa = parameter(ENa, self.varshape, allow_none=False) - self.ECa = parameter(ECa, self.varshape, allow_none=False) - self.EK = parameter(EK, self.varshape, allow_none=False) - self.EL = parameter(EL, self.varshape, allow_none=False) - - # other neuronal parameters - self.V_th = parameter(V_th, self.varshape, allow_none=False) - self.Cm = parameter(Cm, self.varshape, allow_none=False) - self.gc = parameter(gc, self.varshape, allow_none=False) - self.p = parameter(p, self.varshape, allow_none=False) - self.A = parameter(A, self.varshape, allow_none=False) - self.noise = init_noise(noise, self.varshape, num_vars=8) - - # initializers - check.is_initializer(Vs_initializer, 'Vs_initializer', allow_none=False) - check.is_initializer(Vd_initializer, 'Vd_initializer', allow_none=False) - check.is_initializer(Ca_initializer, 'Ca_initializer', allow_none=False) - self._Vs_initializer = Vs_initializer - self._Vd_initializer = Vd_initializer - self._Ca_initializer = Ca_initializer - - # variables - self.Vs = variable_(self._Vs_initializer, self.varshape, self.mode) - self.Vd = variable_(self._Vd_initializer, self.varshape, self.mode) - self.Ca = variable_(self._Ca_initializer, self.varshape, self.mode) - self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) - self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) - self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) - self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) - self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) - self.Id = variable_(bm.zeros, self.varshape, self.mode) # input to soma - self.Is = variable_(bm.zeros, self.varshape, self.mode) # input to dendrite - - # integral - if self.noise is None: - self.integral = odeint(method=method, f=self.derivative) - else: - self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - - def reset_state(self, batch_size=None): - self.Vd.value = variable_(self._Vd_initializer, self.varshape, batch_size) - self.Vs.value = variable_(self._Vs_initializer, self.varshape, batch_size) - self.Ca.value = variable_(self._Ca_initializer, self.varshape, batch_size) - batch_axis = 0 if isinstance(self.mode, bm.BatchingMode) else None - self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis) - self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis) - self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) - self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis) - self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis) - self.Id.value = variable_(bm.zeros, self.varshape, batch_size) - self.Is.value = variable_(bm.zeros, self.varshape, batch_size) - - def dCa(self, Ca, t, s, Vd): - ICa = self.gCa * s * s * (Vd - self.ECa) - return -0.13 * ICa - 0.075 * Ca - - def dh(self, h, t, Vs): - return self.alpha_h(Vs) * (1 - h) - self.beta_h(Vs) * h - - def dn(self, n, t, Vs): - return self.alpha_n(Vs) * (1 - n) - self.beta_n(Vs) * n - - def ds(self, s, t, Vd): - return self.alpha_s(Vd) * (1 - s) - self.beta_s(Vd) * s - - def dc(self, c, t, Vd): - return self.alpha_c(Vd) * (1 - c) - self.beta_c(Vd) * c - - def dq(self, q, t, Ca): - return self.alpha_q(Ca) * (1 - q) - self.beta_q(Ca) * q - - def dVs(self, Vs, t, h, n, Vd): - I_Na = (self.gNa * self.inf_m(Vs) ** 2 * h) * (Vs - self.ENa) - I_KDR = (self.gK * n) * (Vs - self.EK) - I_leak = self.gL * (Vs - self.EL) - I_gj = self.gc / self.p * (Vd - Vs) - dVdt = (- I_Na - I_KDR - I_leak + I_gj + self.Is / self.p) / self.Cm - return dVdt - - def dVd(self, Vd, t, s, q, c, Ca, Vs): - I_leak = self.gL * (Vd - self.EL) - I_Ca = self.gCa * s * s * (Vd - self.ECa) - I_AHP = self.gAHP * q * (Vd - self.EK) - I_C = self.gC * bm.minimum(Ca / 250., 1.) * (Vd - self.EK) - p = 1 - self.p - I_gj = self.gc / p * (Vs - Vd) - dVdt = (- I_leak - I_Ca - I_AHP - I_C + I_gj + self.Id / p) / self.Cm - return dVdt - - @property - def derivative(self): - return JointEq(self.dVs, self.dVd, self.dCa, self.dh, self.dn, self.ds, self.dc, self.dq) - - def update(self, x=None): - assert x is None - t = share.load('t') - dt = share.load('dt') - Vs, Vd, Ca, h, n, s, c, q = self.integral(Vs=self.Vs.value, - Vd=self.Vd.value, - Ca=self.Ca.value, - h=self.h.value, - n=self.n.value, - s=self.s.value, - c=self.c.value, - q=self.q.value, - t=t, - dt=dt) - self.Vs.value = Vs - self.Vd.value = Vd - self.Ca.value = Ca - self.h.value = h - self.n.value = n - self.s.value = s - self.c.value = c - self.q.value = q - - def clear_input(self): - self.Id.value = bm.zeros_like(self.Id) - self.Is.value = bm.zeros_like(self.Is) - - def alpha_m(self, Vs): - return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.) - - def beta_m(self, Vs): - return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.) - - def inf_m(self, Vs): - alpha = self.alpha_m(Vs) - beta = self.beta_m(Vs) - return alpha / (alpha + beta) - - def alpha_n(self, Vs): - return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1) - - def beta_n(self, Vs): - return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.)) - - def inf_n(self, Vs): - alpha = self.alpha_n(Vs) - beta = self.beta_n(Vs) - return alpha / (alpha + beta) - - def alpha_h(self, Vs): - return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.) - - def beta_h(self, Vs): - return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5)) - - def inf_h(self, Vs): - alpha = self.alpha_h(Vs) - beta = self.beta_h(Vs) - return alpha / (alpha + beta) - - def alpha_s(self, Vd): - return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.))) - - def beta_s(self, Vd): - return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.) - - def inf_s(self, Vd): - alpha = self.alpha_s(Vd) - beta = self.beta_s(Vd) - return alpha / (alpha + beta) - - def alpha_c(self, Vd): - return bm.where((Vd + 60.) <= 50., - (bm.exp(((Vd + 60.) - 10.) / 11.) - bm.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975, - 2. * bm.exp((6.5 - (Vd + 60.)) / 27.)) - - def beta_c(self, Vd): - alpha_c = (bm.exp(((Vd + 60.) - 10.) / 11.) - bm.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975 - return bm.where((Vd + 60.) <= 50., 2. * bm.exp((6.5 - (Vd + 60.)) / 27.) - alpha_c, 0.) - - def inf_c(self, Vd): - alpha_c = self.alpha_c(Vd) - beta_c = self.beta_c(Vd) - return alpha_c / (alpha_c + beta_c) - - def alpha_q(self, Ca): - return bm.minimum(2e-5 * Ca, 1e-2) - - def beta_q(self, Ca): - return 1e-3 - - def inf_q(self, Ca): - alpha = self.alpha_q(Ca) - beta = self.beta_q(Ca) - return alpha / (alpha + beta) + The standard values of the parameters are given below. The maximal conductances + (in :math:`\mathrm{mS} / \mathrm{cm}^{2}`) are + :math:`\bar{g}_{L}=0.1`, :math:`\bar{g}_{\mathrm{Na}}=30`, + :math:`\bar{g}_{\mathrm{K}-\mathrm{DR}}=15`, + :math:`\bar{g}_{\mathrm{Ca}}=10`, + :math:`\bar{g}_{\mathrm{K}-\mathrm{AHP}}=0.8`, + :math:`\bar{g}_{\mathrm{K}-\mathrm{C}}=15`, + :math:`\bar{g}_{\mathrm{NMDA}}=0.0` and + :math:`\bar{g}_{\mathrm{AMPA}}=0.0`. + The reversal potentials (in :math:`\mathrm{mV}` ) are + :math:`V_{\mathrm{Na}}=120, V_{\mathrm{C}}=140, V_{\mathrm{K}}=-15 \mathrm{mV})` + are :math:`V_{\mathrm{Na}}=120, V_{\mathrm{Ca}}=140, V_{\mathrm{K}}=-15, $V_{L}=0` + and :math:`V_{\text {Syn }}=60`. The applied currents + (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ) are :math:`I_{s}=-0.5` and :math:`I_{d}=0.0`. + The coupling parameters are :math:`g_{c}=2.1 \mathrm{mS} / \mathrm{cm}^{2}` and + :math:`p=0.5`. The capacitance, :math:`C_{M}`, is + :math:`3 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`\chi(C a)=\min (C a / 250,1)`. + Values for these parameters, and these function definitions, are taken from Traub et al, 1991. + + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium delayed-rectifier channel. + gCa: float, ArrayType, Initializer, callable + The maximum conductance of calcium channel. + gAHP: float, ArrayType, Initializer, callable + The maximum conductance of potassium after-hyper-polarization channel. + gC: float, ArrayType, Initializer, callable + The maximum conductance of calcium activated potassium channel. + gL: float, ArrayType, Initializer, callable + The conductance of leaky channel. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium channel. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium delayed-rectifier channel. + ECa: float, ArrayType, Initializer, callable + The reversal potential of calcium channel. + EL: float, ArrayType, Initializer, callable + The reversal potential of leaky channel. + gc: float, ArrayType, Initializer, callable + The coupling strength between the soma and dendrite. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. + Cm: float, ArrayType, Initializer, callable + The threshold of the membrane spike. + A: float, ArrayType, Initializer, callable + The total cell membrane area, which is normalized to 1. + p: float, ArrayType, Initializer, callable + The proportion of cell area taken up by the soma. + Vs_initializer: ArrayType, Initializer, callable + The initializer of somatic membrane potential. + Vd_initializer: ArrayType, Initializer, callable + The initializer of dendritic membrane potential. + Ca_initializer: ArrayType, Initializer, callable + The initializer of Calcium concentration. + method: str + The numerical integration method. + name: str + The group name. + + References:: + + .. [7] Pinsky, Paul F., and John Rinzel. "Intrinsic and network + rhythmogenesis in a reduced Traub model for CA3 neurons." + Journal of computational neuroscience 1.1 (1994): 39-60. + .. [8] Traub, R. D., Wong, R. K., Miles, R., & Michelson, H. (1991). + A model of a CA3 hippocampal pyramidal neuron incorporating + voltage-clamp data on intrinsic conductances. Journal of + neurophysiology, 66(2), 635-650. + """ + + supported_modes = (bm.BatchingMode, bm.NonBatchingMode) + + def __init__( + self, + size: Shape, + keep_size: bool = False, + # maximum conductance + gNa: Union[float, ArrayType, Initializer, Callable] = 30., + gK: Union[float, ArrayType, Initializer, Callable] = 15., + gCa: Union[float, ArrayType, Initializer, Callable] = 10., + gAHP: Union[float, ArrayType, Initializer, Callable] = 0.8, + gC: Union[float, ArrayType, Initializer, Callable] = 15., + gL: Union[float, ArrayType, Initializer, Callable] = 0.1, + # reversal potential + ENa: Union[float, ArrayType, Initializer, Callable] = 60., + EK: Union[float, ArrayType, Initializer, Callable] = -75., + ECa: Union[float, ArrayType, Initializer, Callable] = 80., + EL: Union[float, ArrayType, Initializer, Callable] = -60., + # other parameters + gc: Union[float, ArrayType, Initializer, Callable] = 2.1, + V_th: Union[float, ArrayType, Initializer, Callable] = 20., + Cm: Union[float, ArrayType, Initializer, Callable] = 3.0, + p: Union[float, ArrayType, Initializer, Callable] = 0.5, + A: Union[float, ArrayType, Initializer, Callable] = 1., + # initializers + Vs_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-64.6), + Vd_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-64.5), + Ca_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.2), + # others + noise: Union[float, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + # initialization + super(PinskyRinzelModel, self).__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + + # conductance parameters + self.gAHP = parameter(gAHP, self.varshape, allow_none=False) + self.gCa = parameter(gCa, self.varshape, allow_none=False) + self.gNa = parameter(gNa, self.varshape, allow_none=False) + self.gK = parameter(gK, self.varshape, allow_none=False) + self.gL = parameter(gL, self.varshape, allow_none=False) + self.gC = parameter(gC, self.varshape, allow_none=False) + + # reversal potential parameters + self.ENa = parameter(ENa, self.varshape, allow_none=False) + self.ECa = parameter(ECa, self.varshape, allow_none=False) + self.EK = parameter(EK, self.varshape, allow_none=False) + self.EL = parameter(EL, self.varshape, allow_none=False) + + # other neuronal parameters + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.Cm = parameter(Cm, self.varshape, allow_none=False) + self.gc = parameter(gc, self.varshape, allow_none=False) + self.p = parameter(p, self.varshape, allow_none=False) + self.A = parameter(A, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=8) + + # initializers + check.is_initializer(Vs_initializer, 'Vs_initializer', allow_none=False) + check.is_initializer(Vd_initializer, 'Vd_initializer', allow_none=False) + check.is_initializer(Ca_initializer, 'Ca_initializer', allow_none=False) + self._Vs_initializer = Vs_initializer + self._Vd_initializer = Vd_initializer + self._Ca_initializer = Ca_initializer + + # variables + self.Vs = variable_(self._Vs_initializer, self.varshape, self.mode) + self.Vd = variable_(self._Vd_initializer, self.varshape, self.mode) + self.Ca = variable_(self._Ca_initializer, self.varshape, self.mode) + self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) + self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) + self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) + self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) + self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(self.mode, bm.BatchingMode) else None) + self.Id = variable_(bm.zeros, self.varshape, self.mode) # input to soma + self.Is = variable_(bm.zeros, self.varshape, self.mode) # input to dendrite + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + def reset_state(self, batch_size=None): + self.Vd.value = variable_(self._Vd_initializer, self.varshape, batch_size) + self.Vs.value = variable_(self._Vs_initializer, self.varshape, batch_size) + self.Ca.value = variable_(self._Ca_initializer, self.varshape, batch_size) + batch_axis = 0 if isinstance(self.mode, bm.BatchingMode) else None + self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis) + self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis) + self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis) + self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis) + self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis) + self.Id.value = variable_(bm.zeros, self.varshape, batch_size) + self.Is.value = variable_(bm.zeros, self.varshape, batch_size) + + def dCa(self, Ca, t, s, Vd): + ICa = self.gCa * s * s * (Vd - self.ECa) + return -0.13 * ICa - 0.075 * Ca + + def dh(self, h, t, Vs): + return self.alpha_h(Vs) * (1 - h) - self.beta_h(Vs) * h + + def dn(self, n, t, Vs): + return self.alpha_n(Vs) * (1 - n) - self.beta_n(Vs) * n + + def ds(self, s, t, Vd): + return self.alpha_s(Vd) * (1 - s) - self.beta_s(Vd) * s + + def dc(self, c, t, Vd): + return self.alpha_c(Vd) * (1 - c) - self.beta_c(Vd) * c + + def dq(self, q, t, Ca): + return self.alpha_q(Ca) * (1 - q) - self.beta_q(Ca) * q + + def dVs(self, Vs, t, h, n, Vd): + I_Na = (self.gNa * self.inf_m(Vs) ** 2 * h) * (Vs - self.ENa) + I_KDR = (self.gK * n) * (Vs - self.EK) + I_leak = self.gL * (Vs - self.EL) + I_gj = self.gc / self.p * (Vd - Vs) + dVdt = (- I_Na - I_KDR - I_leak + I_gj + self.Is / self.p) / self.Cm + return dVdt + + def dVd(self, Vd, t, s, q, c, Ca, Vs): + I_leak = self.gL * (Vd - self.EL) + I_Ca = self.gCa * s * s * (Vd - self.ECa) + I_AHP = self.gAHP * q * (Vd - self.EK) + I_C = self.gC * bm.minimum(Ca / 250., 1.) * (Vd - self.EK) + p = 1 - self.p + I_gj = self.gc / p * (Vs - Vd) + dVdt = (- I_leak - I_Ca - I_AHP - I_C + I_gj + self.Id / p) / self.Cm + return dVdt + + @property + def derivative(self): + return JointEq(self.dVs, self.dVd, self.dCa, self.dh, self.dn, self.ds, self.dc, self.dq) + + def update(self, x=None): + assert x is None + t = share.load('t') + dt = share.load('dt') + Vs, Vd, Ca, h, n, s, c, q = self.integral(Vs=self.Vs.value, + Vd=self.Vd.value, + Ca=self.Ca.value, + h=self.h.value, + n=self.n.value, + s=self.s.value, + c=self.c.value, + q=self.q.value, + t=t, + dt=dt) + self.Vs.value = Vs + self.Vd.value = Vd + self.Ca.value = Ca + self.h.value = h + self.n.value = n + self.s.value = s + self.c.value = c + self.q.value = q + + def clear_input(self): + self.Id.value = bm.zeros_like(self.Id) + self.Is.value = bm.zeros_like(self.Is) + + def alpha_m(self, Vs): + return 0.32 * (13.1 - (Vs + 60.)) / (bm.exp((13.1 - (Vs + 60.)) / 4.) - 1.) + + def beta_m(self, Vs): + return 0.28 * ((Vs + 60.) - 40.1) / (bm.exp(((Vs + 60.) - 40.1) / 5.) - 1.) + + def inf_m(self, Vs): + alpha = self.alpha_m(Vs) + beta = self.beta_m(Vs) + return alpha / (alpha + beta) + + def alpha_n(self, Vs): + return 0.016 * (35.1 - (Vs + 60.)) / (bm.exp((35.1 - (Vs + 60.)) / 5) - 1) + + def beta_n(self, Vs): + return 0.25 * bm.exp(0.5 - 0.025 * (Vs + 60.)) + + def inf_n(self, Vs): + alpha = self.alpha_n(Vs) + beta = self.beta_n(Vs) + return alpha / (alpha + beta) + + def alpha_h(self, Vs): + return 0.128 * bm.exp((17. - (Vs + 60.)) / 18.) + + def beta_h(self, Vs): + return 4. / (1 + bm.exp((40. - (Vs + 60.)) / 5)) + + def inf_h(self, Vs): + alpha = self.alpha_h(Vs) + beta = self.beta_h(Vs) + return alpha / (alpha + beta) + + def alpha_s(self, Vd): + return 1.6 / (1 + bm.exp(-0.072 * ((Vd + 60.) - 65.))) + + def beta_s(self, Vd): + return 0.02 * ((Vd + 60.) - 51.1) / (bm.exp(((Vd + 60.) - 51.1) / 5.) - 1.) + + def inf_s(self, Vd): + alpha = self.alpha_s(Vd) + beta = self.beta_s(Vd) + return alpha / (alpha + beta) + + def alpha_c(self, Vd): + return bm.where((Vd + 60.) <= 50., + (bm.exp(((Vd + 60.) - 10.) / 11.) - bm.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975, + 2. * bm.exp((6.5 - (Vd + 60.)) / 27.)) + + def beta_c(self, Vd): + alpha_c = (bm.exp(((Vd + 60.) - 10.) / 11.) - bm.exp(((Vd + 60.) - 6.5) / 27.)) / 18.975 + return bm.where((Vd + 60.) <= 50., 2. * bm.exp((6.5 - (Vd + 60.)) / 27.) - alpha_c, 0.) + + def inf_c(self, Vd): + alpha_c = self.alpha_c(Vd) + beta_c = self.beta_c(Vd) + return alpha_c / (alpha_c + beta_c) + + def alpha_q(self, Ca): + return bm.minimum(2e-5 * Ca, 1e-2) + + def beta_q(self, Ca): + return 1e-3 + + def inf_q(self, Ca): + alpha = self.alpha_q(Ca) + beta = self.beta_q(Ca) + return alpha / (alpha + beta) class WangBuzsakiModel(hh.WangBuzsakiHH): - r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. - - Each model is described by a single compartment and obeys the current balance equation: - - .. math:: - - C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} - - where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the - injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current - :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance - :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant - :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. - - The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion - currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the - Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current - :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, - where the activation variable :math:`m` is assumed fast and substituted by its steady-state - function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; - :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. - The inactivation variable :math:`h` obeys a first-order kinetics: - - .. math:: - - \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) - - where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and - :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; - :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` - - The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, - where the activation variable :math:`n` obeys the following equation: - - .. math:: - - \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) - - with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and - :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and - :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. - - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - ENa: float, ArrayType, Initializer, callable - The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable - The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable - The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable - The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable - The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable - The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable - The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable - The membrane capacitance. Default is 1 ufarad. - phi: float, ArrayType, Initializer, callable - The temperature regulator constant. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - h_initializer: ArrayType, Initializer, callable - The initializer of h channel. - n_initializer: ArrayType, Initializer, callable - The initializer of n channel. - method: str - The numerical integration method. - name: str - The group name. - - References:: - - .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic - inhibition in a hippocampal interneuronal network model. Journal of - neuroscience, 16(20), pp.6402-6413. - - """ - - def __init__( - self, - *args, - input_var: bool = True, - - **kwargs, - ): - self.input_var = input_var - super().__init__(*args, **kwargs, init_var=False) - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model. + + Each model is described by a single compartment and obeys the current balance equation: + + .. math:: + + C_{m} \frac{d V}{d t}=-I_{\mathrm{Na}}-I_{\mathrm{K}}-I_{\mathrm{L}}-I_{\mathrm{syn}}+I_{\mathrm{app}} + + where :math:`C_{m}=1 \mu \mathrm{F} / \mathrm{cm}^{2}` and :math:`I_{\mathrm{app}}` is the + injected current (in :math:`\mu \mathrm{A} / \mathrm{cm}^{2}` ). The leak current + :math:`I_{\mathrm{L}}=g_{\mathrm{L}}\left(V-E_{\mathrm{L}}\right)` has a conductance + :math:`g_{\mathrm{L}}=0.1 \mathrm{mS} / \mathrm{cm}^{2}`, so that the passive time constant + :math:`\tau_{0}=C_{m} / g_{\mathrm{L}}=10 \mathrm{msec} ; E_{\mathrm{L}}=-65 \mathrm{mV}`. + + The spike-generating :math:`\mathrm{Na}^{+}` and :math:`\mathrm{K}^{+}` voltage-dependent ion + currents :math:`\left(I_{\mathrm{Na}}\right.` and :math:`I_{\mathrm{K}}` ) are of the + Hodgkin-Huxley type (Hodgkin and Huxley, 1952). The transient sodium current + :math:`I_{\mathrm{Na}}=g_{\mathrm{Na}} m_{\infty}^{3} h\left(V-E_{\mathrm{Na}}\right)`, + where the activation variable :math:`m` is assumed fast and substituted by its steady-state + function :math:`m_{\infty}=\alpha_{m} /\left(\alpha_{m}+\beta_{m}\right)` ; + :math:`\alpha_{m}(V)=-0.1(V+35) /(\exp (-0.1(V+35))-1), \beta_{m}(V)=4 \exp (-(V+60) / 18)`. + The inactivation variable :math:`h` obeys a first-order kinetics: + + .. math:: + + \frac{d h}{d t}=\phi\left(\alpha_{h}(1-h)-\beta_{h} h\right) + + where :math:`\alpha_{h}(V)=0.07 \exp (-(V+58) / 20)` and + :math:`\beta_{h}(V)=1 /(\exp (-0.1(V+28)) +1) \cdot g_{\mathrm{Na}}=35 \mathrm{mS} / \mathrm{cm}^{2}` ; + :math:`E_{\mathrm{Na}}=55 \mathrm{mV}, \phi=5 .` + + The delayed rectifier :math:`I_{\mathrm{K}}=g_{\mathrm{K}} n^{4}\left(V-E_{\mathrm{K}}\right)`, + where the activation variable :math:`n` obeys the following equation: + + .. math:: + + \frac{d n}{d t}=\phi\left(\alpha_{n}(1-n)-\beta_{n} n\right) + + with :math:`\alpha_{n}(V)=-0.01(V+34) /(\exp (-0.1(V+34))-1)` and + :math:`\beta_{n}(V)=0.125\exp (-(V+44) / 80)` ; :math:`g_{\mathrm{K}}=9 \mathrm{mS} / \mathrm{cm}^{2}`, and + :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. + + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + ENa: float, ArrayType, Initializer, callable + The reversal potential of sodium. Default is 50 mV. + gNa: float, ArrayType, Initializer, callable + The maximum conductance of sodium channel. Default is 120 msiemens. + EK: float, ArrayType, Initializer, callable + The reversal potential of potassium. Default is -77 mV. + gK: float, ArrayType, Initializer, callable + The maximum conductance of potassium channel. Default is 36 msiemens. + EL: float, ArrayType, Initializer, callable + The reversal potential of learky channel. Default is -54.387 mV. + gL: float, ArrayType, Initializer, callable + The conductance of learky channel. Default is 0.03 msiemens. + V_th: float, ArrayType, Initializer, callable + The threshold of the membrane spike. Default is 20 mV. + C: float, ArrayType, Initializer, callable + The membrane capacitance. Default is 1 ufarad. + phi: float, ArrayType, Initializer, callable + The temperature regulator constant. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + h_initializer: ArrayType, Initializer, callable + The initializer of h channel. + n_initializer: ArrayType, Initializer, callable + The initializer of n channel. + method: str + The numerical integration method. + name: str + The group name. + + References:: + + .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic + inhibition in a hippocampal interneuronal network model. Journal of + neuroscience, 16(20), pp.6402-6413. + + """ + + def __init__( + self, + *args, + input_var: bool = True, + + **kwargs, + ): + self.input_var = input_var + super().__init__(*args, **kwargs, init_var=False) + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) diff --git a/brainpy/_src/dynold/neurons/fractional_models.py b/brainpy/_src/dynold/neurons/fractional_models.py index 2497cbb88..3b68c465f 100644 --- a/brainpy/_src/dynold/neurons/fractional_models.py +++ b/brainpy/_src/dynold/neurons/fractional_models.py @@ -15,321 +15,321 @@ from brainpy.types import Shape, ArrayType __all__ = [ - 'FractionalNeuron', - 'FractionalFHR', - 'FractionalIzhikevich', + 'FractionalNeuron', + 'FractionalFHR', + 'FractionalIzhikevich', ] class FractionalNeuron(NeuDyn): - """Fractional-order neuron model.""" - pass + """Fractional-order neuron model.""" + pass class FractionalFHR(FractionalNeuron): - r"""The fractional-order FH-R model [1]_. - - FitzHugh and Rinzel introduced FH-R model (1976, in an unpublished article), - which is the modification of the classical FHN neuron model. The fractional-order - FH-R model is described as - - .. math:: - - \begin{array}{rcl} - \frac{{d}^{\alpha }v}{d{t}^{\alpha }} & = & v-{v}^{3}/3-w+y+I={f}_{1}(v,w,y),\\ - \frac{{d}^{\alpha }w}{d{t}^{\alpha }} & = & \delta (a+v-bw)={f}_{2}(v,w,y),\\ - \frac{{d}^{\alpha }y}{d{t}^{\alpha }} & = & \mu (c-v-dy)={f}_{3}(v,w,y), - \end{array} - - where :math:`v, w` and :math:`y` represent the membrane voltage, recovery variable - and slow modulation of the current respectively. - :math:`I` measures the constant magnitude of external stimulus current, and :math:`\alpha` - is the fractional exponent which ranges in the interval :math:`(0 < \alpha \le 1)`. - :math:`a, b, c, d, \delta` and :math:`\mu` are the system parameters. - - The system reduces to the original classical order system when :math:`\alpha=1`. - - :math:`\mu` indicates a small parameter that determines the pace of the slow system - variable :math:`y`. The fast subsystem (:math:`v-w`) presents a relaxation oscillator - in the phase plane where :math:`\delta` is a small parameter. - :math:`v` is expressed in mV (millivolt) scale. Time :math:`t` is in ms (millisecond) scale. - It exhibits tonic spiking or quiescent state depending on the parameter sets for a fixed - value of :math:`I`. The parameter :math:`a` in the 2D FHN model corresponds to the - parameter :math:`c` of the FH-R neuron model. If we decrease the value of :math:`a`, - it causes longer intervals between two burstings, however there exists :math:`a` - relatively fixed time of bursting duration. With the increasing of :math:`a`, the - interburst intervals become shorter and periodic bursting changes to tonic spiking. - - Examples:: - - - [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html) - - - Parameters:: - - size: int, sequence of int - The size of the neuron group. - alpha: float, tensor - The fractional order. - num_memory: int - The total number of the short memory. - - References:: - - .. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4 - """ - - def __init__( - self, - size: Shape, - alpha: Union[float, Sequence[float]], - num_memory: int = 1000, - a: Union[float, ArrayType, Initializer, Callable] = 0.7, - b: Union[float, ArrayType, Initializer, Callable] = 0.8, - c: Union[float, ArrayType, Initializer, Callable] = -0.775, - d: Union[float, ArrayType, Initializer, Callable] = 1., - delta: Union[float, ArrayType, Initializer, Callable] = 0.08, - mu: Union[float, ArrayType, Initializer, Callable] = 0.0001, - Vth: Union[float, ArrayType, Initializer, Callable] = 1.8, - V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.5), - w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - y_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - input_var: bool = True, - name: str = None, - keep_size: bool = False, - ): - super(FractionalFHR, self).__init__(size, keep_size=keep_size, name=name) - - assert self.mode.is_one_of(bm.NonBatchingMode, ) - - # fractional order - self.alpha = alpha - is_integer(num_memory, 'num_memory', allow_none=False) - - # parameters - self.a = parameter(a, self.varshape, allow_none=False) - self.b = parameter(b, self.varshape, allow_none=False) - self.c = parameter(c, self.varshape, allow_none=False) - self.d = parameter(d, self.varshape, allow_none=False) - self.mu = parameter(mu, self.varshape, allow_none=False) - self.Vth = parameter(Vth, self.varshape, allow_none=False) - self.delta = parameter(delta, self.varshape, allow_none=False) - self.input_var = input_var - - # initializers - is_initializer(V_initializer, 'V_initializer', allow_none=False) - is_initializer(w_initializer, 'w_initializer', allow_none=False) - is_initializer(y_initializer, 'y_initializer', allow_none=False) - self._V_initializer = V_initializer - self._w_initializer = w_initializer - self._y_initializer = y_initializer - - # variables - self.V = bm.Variable(parameter(V_initializer, self.varshape)) - self.w = bm.Variable(parameter(w_initializer, self.varshape)) - self.y = bm.Variable(parameter(y_initializer, self.varshape)) - self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) - if self.input_var: - self.input = bm.Variable(jnp.zeros(self.varshape)) - - # integral function - self.integral = GLShortMemory(self.derivative, - alpha=alpha, - num_memory=num_memory, - inits=[self.V, self.w, self.y]) - - def reset_state(self, batch_size=None): - self.V.value = parameter(self._V_initializer, self.varshape) - self.w.value = parameter(self._w_initializer, self.varshape) - self.y.value = parameter(self._y_initializer, self.varshape) - self.spike[:] = False - if self.input_var: - self.input[:] = 0 - # integral function reset - self.integral.reset([self.V, self.w, self.y]) - - def dV(self, V, t, w, y, I): - return V - V ** 3 / 3 - w + y + I - - def dw(self, w, t, V): - return self.delta * (self.a + V - self.b * w) - - def dy(self, y, t, V): - return self.mu * (self.c - V - self.d * y) - - @property - def derivative(self): - return JointEq([self.dV, self.dw, self.dy]) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - V, w, y = self.integral(self.V, self.w, self.y, t, I=x, dt=dt) - self.spike.value = jnp.logical_and(V >= self.Vth, self.V < self.Vth) - self.V.value = V - self.w.value = w - self.y.value = y - return self.spike.value - - def clear_input(self): - if self.input_var: - self.input[:] = 0. + r"""The fractional-order FH-R model [1]_. + + FitzHugh and Rinzel introduced FH-R model (1976, in an unpublished article), + which is the modification of the classical FHN neuron model. The fractional-order + FH-R model is described as + + .. math:: + + \begin{array}{rcl} + \frac{{d}^{\alpha }v}{d{t}^{\alpha }} & = & v-{v}^{3}/3-w+y+I={f}_{1}(v,w,y),\\ + \frac{{d}^{\alpha }w}{d{t}^{\alpha }} & = & \delta (a+v-bw)={f}_{2}(v,w,y),\\ + \frac{{d}^{\alpha }y}{d{t}^{\alpha }} & = & \mu (c-v-dy)={f}_{3}(v,w,y), + \end{array} + + where :math:`v, w` and :math:`y` represent the membrane voltage, recovery variable + and slow modulation of the current respectively. + :math:`I` measures the constant magnitude of external stimulus current, and :math:`\alpha` + is the fractional exponent which ranges in the interval :math:`(0 < \alpha \le 1)`. + :math:`a, b, c, d, \delta` and :math:`\mu` are the system parameters. + + The system reduces to the original classical order system when :math:`\alpha=1`. + + :math:`\mu` indicates a small parameter that determines the pace of the slow system + variable :math:`y`. The fast subsystem (:math:`v-w`) presents a relaxation oscillator + in the phase plane where :math:`\delta` is a small parameter. + :math:`v` is expressed in mV (millivolt) scale. Time :math:`t` is in ms (millisecond) scale. + It exhibits tonic spiking or quiescent state depending on the parameter sets for a fixed + value of :math:`I`. The parameter :math:`a` in the 2D FHN model corresponds to the + parameter :math:`c` of the FH-R neuron model. If we decrease the value of :math:`a`, + it causes longer intervals between two burstings, however there exists :math:`a` + relatively fixed time of bursting duration. With the increasing of :math:`a`, the + interburst intervals become shorter and periodic bursting changes to tonic spiking. + + Examples:: + + - [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html) + + + Parameters:: + + size: int, sequence of int + The size of the neuron group. + alpha: float, tensor + The fractional order. + num_memory: int + The total number of the short memory. + + References:: + + .. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4 + """ + + def __init__( + self, + size: Shape, + alpha: Union[float, Sequence[float]], + num_memory: int = 1000, + a: Union[float, ArrayType, Initializer, Callable] = 0.7, + b: Union[float, ArrayType, Initializer, Callable] = 0.8, + c: Union[float, ArrayType, Initializer, Callable] = -0.775, + d: Union[float, ArrayType, Initializer, Callable] = 1., + delta: Union[float, ArrayType, Initializer, Callable] = 0.08, + mu: Union[float, ArrayType, Initializer, Callable] = 0.0001, + Vth: Union[float, ArrayType, Initializer, Callable] = 1.8, + V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(2.5), + w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + y_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + input_var: bool = True, + name: str = None, + keep_size: bool = False, + ): + super(FractionalFHR, self).__init__(size, keep_size=keep_size, name=name) + + assert self.mode.is_one_of(bm.NonBatchingMode, ) + + # fractional order + self.alpha = alpha + is_integer(num_memory, 'num_memory', allow_none=False) + + # parameters + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.mu = parameter(mu, self.varshape, allow_none=False) + self.Vth = parameter(Vth, self.varshape, allow_none=False) + self.delta = parameter(delta, self.varshape, allow_none=False) + self.input_var = input_var + + # initializers + is_initializer(V_initializer, 'V_initializer', allow_none=False) + is_initializer(w_initializer, 'w_initializer', allow_none=False) + is_initializer(y_initializer, 'y_initializer', allow_none=False) + self._V_initializer = V_initializer + self._w_initializer = w_initializer + self._y_initializer = y_initializer + + # variables + self.V = bm.Variable(parameter(V_initializer, self.varshape)) + self.w = bm.Variable(parameter(w_initializer, self.varshape)) + self.y = bm.Variable(parameter(y_initializer, self.varshape)) + self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) + if self.input_var: + self.input = bm.Variable(jnp.zeros(self.varshape)) + + # integral function + self.integral = GLShortMemory(self.derivative, + alpha=alpha, + num_memory=num_memory, + inits=[self.V, self.w, self.y]) + + def reset_state(self, batch_size=None): + self.V.value = parameter(self._V_initializer, self.varshape) + self.w.value = parameter(self._w_initializer, self.varshape) + self.y.value = parameter(self._y_initializer, self.varshape) + self.spike[:] = False + if self.input_var: + self.input[:] = 0 + # integral function reset + self.integral.reset([self.V, self.w, self.y]) + + def dV(self, V, t, w, y, I): + return V - V ** 3 / 3 - w + y + I + + def dw(self, w, t, V): + return self.delta * (self.a + V - self.b * w) + + def dy(self, y, t, V): + return self.mu * (self.c - V - self.d * y) + + @property + def derivative(self): + return JointEq([self.dV, self.dw, self.dy]) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, w, y = self.integral(self.V, self.w, self.y, t, I=x, dt=dt) + self.spike.value = jnp.logical_and(V >= self.Vth, self.V < self.Vth) + self.V.value = V + self.w.value = w + self.y.value = y + return self.spike.value + + def clear_input(self): + if self.input_var: + self.input[:] = 0. class FractionalIzhikevich(FractionalNeuron): - r"""Fractional-order Izhikevich model [10]_. - - The fractional-order Izhikevich model is given by - - .. math:: - - \begin{aligned} - &\tau \frac{d^{\alpha} v}{d t^{\alpha}}=\mathrm{f} v^{2}+g v+h-u+R I \\ - &\tau \frac{d^{\alpha} u}{d t^{\alpha}}=a(b v-u) - \end{aligned} - - where :math:`\alpha` is the fractional order (exponent) such that :math:`0<\alpha\le1`. - It is a commensurate system that reduces to classical Izhikevich model at :math:`\alpha=1`. - - The time :math:`t` is in ms; and the system variable :math:`v` expressed in mV - corresponds to membrane voltage. Moreover, :math:`u` expressed in mV is the - recovery variable that corresponds to the activation of K+ ionic current and - inactivation of Na+ ionic current. - - The parameters :math:`f, g, h` are fixed constants (should not be changed) such - that :math:`f=0.04` (mV)−1, :math:`g=5, h=140` mV; and :math:`a` and :math:`b` are - dimensionless parameters. The time constant :math:`\tau=1` ms; the resistance - :math:`R=1` Ω; and :math:`I` expressed in mA measures the injected (applied) - dc stimulus current to the system. - - When the membrane voltage reaches the spike peak :math:`v_{peak}`, the two variables - are rest as follow: - - .. math:: - - \text { if } v \geq v_{\text {peak }} \text { then }\left\{\begin{array}{l} - v \leftarrow c \\ - u \leftarrow u+d - \end{array}\right. - - we used :math:`v_{peak}=30` mV, and :math:`c` and :math:`d` are parameters expressed - in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the - recovery variable :math:`u` are reset according to the above condition. - - Examples:: - - - [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html) - - - References:: - - .. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and - bursting patterns of fractional-order Izhikevich model." Communications - in Nonlinear Science and Numerical Simulation 56 (2018): 161-176. - - """ - - def __init__( - self, - size: Shape, - alpha: Union[float, Sequence[float]], - num_memory: int, - a: Union[float, ArrayType, Initializer, Callable] = 0.02, - b: Union[float, ArrayType, Initializer, Callable] = 0.20, - c: Union[float, ArrayType, Initializer, Callable] = -65., - d: Union[float, ArrayType, Initializer, Callable] = 8., - f: Union[float, ArrayType, Initializer, Callable] = 0.04, - g: Union[float, ArrayType, Initializer, Callable] = 5., - h: Union[float, ArrayType, Initializer, Callable] = 140., - R: Union[float, ArrayType, Initializer, Callable] = 1., - tau: Union[float, ArrayType, Initializer, Callable] = 1., - V_th: Union[float, ArrayType, Initializer, Callable] = 30., - V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-65.), - u_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.20 * -65.), - keep_size: bool = False, - input_var: bool = True, - name: str = None - ): - # initialization - super(FractionalIzhikevich, self).__init__(size=size, keep_size=keep_size, name=name) - assert self.mode.is_a(bm.NonBatchingMode) - - # params - self.alpha = alpha - is_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True) - self.a = parameter(a, self.varshape, allow_none=False) - self.b = parameter(b, self.varshape, allow_none=False) - self.c = parameter(c, self.varshape, allow_none=False) - self.d = parameter(d, self.varshape, allow_none=False) - self.f = parameter(f, self.varshape, allow_none=False) - self.g = parameter(g, self.varshape, allow_none=False) - self.h = parameter(h, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.R = parameter(R, self.varshape, allow_none=False) - self.V_th = parameter(V_th, self.varshape, allow_none=False) - self.input_var = input_var - - # initializers - is_initializer(V_initializer, 'V_initializer', allow_none=False) - is_initializer(u_initializer, 'u_initializer', allow_none=False) - self._V_initializer = V_initializer - self._u_initializer = u_initializer - - # variables - self.V = bm.Variable(parameter(V_initializer, self.varshape)) - self.u = bm.Variable(parameter(u_initializer, self.varshape)) - self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) - if self.input_var: - self.input = bm.Variable(jnp.zeros(self.varshape)) - - # functions - is_integer(num_memory, 'num_memory', allow_none=False) - self.integral = CaputoL1Schema(f=self.derivative, - alpha=alpha, - num_memory=num_memory, - inits=[self.V, self.u]) - - def reset_state(self, batch_size=None): - self.V.value = parameter(self._V_initializer, self.varshape) - self.u.value = parameter(self._u_initializer, self.varshape) - self.spike[:] = False - if self.input_var: - self.input[:] = 0 - # integral function reset - self.integral.reset([self.V, self.u]) - - def dV(self, V, t, u, I_ext): - dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext - return dVdt / self.tau - - def du(self, u, t, V): - dudt = self.a * (self.b * V - u) - return dudt / self.tau - - @property - def derivative(self): - return JointEq(self.dV, self.du) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - V, u = self.integral(self.V, self.u, t=share['t'], I_ext=x, dt=share['dt']) - spikes = V >= self.V_th - self.V.value = jnp.where(spikes, self.c, V) - self.u.value = jnp.where(spikes, u + self.d, u) - self.spike.value = spikes - return spikes - - def clear_input(self): - if self.input_var: - self.input[:] = 0. + r"""Fractional-order Izhikevich model [10]_. + + The fractional-order Izhikevich model is given by + + .. math:: + + \begin{aligned} + &\tau \frac{d^{\alpha} v}{d t^{\alpha}}=\mathrm{f} v^{2}+g v+h-u+R I \\ + &\tau \frac{d^{\alpha} u}{d t^{\alpha}}=a(b v-u) + \end{aligned} + + where :math:`\alpha` is the fractional order (exponent) such that :math:`0<\alpha\le1`. + It is a commensurate system that reduces to classical Izhikevich model at :math:`\alpha=1`. + + The time :math:`t` is in ms; and the system variable :math:`v` expressed in mV + corresponds to membrane voltage. Moreover, :math:`u` expressed in mV is the + recovery variable that corresponds to the activation of K+ ionic current and + inactivation of Na+ ionic current. + + The parameters :math:`f, g, h` are fixed constants (should not be changed) such + that :math:`f=0.04` (mV)−1, :math:`g=5, h=140` mV; and :math:`a` and :math:`b` are + dimensionless parameters. The time constant :math:`\tau=1` ms; the resistance + :math:`R=1` Ω; and :math:`I` expressed in mA measures the injected (applied) + dc stimulus current to the system. + + When the membrane voltage reaches the spike peak :math:`v_{peak}`, the two variables + are rest as follow: + + .. math:: + + \text { if } v \geq v_{\text {peak }} \text { then }\left\{\begin{array}{l} + v \leftarrow c \\ + u \leftarrow u+d + \end{array}\right. + + we used :math:`v_{peak}=30` mV, and :math:`c` and :math:`d` are parameters expressed + in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the + recovery variable :math:`u` are reset according to the above condition. + + Examples:: + + - [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html) + + + References:: + + .. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and + bursting patterns of fractional-order Izhikevich model." Communications + in Nonlinear Science and Numerical Simulation 56 (2018): 161-176. + + """ + + def __init__( + self, + size: Shape, + alpha: Union[float, Sequence[float]], + num_memory: int, + a: Union[float, ArrayType, Initializer, Callable] = 0.02, + b: Union[float, ArrayType, Initializer, Callable] = 0.20, + c: Union[float, ArrayType, Initializer, Callable] = -65., + d: Union[float, ArrayType, Initializer, Callable] = 8., + f: Union[float, ArrayType, Initializer, Callable] = 0.04, + g: Union[float, ArrayType, Initializer, Callable] = 5., + h: Union[float, ArrayType, Initializer, Callable] = 140., + R: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType, Initializer, Callable] = 1., + V_th: Union[float, ArrayType, Initializer, Callable] = 30., + V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-65.), + u_initializer: Union[Initializer, Callable, ArrayType] = OneInit(0.20 * -65.), + keep_size: bool = False, + input_var: bool = True, + name: str = None + ): + # initialization + super(FractionalIzhikevich, self).__init__(size=size, keep_size=keep_size, name=name) + assert self.mode.is_a(bm.NonBatchingMode) + + # params + self.alpha = alpha + is_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True) + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.f = parameter(f, self.varshape, allow_none=False) + self.g = parameter(g, self.varshape, allow_none=False) + self.h = parameter(h, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.input_var = input_var + + # initializers + is_initializer(V_initializer, 'V_initializer', allow_none=False) + is_initializer(u_initializer, 'u_initializer', allow_none=False) + self._V_initializer = V_initializer + self._u_initializer = u_initializer + + # variables + self.V = bm.Variable(parameter(V_initializer, self.varshape)) + self.u = bm.Variable(parameter(u_initializer, self.varshape)) + self.spike = bm.Variable(jnp.zeros(self.varshape, dtype=bool)) + if self.input_var: + self.input = bm.Variable(jnp.zeros(self.varshape)) + + # functions + is_integer(num_memory, 'num_memory', allow_none=False) + self.integral = CaputoL1Schema(f=self.derivative, + alpha=alpha, + num_memory=num_memory, + inits=[self.V, self.u]) + + def reset_state(self, batch_size=None): + self.V.value = parameter(self._V_initializer, self.varshape) + self.u.value = parameter(self._u_initializer, self.varshape) + self.spike[:] = False + if self.input_var: + self.input[:] = 0 + # integral function reset + self.integral.reset([self.V, self.u]) + + def dV(self, V, t, u, I_ext): + dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext + return dVdt / self.tau + + def du(self, u, t, V): + dudt = self.a * (self.b * V - u) + return dudt / self.tau + + @property + def derivative(self): + return JointEq(self.dV, self.du) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, u = self.integral(self.V, self.u, t=share['t'], I_ext=x, dt=share['dt']) + spikes = V >= self.V_th + self.V.value = jnp.where(spikes, self.c, V) + self.u.value = jnp.where(spikes, u + self.d, u) + self.spike.value = spikes + return spikes + + def clear_input(self): + if self.input_var: + self.input[:] = 0. diff --git a/brainpy/_src/dynold/neurons/reduced_models.py b/brainpy/_src/dynold/neurons/reduced_models.py index 6f74e9b6c..181845fd7 100644 --- a/brainpy/_src/dynold/neurons/reduced_models.py +++ b/brainpy/_src/dynold/neurons/reduced_models.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable, Optional +from typing import Union, Callable from jax.lax import stop_gradient import brainpy.math as bm from brainpy._src.context import share -from brainpy._src.dyn.neurons import lif from brainpy._src.dyn.base import NeuDyn +from brainpy._src.dyn.neurons import lif from brainpy._src.initialize import (ZeroInit, OneInit, Initializer, @@ -19,1519 +19,1519 @@ from brainpy.types import Shape, ArrayType __all__ = [ - 'LeakyIntegrator', - 'LIF', - 'ExpIF', - 'AdExIF', - 'QuaIF', - 'AdQuaIF', - 'GIF', - 'ALIFBellec2020', - 'Izhikevich', - 'HindmarshRose', - 'FHN', + 'LeakyIntegrator', + 'LIF', + 'ExpIF', + 'AdExIF', + 'QuaIF', + 'AdQuaIF', + 'GIF', + 'ALIFBellec2020', + 'Izhikevich', + 'HindmarshRose', + 'FHN', ] class LeakyIntegrator(NeuDyn): - r"""Leaky Integrator Model. - - **Model Descriptions** - - This class implements a leaky integrator model, in which its dynamics is - given by: - - .. math:: - - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) - - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`\tau` is the time constant, and :math:`R` is the - resistance. - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - V_rest: float, ArrayType, Initializer, callable - Resting membrane potential. - R: float, ArrayType, Initializer, callable - Membrane resistance. - tau: float, ArrayType, Initializer, callable - Membrane time constant. - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - noise: ArrayType, Initializer, callable - The noise added onto the membrane potential - method: str - The numerical integration method. - name: str - The group name. - """ - - def __init__( - self, - - # neuron group size - size: Shape, - keep_size: bool = False, - - # neuron parameters - V_rest: Union[float, ArrayType, Initializer, Callable] = 0., - R: Union[float, ArrayType, Initializer, Callable] = 1., - tau: Union[float, ArrayType, Initializer, Callable] = 10., - V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - noise: Union[float, ArrayType, Initializer, Callable] = None, - - # other parameter - input_var: bool = True, - name: str = None, - mode: bm.Mode = None, - method: str = 'exp_auto', - ): - super().__init__(size=size, - mode=mode, - keep_size=keep_size, - name=name) - is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) - - # parameters - self.V_rest = parameter(V_rest, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.R = parameter(R, self.varshape, allow_none=False) - self.noise = init_noise(noise, self.varshape) - self.input_var = input_var - - # initializers - self._V_initializer = is_initializer(V_initializer) - - # integral - if self.noise is None: - self.integral = odeint(method=method, f=self.derivative) - else: - self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - - # variables - self.reset_state(self.mode) - - def derivative(self, V, t, I_ext): - return (-V + self.V_rest + self.R * I_ext) / self.tau - - def reset_state(self, batch_size=None): - self.V = variable_(self._V_initializer, self.varshape, batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - self.V.value = self.integral(self.V.value, t, x, dt) - return self.V.value - - def clear_input(self): - if self.input_var: - self.input[:] = 0. + r"""Leaky Integrator Model. + + **Model Descriptions** + + This class implements a leaky integrator model, in which its dynamics is + given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`\tau` is the time constant, and :math:`R` is the + resistance. + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + V_rest: float, ArrayType, Initializer, callable + Resting membrane potential. + R: float, ArrayType, Initializer, callable + Membrane resistance. + tau: float, ArrayType, Initializer, callable + Membrane time constant. + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + noise: ArrayType, Initializer, callable + The noise added onto the membrane potential + method: str + The numerical integration method. + name: str + The group name. + """ + + def __init__( + self, + + # neuron group size + size: Shape, + keep_size: bool = False, + + # neuron parameters + V_rest: Union[float, ArrayType, Initializer, Callable] = 0., + R: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType, Initializer, Callable] = 10., + V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + noise: Union[float, ArrayType, Initializer, Callable] = None, + + # other parameter + input_var: bool = True, + name: str = None, + mode: bm.Mode = None, + method: str = 'exp_auto', + ): + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape) + self.input_var = input_var + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + # variables + self.reset_state(self.mode) + + def derivative(self, V, t, I_ext): + return (-V + self.V_rest + self.R * I_ext) / self.tau + + def reset_state(self, batch_size=None): + self.V = variable_(self._V_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + self.V.value = self.integral(self.V.value, t, x, dt) + return self.V.value + + def clear_input(self): + if self.input_var: + self.input[:] = 0. class LIF(lif.LifRef): - r"""Leaky integrate-and-fire neuron model. - - **Model Descriptions** - - The formal equations of a LIF model [1]_ is given by: - - .. math:: - - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ - \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad - \text{last} \quad \tau_{ref} \quad \text{ms} - - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`V_{reset}` is the reset membrane potential, - :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, - :math:`\tau_{ref}` is the refractory time period, - and :math:`I` is the time-variant synaptic inputs. - - **Model Examples** - - - `(Brette, Romain. 2004) LIF phase locking `_ - - - Parameters:: - - size: sequence of int, int - The size of the neuron group. - V_rest: float, ArrayType, Initializer, callable - Resting membrane potential. - V_reset: float, ArrayType, Initializer, callable - Reset potential after spike. - V_th: float, ArrayType, Initializer, callable - Threshold potential of spike. - R: float, ArrayType, Initializer, callable - Membrane resistance. - tau: float, ArrayType, Initializer, callable - Membrane time constant. - tau_ref: float, ArrayType, Initializer, callable - Refractory period length.(ms) - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - noise: ArrayType, Initializer, callable - The noise added onto the membrane potential - method: str - The numerical integration method. - name: str - The group name. - - References:: - - .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model - neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. - """ - - def __init__( - self, - *args, - input_var: bool = True, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Leaky integrate-and-fire neuron model. + + **Model Descriptions** + + The formal equations of a LIF model [1]_ is given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad + \text{last} \quad \tau_{ref} \quad \text{ms} + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + :math:`\tau_{ref}` is the refractory time period, + and :math:`I` is the time-variant synaptic inputs. + + **Model Examples** + + - `(Brette, Romain. 2004) LIF phase locking `_ + + + Parameters:: + + size: sequence of int, int + The size of the neuron group. + V_rest: float, ArrayType, Initializer, callable + Resting membrane potential. + V_reset: float, ArrayType, Initializer, callable + Reset potential after spike. + V_th: float, ArrayType, Initializer, callable + Threshold potential of spike. + R: float, ArrayType, Initializer, callable + Membrane resistance. + tau: float, ArrayType, Initializer, callable + Membrane time constant. + tau_ref: float, ArrayType, Initializer, callable + Refractory period length.(ms) + V_initializer: ArrayType, Initializer, callable + The initializer of membrane potential. + noise: ArrayType, Initializer, callable + The noise added onto the membrane potential + method: str + The numerical integration method. + name: str + The group name. + + References:: + + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + """ + + def __init__( + self, + *args, + input_var: bool = True, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class ExpIF(lif.ExpIFRef): - r"""Exponential integrate-and-fire neuron model. - - **Model Descriptions** - - In the exponential integrate-and-fire model [1]_, the differential - equation for the membrane potential is given by - - .. math:: - - \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ - \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} - - This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` - and "threshold" :math:`\vartheta_{rh}`. - - The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` - defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to - :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, - where :math:`\tau_{\rm ref}` is an absolute refractory time. - If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, - its exact value does not play any role. The reason is that the upswing of the action - potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in - an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical - convenience. For a formal mathematical analysis of the model, the threshold can be pushed - to infinity. - - The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk - and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. - It is one of the prominent examples of a precise theoretical prediction in computational - neuroscience that was later confirmed by experimental neuroscience. - - Two important remarks: - - - (i) The right-hand side of the above equation contains a nonlinearity - that can be directly extracted from experimental data [3]_. In this sense the exponential - nonlinearity is not an arbitrary choice but directly supported by experimental evidence. - - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing - rate for constant input, and the linear response to fluctuations, even in the presence - of input noise [4]_. - - **Model Examples** - - >>> import brainpy as bp - >>> group = bp.neurons.ExpIF(1) - >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 10.)) - >>> runner.run(300., ) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) - - - **Model Parameters** - - ============= ============== ======== =================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- --------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - R 1 \ Membrane resistance. - tau 10 \ Membrane time constant. Compute by R * C. - tau_ref 1.7 \ Refractory period length. - ============= ============== ======== =================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - **References** - - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). - Neuronal dynamics: From single neurons to networks and models - of cognition. Cambridge University Press. - .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, - Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves - are reliable predictors of naturalistic pyramidal-neuron voltage - traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. - .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear - integrate-and-fire neurons to modulated current-based and - conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. - .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire - """ - - def __init__( - self, - *args, - input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Exponential integrate-and-fire neuron model. + + **Model Descriptions** + + In the exponential integrate-and-fire model [1]_, the differential + equation for the membrane potential is given by + + .. math:: + + \tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ + \text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} + + This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` + and "threshold" :math:`\vartheta_{rh}`. + + The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` + defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to + :math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, + where :math:`\tau_{\rm ref}` is an absolute refractory time. + If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, + its exact value does not play any role. The reason is that the upswing of the action + potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in + an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical + convenience. For a formal mathematical analysis of the model, the threshold can be pushed + to infinity. + + The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk + and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. + It is one of the prominent examples of a precise theoretical prediction in computational + neuroscience that was later confirmed by experimental neuroscience. + + Two important remarks: + + - (i) The right-hand side of the above equation contains a nonlinearity + that can be directly extracted from experimental data [3]_. In this sense the exponential + nonlinearity is not an arbitrary choice but directly supported by experimental evidence. + - (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing + rate for constant input, and the linear response to fluctuations, even in the presence + of input noise [4]_. + + **Model Examples** + + >>> import brainpy as bp + >>> group = bp.neurons.ExpIF(1) + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 10.)) + >>> runner.run(300., ) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) + + + **Model Parameters** + + ============= ============== ======== =================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- --------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + R 1 \ Membrane resistance. + tau 10 \ Membrane time constant. Compute by R * C. + tau_ref 1.7 \ Refractory period length. + ============= ============== ======== =================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). + Neuronal dynamics: From single neurons to networks and models + of cognition. Cambridge University Press. + .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, + Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves + are reliable predictors of naturalistic pyramidal-neuron voltage + traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. + .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear + integrate-and-fire neurons to modulated current-based and + conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. + .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire + """ + + def __init__( + self, + *args, + input_var: bool = True, + noise: Union[float, ArrayType, Initializer, Callable] = None, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class AdExIF(lif.AdExIFRef): - r"""Adaptive exponential integrate-and-fire neuron model. - - **Model Descriptions** - - The **adaptive exponential integrate-and-fire model**, also called AdEx, is a - spiking neuron model with two variables [1]_ [2]_. - - .. math:: - - \begin{aligned} - \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ - \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w - \end{aligned} - - once the membrane potential reaches the spike threshold, - - .. math:: - - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. - - The first equation describes the dynamics of the membrane potential and includes - an activation term with an exponential voltage dependence. Voltage is coupled to - a second equation which describes adaptation. Both variables are reset if an action - potential has been triggered. The combination of adaptation and exponential voltage - dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. - - The adaptive exponential integrate-and-fire model is capable of describing known - neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, - initial bursting, fast spiking, and regular spiking. - - **Model Examples** - - - `Examples for different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_T -59.9 mV Threshold potential of generating action potential. - delta_T 3.48 \ Spike slope factor. - a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` - b 1 \ The increment of :math:`w` produced by a spike. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_w 30 ms Time constant of the adaptation current. - tau_ref 0. ms Refractory time. - ============= ============== ======== ======================================================================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - **References** - - .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation - mechanisms determine the neuronal response to fluctuating - inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. - .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model - """ - - def __init__( - self, - *args, - input_var: bool = True, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Adaptive exponential integrate-and-fire neuron model. + + **Model Descriptions** + + The **adaptive exponential integrate-and-fire model**, also called AdEx, is a + spiking neuron model with two variables [1]_ [2]_. + + .. math:: + + \begin{aligned} + \tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ + \tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w + \end{aligned} + + once the membrane potential reaches the spike threshold, + + .. math:: + + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. + + The first equation describes the dynamics of the membrane potential and includes + an activation term with an exponential voltage dependence. Voltage is coupled to + a second equation which describes adaptation. Both variables are reset if an action + potential has been triggered. The combination of adaptation and exponential voltage + dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. + + The adaptive exponential integrate-and-fire model is capable of describing known + neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, + initial bursting, fast spiking, and regular spiking. + + **Model Examples** + + - `Examples for different firing patterns `_ + + **Model Parameters** + + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_T -59.9 mV Threshold potential of generating action potential. + delta_T 3.48 \ Spike slope factor. + a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` + b 1 \ The increment of :math:`w` produced by a spike. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_w 30 ms Time constant of the adaptation current. + tau_ref 0. ms Refractory time. + ============= ============== ======== ======================================================================================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation + mechanisms determine the neuronal response to fluctuating + inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. + .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model + """ + + def __init__( + self, + *args, + input_var: bool = True, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class QuaIF(lif.QuaIFRef): - r"""Quadratic Integrate-and-Fire neuron model. - - **Model Descriptions** - - In contrast to physiologically accurate but computationally expensive - neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only - to produce **action potential-like patterns** and ignores subtleties - like gating variables, which play an important role in generating action - potentials in a real neuron. However, the QIF model is incredibly easy - to implement and compute, and relatively straightforward to study and - understand, thus has found ubiquitous use in computational neuroscience. - - .. math:: - - \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) - - where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). - - **Model Examples** - - >>> import brainpy as bp - >>> - >>> group = bp.neurons.QuaIF(1,) - >>> - >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 20.)) - >>> runner.run(duration=200.) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) - - - **Model Parameters** - - ============= ============== ======== ======================================================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. - c .07 \ Coefficient describes membrane potential update. Larger than 0. - R 1 \ Membrane resistance. - tau 10 ms Membrane time constant. Compute by R * C. - tau_ref 0 ms Refractory period length. - ============= ============== ======== ======================================================================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - **References** - - .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg - (2000) Intrinsic dynamics in neuronal networks. I. Theory. - J. Neurophysiology 83, pp. 808–827. - """ - - def __init__( - self, - *args, - input_var: bool = True, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Quadratic Integrate-and-Fire neuron model. + + **Model Descriptions** + + In contrast to physiologically accurate but computationally expensive + neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only + to produce **action potential-like patterns** and ignores subtleties + like gating variables, which play an important role in generating action + potentials in a real neuron. However, the QIF model is incredibly easy + to implement and compute, and relatively straightforward to study and + understand, thus has found ubiquitous use in computational neuroscience. + + .. math:: + + \tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) + + where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). + + **Model Examples** + + >>> import brainpy as bp + >>> + >>> group = bp.neurons.QuaIF(1,) + >>> + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=('input', 20.)) + >>> runner.run(duration=200.) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) + + + **Model Parameters** + + ============= ============== ======== ======================================================================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. + c .07 \ Coefficient describes membrane potential update. Larger than 0. + R 1 \ Membrane resistance. + tau 10 ms Membrane time constant. Compute by R * C. + tau_ref 0 ms Refractory period length. + ============= ============== ======== ======================================================================================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg + (2000) Intrinsic dynamics in neuronal networks. I. Theory. + J. Neurophysiology 83, pp. 808–827. + """ + + def __init__( + self, + *args, + input_var: bool = True, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class AdQuaIF(lif.AdQuaIFRef): - r"""Adaptive quadratic integrate-and-fire neuron model. - - **Model Descriptions** - - The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: - - .. math:: - - \begin{aligned} - \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ - \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, - \end{aligned} - - once the membrane potential reaches the spike threshold, - - .. math:: - - V \rightarrow V_{reset}, \\ - w \rightarrow w+b. - - **Model Examples** - - >>> import brainpy as bp - >>> group = bp.neurons.AdQuaIF(1, ) - >>> runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 30.)) - >>> runner.run(300) - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V') - >>> fig.add_subplot(gs[1, 0]) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True) - - **Model Parameters** - - ============= ============== ======== ======================================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------------------- - V_rest -65 mV Resting potential. - V_reset -68 mV Reset potential after spike. - V_th -30 mV Threshold potential of spike and reset. - V_c -50 mV Critical voltage for spike initiation. Must be larger - than :math:`V_{rest}`. - a 1 \ The sensitivity of the recovery variable :math:`u` to - the sub-threshold fluctuations of the membrane - potential :math:`v` - b .1 \ The increment of :math:`w` produced by a spike. - c .07 \ Coefficient describes membrane potential update. - Larger than 0. - tau 10 ms Membrane time constant. - tau_w 10 ms Time constant of the adaptation current. - ============= ============== ======== ======================================================= - - **Model Variables** - - ================== ================= ========================================================== - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- ---------------------------------------------------------- - V 0 Membrane potential. - w 0 Adaptation current. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================== - - **References** - - .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking - neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. - .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of - nonlinear integrate-and-fire neurons." SIAM Journal on Applied - Mathematics 68, no. 4 (2008): 1045-1079. - """ - - def __init__( - self, - *args, - input_var: bool = True, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Adaptive quadratic integrate-and-fire neuron model. + + **Model Descriptions** + + The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: + + .. math:: + + \begin{aligned} + \tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ + \tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, + \end{aligned} + + once the membrane potential reaches the spike threshold, + + .. math:: + + V \rightarrow V_{reset}, \\ + w \rightarrow w+b. + + **Model Examples** + + >>> import brainpy as bp + >>> group = bp.neurons.AdQuaIF(1, ) + >>> runner = bp.DSRunner(group, monitors=['V', 'w'], inputs=('input', 30.)) + >>> runner.run(300) + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V') + >>> fig.add_subplot(gs[1, 0]) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True) + + **Model Parameters** + + ============= ============== ======== ======================================================= + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------------------- + V_rest -65 mV Resting potential. + V_reset -68 mV Reset potential after spike. + V_th -30 mV Threshold potential of spike and reset. + V_c -50 mV Critical voltage for spike initiation. Must be larger + than :math:`V_{rest}`. + a 1 \ The sensitivity of the recovery variable :math:`u` to + the sub-threshold fluctuations of the membrane + potential :math:`v` + b .1 \ The increment of :math:`w` produced by a spike. + c .07 \ Coefficient describes membrane potential update. + Larger than 0. + tau 10 ms Membrane time constant. + tau_w 10 ms Time constant of the adaptation current. + ============= ============== ======== ======================================================= + + **Model Variables** + + ================== ================= ========================================================== + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- ---------------------------------------------------------- + V 0 Membrane potential. + w 0 Adaptation current. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================== + + **References** + + .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking + neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. + .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of + nonlinear integrate-and-fire neurons." SIAM Journal on Applied + Mathematics 68, no. 4 (2008): 1045-1079. + """ + + def __init__( + self, + *args, + input_var: bool = True, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class GIF(lif.GifRef): - r"""Generalized Integrate-and-Fire model. - - **Model Descriptions** - - The generalized integrate-and-fire model [1]_ is given by - - .. math:: - - &\frac{d I_j}{d t} = - k_j I_j - - &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau - - &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) - - When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: - - .. math:: - - &I_j \leftarrow R_j I_j + A_j - - &V \leftarrow V_{reset} - - &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) - - Note that :math:`I_j` refers to arbitrary number of internal currents. - - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ==================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------- - V_rest -70 mV Resting potential. - V_reset -70 mV Reset potential after spike. - V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. - V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. - R 20 \ Membrane resistance. - tau 20 ms Membrane time constant. Compute by :math:`R * C`. - a 0 \ Coefficient describes the dependence of - :math:`V_{th}` on membrane potential. - b 0.01 \ Coefficient describes :math:`V_{th}` update. - k1 0.2 \ Constant pf :math:`I1`. - k2 0.02 \ Constant of :math:`I2`. - R1 0 \ Free parameter. - Describes dependence of :math:`I_1` reset value on - :math:`I_1` value before spiking. - R2 1 \ Free parameter. - Describes dependence of :math:`I_2` reset value on - :math:`I_2` value before spiking. - A1 0 \ Free parameter. - A2 0 \ Free parameter. - ============= ============== ======== ==================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -70 Membrane potential. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - V_th -50 Spiking threshold potential. - I1 0 Internal current 1. - I2 0 Internal current 2. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - **References** - - .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear - integrate-and-fire neural model produces diverse spiking - behaviors." Neural computation 21.3 (2009): 704-718. - .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan - Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized - leaky integrate-and-fire models classify multiple neuron types." - Nature communications 9, no. 1 (2018): 1-15. - """ - - def __init__( - self, - *args, - input_var: bool = True, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - self.reset_state(self.mode) - - def reset_state(self, batch_size=None): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""Generalized Integrate-and-Fire model. + + **Model Descriptions** + + The generalized integrate-and-fire model [1]_ is given by + + .. math:: + + &\frac{d I_j}{d t} = - k_j I_j + + &\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau + + &\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) + + When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: + + .. math:: + + &I_j \leftarrow R_j I_j + A_j + + &V \leftarrow V_{reset} + + &V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) + + Note that :math:`I_j` refers to arbitrary number of internal currents. + + **Model Examples** + + - `Detailed examples to reproduce different firing patterns `_ + + **Model Parameters** + + ============= ============== ======== ==================================================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------- + V_rest -70 mV Resting potential. + V_reset -70 mV Reset potential after spike. + V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. + V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. + R 20 \ Membrane resistance. + tau 20 ms Membrane time constant. Compute by :math:`R * C`. + a 0 \ Coefficient describes the dependence of + :math:`V_{th}` on membrane potential. + b 0.01 \ Coefficient describes :math:`V_{th}` update. + k1 0.2 \ Constant pf :math:`I1`. + k2 0.02 \ Constant of :math:`I2`. + R1 0 \ Free parameter. + Describes dependence of :math:`I_1` reset value on + :math:`I_1` value before spiking. + R2 1 \ Free parameter. + Describes dependence of :math:`I_2` reset value on + :math:`I_2` value before spiking. + A1 0 \ Free parameter. + A2 0 \ Free parameter. + ============= ============== ======== ==================================================================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -70 Membrane potential. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + V_th -50 Spiking threshold potential. + I1 0 Internal current 1. + I2 0 Internal current 2. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear + integrate-and-fire neural model produces diverse spiking + behaviors." Neural computation 21.3 (2009): 704-718. + .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan + Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized + leaky integrate-and-fire models classify multiple neuron types." + Nature communications 9, no. 1 (2018): 1-15. + """ + + def __init__( + self, + *args, + input_var: bool = True, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class Izhikevich(lif.IzhikevichRef): - r"""The Izhikevich neuron model. - - **Model Descriptions** - - The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: - - .. math :: - - \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I - - \frac{d u}{d t} &=a(b V-u) - - .. math :: - - \text{if} v \geq 30 \text{mV}, \text{then} - \begin{cases} v \leftarrow c \\ - u \leftarrow u+d \end{cases} - - **Model Examples** - - - `Detailed examples to reproduce different firing patterns `_ - - **Model Parameters** - - ============= ============== ======== ================================================================================ - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- -------------------------------------------------------------------------------- - a 0.02 \ It determines the time scale of - the recovery variable :math:`u`. - b 0.2 \ It describes the sensitivity of the - recovery variable :math:`u` to - the sub-threshold fluctuations of the - membrane potential :math:`v`. - c -65 \ It describes the after-spike reset value - of the membrane potential :math:`v` caused by - the fast high-threshold :math:`K^{+}` - conductance. - d 8 \ It describes after-spike reset of the - recovery variable :math:`u` - caused by slow high-threshold - :math:`Na^{+}` and :math:`K^{+}` conductance. - tau_ref 0 ms Refractory period length. [ms] - V_th 30 mV The membrane potential threshold. - ============= ============== ======== ================================================================================ - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V -65 Membrane potential. - u 1 Recovery variable. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - refractory False Flag to mark whether the neuron is in refractory period. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - **References** - - .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE - Transactions on neural networks 14.6 (2003): 1569-1572. - - .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." - IEEE transactions on neural networks 15.5 (2004): 1063-1070. - """ - - def __init__( - self, - *args, - input_var: bool = True, - spike_fun: Callable = None, - **kwargs, - ): - self.input_var = input_var - if spike_fun is not None: - kwargs['spk_fun'] = spike_fun - super().__init__(*args, **kwargs, init_var=False) - self.reset_state(self.mode) - - def reset_state(self, batch_size=None, **kwargs): - super().reset_state(batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - return super().update(x) - - def clear_input(self): - if self.input_var: - self.input.value = bm.zeros_like(self.input) + r"""The Izhikevich neuron model. + + **Model Descriptions** + + The dynamics of the Izhikevich neuron model [1]_ [2]_ is given by: + + .. math :: + + \frac{d V}{d t} &= 0.04 V^{2}+5 V+140-u+I + + \frac{d u}{d t} &=a(b V-u) + + .. math :: + + \text{if} v \geq 30 \text{mV}, \text{then} + \begin{cases} v \leftarrow c \\ + u \leftarrow u+d \end{cases} + + **Model Examples** + + - `Detailed examples to reproduce different firing patterns `_ + + **Model Parameters** + + ============= ============== ======== ================================================================================ + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- -------------------------------------------------------------------------------- + a 0.02 \ It determines the time scale of + the recovery variable :math:`u`. + b 0.2 \ It describes the sensitivity of the + recovery variable :math:`u` to + the sub-threshold fluctuations of the + membrane potential :math:`v`. + c -65 \ It describes the after-spike reset value + of the membrane potential :math:`v` caused by + the fast high-threshold :math:`K^{+}` + conductance. + d 8 \ It describes after-spike reset of the + recovery variable :math:`u` + caused by slow high-threshold + :math:`Na^{+}` and :math:`K^{+}` conductance. + tau_ref 0 ms Refractory period length. [ms] + V_th 30 mV The membrane potential threshold. + ============= ============== ======== ================================================================================ + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V -65 Membrane potential. + u 1 Recovery variable. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + refractory False Flag to mark whether the neuron is in refractory period. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] Izhikevich, Eugene M. "Simple model of spiking neurons." IEEE + Transactions on neural networks 14.6 (2003): 1569-1572. + + .. [2] Izhikevich, Eugene M. "Which model to use for cortical spiking neurons?." + IEEE transactions on neural networks 15.5 (2004): 1063-1070. + """ + + def __init__( + self, + *args, + input_var: bool = True, + spike_fun: Callable = None, + **kwargs, + ): + self.input_var = input_var + if spike_fun is not None: + kwargs['spk_fun'] = spike_fun + super().__init__(*args, **kwargs, init_var=False) + self.reset_state(self.mode) + + def reset_state(self, batch_size=None, **kwargs): + super().reset_state(batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + return super().update(x) + + def clear_input(self): + if self.input_var: + self.input.value = bm.zeros_like(self.input) class HindmarshRose(NeuDyn): - r"""Hindmarsh-Rose neuron model. - - **Model Descriptions** - - The Hindmarsh–Rose model [1]_ [2]_ of neuronal activity is aimed to study the - spiking-bursting behavior of the membrane potential observed in experiments - made with a single neuron. - - The model has the mathematical form of a system of three nonlinear ordinary - differential equations on the dimensionless dynamical variables :math:`x(t)`, - :math:`y(t)`, and :math:`z(t)`. They read: - - .. math:: - - \begin{aligned} - \frac{d V}{d t} &= y - a V^3 + b V^2 - z + I \\ - \frac{d y}{d t} &= c - d V^2 - y \\ - \frac{d z}{d t} &= r (s (V - V_{rest}) - z) - \end{aligned} - - where :math:`a, b, c, d` model the working of the fast ion channels, - :math:`I` models the slow ion channels. - - **Model Examples** - - >>> import brainpy.math as bm - >>> import brainpy as bp - >>> import matplotlib.pyplot as plt - >>> - >>> bp.math.set_dt(dt=0.01) - >>> bp.ode.set_default_odeint('rk4') - >>> - >>> types = ['quiescence', 'spiking', 'bursting', 'irregular_spiking', 'irregular_bursting'] - >>> bs = bm.array([1.0, 3.5, 2.5, 2.95, 2.8]) - >>> Is = bm.array([2.0, 5.0, 3.0, 3.3, 3.7]) - >>> - >>> # define neuron type - >>> group = bp.neurons.HindmarshRose(len(types), b=bs) - >>> runner = bp.DSRunner(group, monitors=['V'], inputs=['input', Is],) - >>> runner.run(1e3) - >>> - >>> fig, gs = bp.visualize.get_figure(row_num=3, col_num=2, row_len=3, col_len=5) - >>> for i, mode in enumerate(types): - >>> fig.add_subplot(gs[i // 2, i % 2]) - >>> plt.plot(runner.mon.ts, runner.mon.V[:, i]) - >>> plt.title(mode) - >>> plt.xlabel('Time [ms]') - >>> plt.show() - - **Model Parameters** - - ============= ============== ========= ============================================================ - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- --------- ------------------------------------------------------------ - a 1 \ Model parameter. - Fixed to a value best fit neuron activity. - b 3 \ Model parameter. - Allows the model to switch between bursting - and spiking, controls the spiking frequency. - c 1 \ Model parameter. - Fixed to a value best fit neuron activity. - d 5 \ Model parameter. - Fixed to a value best fit neuron activity. - r 0.01 \ Model parameter. - Controls slow variable z's variation speed. - Governs spiking frequency when spiking, and affects the - number of spikes per burst when bursting. - s 4 \ Model parameter. Governs adaption. - ============= ============== ========= ============================================================ - - **Model Variables** - - =============== ================= ===================================== - **Member name** **Initial Value** **Explanation** - --------------- ----------------- ------------------------------------- - V -1.6 Membrane potential. - y -10 Gating variable. - z 0 Gating variable. - spike False Whether generate the spikes. - input 0 External and synaptic input current. - t_last_spike -1e7 Last spike time stamp. - =============== ================= ===================================== - - **References** - - .. [1] Hindmarsh, James L., and R. M. Rose. "A model of neuronal bursting using - three coupled first order differential equations." Proceedings of the - Royal society of London. Series B. Biological sciences 221.1222 (1984): - 87-102. - .. [2] Storace, Marco, Daniele Linaro, and Enno de Lange. "The Hindmarsh–Rose - neuron model: bifurcation analysis and piecewise-linear approximations." - Chaos: An Interdisciplinary Journal of Nonlinear Science 18.3 (2008): - 033128. - """ - - def __init__( - self, - size: Shape, - a: Union[float, ArrayType, Initializer, Callable] = 1., - b: Union[float, ArrayType, Initializer, Callable] = 3., - c: Union[float, ArrayType, Initializer, Callable] = 1., - d: Union[float, ArrayType, Initializer, Callable] = 5., - r: Union[float, ArrayType, Initializer, Callable] = 0.01, - s: Union[float, ArrayType, Initializer, Callable] = 4., - V_rest: Union[float, ArrayType, Initializer, Callable] = -1.6, - V_th: Union[float, ArrayType, Initializer, Callable] = 1.0, - V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - y_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-10.), - z_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - noise: Union[float, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - keep_size: bool = False, - input_var: bool = True, - name: str = None, - - # parameters for training - mode: bm.Mode = None, - spike_fun: Callable = bm.surrogate.inv_square_grad2, - ): - # initialization - super(HindmarshRose, self).__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) - - # parameters - self.a = parameter(a, self.varshape, allow_none=False) - self.b = parameter(b, self.varshape, allow_none=False) - self.c = parameter(c, self.varshape, allow_none=False) - self.d = parameter(d, self.varshape, allow_none=False) - self.r = parameter(r, self.varshape, allow_none=False) - self.s = parameter(s, self.varshape, allow_none=False) - self.V_th = parameter(V_th, self.varshape, allow_none=False) - self.V_rest = parameter(V_rest, self.varshape, allow_none=False) - self.noise = init_noise(noise, self.varshape, num_vars=3) - self.spike_fun = is_callable(spike_fun, 'spike_fun') - self.input_var = input_var - - # variables - is_initializer(V_initializer, 'V_initializer', allow_none=False) - is_initializer(y_initializer, 'y_initializer', allow_none=False) - is_initializer(z_initializer, 'z_initializer', allow_none=False) - self._V_initializer = V_initializer - self._y_initializer = y_initializer - self._z_initializer = z_initializer - - # variables - self.reset_state(self.mode) - - # integral - if self.noise is None: - self.integral = odeint(method=method, f=self.derivative) - else: - self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - - def reset_state(self, batch_size=None): - self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.y = variable_(self._y_initializer, self.varshape, batch_size) - self.z = variable_(self._z_initializer, self.varshape, batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - - def dV(self, V, t, y, z, I_ext): - return y - self.a * V * V * V + self.b * V * V - z + I_ext - - def dy(self, y, t, V): - return self.c - self.d * V * V - y - - def dz(self, z, t, V): - return self.r * (self.s * (V - self.V_rest) - z) - - @property - def derivative(self): - return JointEq([self.dV, self.dy, self.dz]) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - V, y, z = self.integral(self.V.value, self.y.value, self.z.value, t, x, dt=dt) - if isinstance(self.mode, bm.TrainingMode): - self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th) - else: - self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) - self.V.value = V - self.y.value = y - self.z.value = z - return self.spike.value - - def clear_input(self): - if self.input_var: - self.input[:] = 0. + r"""Hindmarsh-Rose neuron model. + + **Model Descriptions** + + The Hindmarsh–Rose model [1]_ [2]_ of neuronal activity is aimed to study the + spiking-bursting behavior of the membrane potential observed in experiments + made with a single neuron. + + The model has the mathematical form of a system of three nonlinear ordinary + differential equations on the dimensionless dynamical variables :math:`x(t)`, + :math:`y(t)`, and :math:`z(t)`. They read: + + .. math:: + + \begin{aligned} + \frac{d V}{d t} &= y - a V^3 + b V^2 - z + I \\ + \frac{d y}{d t} &= c - d V^2 - y \\ + \frac{d z}{d t} &= r (s (V - V_{rest}) - z) + \end{aligned} + + where :math:`a, b, c, d` model the working of the fast ion channels, + :math:`I` models the slow ion channels. + + **Model Examples** + + >>> import brainpy.math as bm + >>> import brainpy as bp + >>> import matplotlib.pyplot as plt + >>> + >>> bp.math.set_dt(dt=0.01) + >>> bp.ode.set_default_odeint('rk4') + >>> + >>> types = ['quiescence', 'spiking', 'bursting', 'irregular_spiking', 'irregular_bursting'] + >>> bs = bm.array([1.0, 3.5, 2.5, 2.95, 2.8]) + >>> Is = bm.array([2.0, 5.0, 3.0, 3.3, 3.7]) + >>> + >>> # define neuron type + >>> group = bp.neurons.HindmarshRose(len(types), b=bs) + >>> runner = bp.DSRunner(group, monitors=['V'], inputs=['input', Is],) + >>> runner.run(1e3) + >>> + >>> fig, gs = bp.visualize.get_figure(row_num=3, col_num=2, row_len=3, col_len=5) + >>> for i, mode in enumerate(types): + >>> fig.add_subplot(gs[i // 2, i % 2]) + >>> plt.plot(runner.mon.ts, runner.mon.V[:, i]) + >>> plt.title(mode) + >>> plt.xlabel('Time [ms]') + >>> plt.show() + + **Model Parameters** + + ============= ============== ========= ============================================================ + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- --------- ------------------------------------------------------------ + a 1 \ Model parameter. + Fixed to a value best fit neuron activity. + b 3 \ Model parameter. + Allows the model to switch between bursting + and spiking, controls the spiking frequency. + c 1 \ Model parameter. + Fixed to a value best fit neuron activity. + d 5 \ Model parameter. + Fixed to a value best fit neuron activity. + r 0.01 \ Model parameter. + Controls slow variable z's variation speed. + Governs spiking frequency when spiking, and affects the + number of spikes per burst when bursting. + s 4 \ Model parameter. Governs adaption. + ============= ============== ========= ============================================================ + + **Model Variables** + + =============== ================= ===================================== + **Member name** **Initial Value** **Explanation** + --------------- ----------------- ------------------------------------- + V -1.6 Membrane potential. + y -10 Gating variable. + z 0 Gating variable. + spike False Whether generate the spikes. + input 0 External and synaptic input current. + t_last_spike -1e7 Last spike time stamp. + =============== ================= ===================================== + + **References** + + .. [1] Hindmarsh, James L., and R. M. Rose. "A model of neuronal bursting using + three coupled first order differential equations." Proceedings of the + Royal society of London. Series B. Biological sciences 221.1222 (1984): + 87-102. + .. [2] Storace, Marco, Daniele Linaro, and Enno de Lange. "The Hindmarsh–Rose + neuron model: bifurcation analysis and piecewise-linear approximations." + Chaos: An Interdisciplinary Journal of Nonlinear Science 18.3 (2008): + 033128. + """ + + def __init__( + self, + size: Shape, + a: Union[float, ArrayType, Initializer, Callable] = 1., + b: Union[float, ArrayType, Initializer, Callable] = 3., + c: Union[float, ArrayType, Initializer, Callable] = 1., + d: Union[float, ArrayType, Initializer, Callable] = 5., + r: Union[float, ArrayType, Initializer, Callable] = 0.01, + s: Union[float, ArrayType, Initializer, Callable] = 4., + V_rest: Union[float, ArrayType, Initializer, Callable] = -1.6, + V_th: Union[float, ArrayType, Initializer, Callable] = 1.0, + V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + y_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-10.), + z_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + noise: Union[float, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + keep_size: bool = False, + input_var: bool = True, + name: str = None, + + # parameters for training + mode: bm.Mode = None, + spike_fun: Callable = bm.surrogate.inv_square_grad2, + ): + # initialization + super(HindmarshRose, self).__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) + + # parameters + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.c = parameter(c, self.varshape, allow_none=False) + self.d = parameter(d, self.varshape, allow_none=False) + self.r = parameter(r, self.varshape, allow_none=False) + self.s = parameter(s, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=3) + self.spike_fun = is_callable(spike_fun, 'spike_fun') + self.input_var = input_var + + # variables + is_initializer(V_initializer, 'V_initializer', allow_none=False) + is_initializer(y_initializer, 'y_initializer', allow_none=False) + is_initializer(z_initializer, 'z_initializer', allow_none=False) + self._V_initializer = V_initializer + self._y_initializer = y_initializer + self._z_initializer = z_initializer + + # variables + self.reset_state(self.mode) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + def reset_state(self, batch_size=None): + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.y = variable_(self._y_initializer, self.varshape, batch_size) + self.z = variable_(self._z_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + + def dV(self, V, t, y, z, I_ext): + return y - self.a * V * V * V + self.b * V * V - z + I_ext + + def dy(self, y, t, V): + return self.c - self.d * V * V - y + + def dz(self, z, t, V): + return self.r * (self.s * (V - self.V_rest) - z) + + @property + def derivative(self): + return JointEq([self.dV, self.dy, self.dz]) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, y, z = self.integral(self.V.value, self.y.value, self.z.value, t, x, dt=dt) + if isinstance(self.mode, bm.TrainingMode): + self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th) + else: + self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th) + self.V.value = V + self.y.value = y + self.z.value = z + return self.spike.value + + def clear_input(self): + if self.input_var: + self.input[:] = 0. class FHN(NeuDyn): - r"""FitzHugh-Nagumo neuron model. - - **Model Descriptions** - - The FitzHugh–Nagumo model (FHN), named after Richard FitzHugh (1922–2007) - who suggested the system in 1961 [1]_ and J. Nagumo et al. who created the - equivalent circuit the following year, describes a prototype of an excitable - system (e.g., a neuron). - - The motivation for the FitzHugh-Nagumo model was to isolate conceptually - the essentially mathematical properties of excitation and propagation from - the electrochemical properties of sodium and potassium ion flow. The model - consists of - - - a *voltage-like variable* having cubic nonlinearity that allows regenerative - self-excitation via a positive feedback, and - - a *recovery variable* having a linear dynamics that provides a slower negative feedback. - - .. math:: - - \begin{aligned} - {\dot {v}} &=v-{\frac {v^{3}}{3}}-w+RI_{\rm {ext}}, \\ - \tau {\dot {w}}&=v+a-bw. - \end{aligned} - - The FHN Model is an example of a relaxation oscillator - because, if the external stimulus :math:`I_{\text{ext}}` - exceeds a certain threshold value, the system will exhibit - a characteristic excursion in phase space, before the - variables :math:`v` and :math:`w` relax back to their rest values. - This behaviour is typical for spike generations (a short, - nonlinear elevation of membrane voltage :math:`v`, - diminished over time by a slower, linear recovery variable - :math:`w`) in a neuron after stimulation by an external - input current. - - **Model Examples** - - >>> import brainpy as bp - >>> fhn = bp.neurons.FHN(1) - >>> runner = bp.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w']) - >>> runner.run(100.) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w') - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) - - **Model Parameters** - - ============= ============== ======== ======================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------ - a 1 \ Positive constant - b 1 \ Positive constant - tau 10 ms Membrane time constant. - V_th 1.8 mV Threshold potential of spike. - ============= ============== ======== ======================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - V 0 Membrane potential. - w 0 A recovery variable which represents - the combined effects of sodium channel - de-inactivation and potassium channel - deactivation. - input 0 External and synaptic input current. - spike False Flag to mark whether the neuron is spiking. - t_last_spike -1e7 Last spike time stamp. - ================== ================= ========================================================= - - **References** - - .. [1] FitzHugh, Richard. "Impulses and physiological states in theoretical models of nerve membrane." Biophysical journal 1.6 (1961): 445-466. - .. [2] https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model - .. [3] http://www.scholarpedia.org/article/FitzHugh-Nagumo_model - - """ - - def __init__( - self, - size: Shape, - a: Union[float, ArrayType, Initializer, Callable] = 0.7, - b: Union[float, ArrayType, Initializer, Callable] = 0.8, - tau: Union[float, ArrayType, Initializer, Callable] = 12.5, - Vth: Union[float, ArrayType, Initializer, Callable] = 1.8, - V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - noise: Union[float, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - keep_size: bool = False, - input_var: bool = True, - name: str = None, - - # parameters for training - mode: bm.Mode = None, - ): - # initialization - super(FHN, self).__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) - is_subclass(self.mode, (bm.NonBatchingMode,)) - - # parameters - self.a = parameter(a, self.varshape, allow_none=False) - self.b = parameter(b, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.Vth = parameter(Vth, self.varshape, allow_none=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - self.input_var = input_var - - # initializers - is_initializer(V_initializer, 'V_initializer') - is_initializer(w_initializer, 'w_initializer') - self._V_initializer = V_initializer - self._w_initializer = w_initializer - - # variables - self.reset_state(self.mode) - - # integral - if self.noise is None: - self.integral = odeint(method=method, f=self.derivative) - else: - self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - - def reset_state(self, batch_size=None): - self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.w = variable_(self._w_initializer, self.varshape, batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - - def dV(self, V, t, w, I_ext): - return V - V * V * V / 3 - w + I_ext - - def dw(self, w, t, V): - return (V + self.a - self.b * w) / self.tau - - @property - def derivative(self): - return JointEq([self.dV, self.dw]) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - V, w = self.integral(self.V.value, self.w.value, t, x, dt=dt) - self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) - self.V.value = V - self.w.value = w - return self.spike.value - - def clear_input(self): - if self.input_var: - self.input[:] = 0. + r"""FitzHugh-Nagumo neuron model. + + **Model Descriptions** + + The FitzHugh–Nagumo model (FHN), named after Richard FitzHugh (1922–2007) + who suggested the system in 1961 [1]_ and J. Nagumo et al. who created the + equivalent circuit the following year, describes a prototype of an excitable + system (e.g., a neuron). + + The motivation for the FitzHugh-Nagumo model was to isolate conceptually + the essentially mathematical properties of excitation and propagation from + the electrochemical properties of sodium and potassium ion flow. The model + consists of + + - a *voltage-like variable* having cubic nonlinearity that allows regenerative + self-excitation via a positive feedback, and + - a *recovery variable* having a linear dynamics that provides a slower negative feedback. + + .. math:: + + \begin{aligned} + {\dot {v}} &=v-{\frac {v^{3}}{3}}-w+RI_{\rm {ext}}, \\ + \tau {\dot {w}}&=v+a-bw. + \end{aligned} + + The FHN Model is an example of a relaxation oscillator + because, if the external stimulus :math:`I_{\text{ext}}` + exceeds a certain threshold value, the system will exhibit + a characteristic excursion in phase space, before the + variables :math:`v` and :math:`w` relax back to their rest values. + This behaviour is typical for spike generations (a short, + nonlinear elevation of membrane voltage :math:`v`, + diminished over time by a slower, linear recovery variable + :math:`w`) in a neuron after stimulation by an external + input current. + + **Model Examples** + + >>> import brainpy as bp + >>> fhn = bp.neurons.FHN(1) + >>> runner = bp.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w']) + >>> runner.run(100.) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w') + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) + + **Model Parameters** + + ============= ============== ======== ======================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------ + a 1 \ Positive constant + b 1 \ Positive constant + tau 10 ms Membrane time constant. + V_th 1.8 mV Threshold potential of spike. + ============= ============== ======== ======================== + + **Model Variables** + + ================== ================= ========================================================= + **Variables name** **Initial Value** **Explanation** + ------------------ ----------------- --------------------------------------------------------- + V 0 Membrane potential. + w 0 A recovery variable which represents + the combined effects of sodium channel + de-inactivation and potassium channel + deactivation. + input 0 External and synaptic input current. + spike False Flag to mark whether the neuron is spiking. + t_last_spike -1e7 Last spike time stamp. + ================== ================= ========================================================= + + **References** + + .. [1] FitzHugh, Richard. "Impulses and physiological states in theoretical models of nerve membrane." Biophysical journal 1.6 (1961): 445-466. + .. [2] https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model + .. [3] http://www.scholarpedia.org/article/FitzHugh-Nagumo_model + + """ + + def __init__( + self, + size: Shape, + a: Union[float, ArrayType, Initializer, Callable] = 0.7, + b: Union[float, ArrayType, Initializer, Callable] = 0.8, + tau: Union[float, ArrayType, Initializer, Callable] = 12.5, + Vth: Union[float, ArrayType, Initializer, Callable] = 1.8, + V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + w_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + noise: Union[float, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + keep_size: bool = False, + input_var: bool = True, + name: str = None, + + # parameters for training + mode: bm.Mode = None, + ): + # initialization + super(FHN, self).__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) + is_subclass(self.mode, (bm.NonBatchingMode,)) + + # parameters + self.a = parameter(a, self.varshape, allow_none=False) + self.b = parameter(b, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.Vth = parameter(Vth, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape, num_vars=2) + self.input_var = input_var + + # initializers + is_initializer(V_initializer, 'V_initializer') + is_initializer(w_initializer, 'w_initializer') + self._V_initializer = V_initializer + self._w_initializer = w_initializer + + # variables + self.reset_state(self.mode) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + def reset_state(self, batch_size=None): + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.w = variable_(self._w_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + + def dV(self, V, t, w, I_ext): + return V - V * V * V / 3 - w + I_ext + + def dw(self, w, t, V): + return (V + self.a - self.b * w) / self.tau + + @property + def derivative(self): + return JointEq([self.dV, self.dw]) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, w = self.integral(self.V.value, self.w.value, t, x, dt=dt) + self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) + self.V.value = V + self.w.value = w + return self.spike.value + + def clear_input(self): + if self.input_var: + self.input[:] = 0. class ALIFBellec2020(NeuDyn): - r"""Leaky Integrate-and-Fire model with SFA [1]_. - - This model is similar to the GLIF2 model in the Technical White Paper - on generalized LIF (GLIF) models from AllenInstitute [2]_. - - Formally, this model is given by: - - .. math:: - - \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\ - \tau_a \dot{a} = -a - - Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then - - .. math:: - - V \gets V - V_{\mathrm{th}} \\ - a \gets a + 1 - - - References:: - - .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for - recurrent networks of spiking neurons." - Nature communications 11.1 (2020): 1-15. - .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for - Brain Science. Allen Cell Types Database, cell feature search. - Available from: celltypes.brain-map.org/data (2018). - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # model parameters - V_rest: Union[float, ArrayType, Initializer, Callable] = -70., - V_th: Union[float, ArrayType, Initializer, Callable] = -60., - R: Union[float, ArrayType, Initializer, Callable] = 1., - beta: Union[float, ArrayType, Initializer, Callable] = 1.6, - tau: Union[float, ArrayType, Initializer, Callable] = 20., - tau_a: Union[float, ArrayType, Initializer, Callable] = 2000., - tau_ref: Union[float, ArrayType, Initializer, Callable] = None, - noise: Union[float, ArrayType, Initializer, Callable] = None, - - # initializers - V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.), - a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.), - - # parameter for training - spike_fun: Callable = bm.surrogate.relu_grad, - input_var: bool = True, - - # other parameters - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - eprop: bool = False - ): - super().__init__(name=name, - size=size, - keep_size=keep_size, - mode=mode) - is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) - - # parameters - self.V_rest = parameter(V_rest, self.varshape, allow_none=False) - self.V_th = parameter(V_th, self.varshape, allow_none=False) - self.R = parameter(R, self.varshape, allow_none=False) - self.beta = parameter(beta, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.tau_a = parameter(tau_a, self.varshape, allow_none=False) - self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) - self.noise = init_noise(noise, self.varshape, num_vars=2) - self.spike_fun = is_callable(spike_fun, 'spike_fun') - self.eprop = eprop - self.input_var = input_var - - # initializers - self._V_initializer = is_initializer(V_initializer, 'V_initializer') - self._a_initializer = is_initializer(a_initializer, 'a_initializer') - - # variables - self.reset_state(self.mode) - - # integral - if self.noise is None: - self.integral = odeint(method=method, f=self.derivative) - else: - self.integral = sdeint(method=method, f=self.derivative, g=self.noise) - - def da(self, a, t): - return -a / self.tau_a - - def dV(self, V, t, I_ext): - return (- (V - self.V_rest) + self.R * I_ext) / self.tau - - @property - def derivative(self): - return JointEq([self.dV, self.da]) - - def reset_state(self, batch_size=None): - self.a = variable_(self._a_initializer, self.varshape, batch_size) - self.V = variable_(self._V_initializer, self.varshape, batch_size) - if self.input_var: - self.input = variable_(bm.zeros, self.varshape, batch_size) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - if self.input_var: - if x is not None: - self.input += x - x = self.input.value - else: - x = 0. if x is None else x - V, a = self.integral(self.V.value, self.a.value, t, x, dt) - - if self.tau_ref is not None: - # refractory - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) - V = bm.where(refractory, self.V.value, V) - # spike and reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) - V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) - # will be used in other place, like Delta Synapse, so stop its gradient - spike_ = spike > 0. - refractory = stop_gradient(bm.logical_or(refractory, spike_)) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) - else: - spike = V >= (self.V_th + self.beta * self.a) - refractory = bm.logical_or(refractory, spike) - t_last_spike = bm.where(spike, t, self.t_last_spike.value) - V -= self.V_th * spike - self.refractory.value = refractory - self.t_last_spike.value = t_last_spike - - else: - # spike and reset - if isinstance(self.mode, bm.TrainingMode): - spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) - V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) - else: - spike = V >= (self.V_th + self.beta * self.a) - V -= self.V_th * spike - self.spike.value = spike - self.V.value = V - self.a.value = a + spike - return spike - - def clear_input(self): - if self.input_var: - self.input[:] = 0. + r"""Leaky Integrate-and-Fire model with SFA [1]_. + + This model is similar to the GLIF2 model in the Technical White Paper + on generalized LIF (GLIF) models from AllenInstitute [2]_. + + Formally, this model is given by: + + .. math:: + + \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\ + \tau_a \dot{a} = -a + + Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then + + .. math:: + + V \gets V - V_{\mathrm{th}} \\ + a \gets a + 1 + + + References:: + + .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for + recurrent networks of spiking neurons." + Nature communications 11.1 (2020): 1-15. + .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for + Brain Science. Allen Cell Types Database, cell feature search. + Available from: celltypes.brain-map.org/data (2018). + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + V_rest: Union[float, ArrayType, Initializer, Callable] = -70., + V_th: Union[float, ArrayType, Initializer, Callable] = -60., + R: Union[float, ArrayType, Initializer, Callable] = 1., + beta: Union[float, ArrayType, Initializer, Callable] = 1.6, + tau: Union[float, ArrayType, Initializer, Callable] = 20., + tau_a: Union[float, ArrayType, Initializer, Callable] = 2000., + tau_ref: Union[float, ArrayType, Initializer, Callable] = None, + noise: Union[float, ArrayType, Initializer, Callable] = None, + + # initializers + V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.), + a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.), + + # parameter for training + spike_fun: Callable = bm.surrogate.relu_grad, + input_var: bool = True, + + # other parameters + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + eprop: bool = False + ): + super().__init__(name=name, + size=size, + keep_size=keep_size, + mode=mode) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_a = parameter(tau_a, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.noise = init_noise(noise, self.varshape, num_vars=2) + self.spike_fun = is_callable(spike_fun, 'spike_fun') + self.eprop = eprop + self.input_var = input_var + + # initializers + self._V_initializer = is_initializer(V_initializer, 'V_initializer') + self._a_initializer = is_initializer(a_initializer, 'a_initializer') + + # variables + self.reset_state(self.mode) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + def da(self, a, t): + return -a / self.tau_a + + def dV(self, V, t, I_ext): + return (- (V - self.V_rest) + self.R * I_ext) / self.tau + + @property + def derivative(self): + return JointEq([self.dV, self.da]) + + def reset_state(self, batch_size=None): + self.a = variable_(self._a_initializer, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + if self.tau_ref is not None: + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, a = self.integral(self.V.value, self.a.value, t, x, dt) + + if self.tau_ref is not None: + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + # spike and reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) + # will be used in other place, like Delta Synapse, so stop its gradient + spike_ = spike > 0. + refractory = stop_gradient(bm.logical_or(refractory, spike_)) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + else: + spike = V >= (self.V_th + self.beta * self.a) + refractory = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + V -= self.V_th * spike + self.refractory.value = refractory + self.t_last_spike.value = t_last_spike + + else: + # spike and reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * (stop_gradient(spike) if self.eprop else spike) + else: + spike = V >= (self.V_th + self.beta * self.a) + V -= self.V_th * spike + self.spike.value = spike + self.V.value = V + self.a.value = a + spike + return spike + + def clear_input(self): + if self.input_var: + self.input[:] = 0. class LIF_SFA_Bellec2020(NeuDyn): - r"""Leaky Integrate-and-Fire model with SFA [1]_. + r"""Leaky Integrate-and-Fire model with SFA [1]_. - This model is similar to the GLIF2 model in the Technical White Paper - on generalized LIF (GLIF) models from AllenInstitute [2]_. + This model is similar to the GLIF2 model in the Technical White Paper + on generalized LIF (GLIF) models from AllenInstitute [2]_. - Formally, this model is given by: + Formally, this model is given by: - .. math:: - - \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\ - \tau_a \dot{a} = -a - - Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then - - .. math:: - - V \gets V - V_{\mathrm{th}} \\ - a \gets a + 1 - - - References:: - - .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for - recurrent networks of spiking neurons." - Nature communications 11.1 (2020): 1-15. - .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for - Brain Science. Allen Cell Types Database, cell feature search. - Available from: celltypes.brain-map.org/data (2018). - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # model parameters - V_rest: Union[float, ArrayType, Initializer, Callable] = -70., - V_th: Union[float, ArrayType, Initializer, Callable] = -60., - R: Union[float, ArrayType, Initializer, Callable] = 1., - beta: Union[float, ArrayType, Initializer, Callable] = 1.6, - tau: Union[float, ArrayType, Initializer, Callable] = 20., - tau_a: Union[float, ArrayType, Initializer, Callable] = 2000., - tau_ref: Union[float, ArrayType, Initializer, Callable] = None, - - # initializers - V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.), - a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.), - - # parameter for training - spike_fun: Callable = bm.surrogate.relu_grad, - - # other parameters - method: str = 'exp_auto', - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(name=name, size=size, keep_size=keep_size, mode=mode) - is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) - - # parameters - self.V_rest = parameter(V_rest, self.varshape, allow_none=False) - self.V_th = parameter(V_th, self.varshape, allow_none=False) - self.R = parameter(R, self.varshape, allow_none=False) - self.beta = parameter(beta, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.tau_a = parameter(tau_a, self.varshape, allow_none=False) - self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) - self.spike_fun = is_callable(spike_fun, 'spike_fun') - - # initializers - self._V_initializer = is_initializer(V_initializer) - self._a_initializer = is_initializer(a_initializer) - - # variables - self.reset_state(self.mode) - - # integral - self.integral = odeint(method=method, f=self.derivative) - - def da(self, a, t): - return -a / self.tau_a - - def dV(self, V, t, I_ext): - return (- (V - self.V_rest) + self.R * I_ext) / self.tau - - @property - def derivative(self): - return JointEq([self.dV, self.da]) - - def reset_state(self, batch_size=None): - self.a = variable_(self._a_initializer, self.varshape, batch_size) - self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.spike = variable_(bm.zeros, self.varshape, batch_size) - if self.tau_ref is not None: - self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - x = 0. if x is None else x - - # integral - V, a = self.integral(self.V.value, self.a.value, t, x, dt) - - if self.tau_ref is not None: - # refractory - refractory = stop_gradient((t - self.t_last_spike) <= self.tau_ref) - V = bm.where(refractory, self.V.value, V) - # spike and reset - spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) - V -= self.V_th * spike - t_last_spike = stop_gradient(bm.where(spike, t, self.t_last_spike.value)) - self.t_last_spike.value = t_last_spike - - else: - # spike and reset - spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) - V -= self.V_th * spike - self.spike.value = spike - self.V.value = V - self.a.value = a + spike - return spike + .. math:: + + \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\ + \tau_a \dot{a} = -a + + Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then + + .. math:: + + V \gets V - V_{\mathrm{th}} \\ + a \gets a + 1 + + + References:: + + .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for + recurrent networks of spiking neurons." + Nature communications 11.1 (2020): 1-15. + .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for + Brain Science. Allen Cell Types Database, cell feature search. + Available from: celltypes.brain-map.org/data (2018). + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + V_rest: Union[float, ArrayType, Initializer, Callable] = -70., + V_th: Union[float, ArrayType, Initializer, Callable] = -60., + R: Union[float, ArrayType, Initializer, Callable] = 1., + beta: Union[float, ArrayType, Initializer, Callable] = 1.6, + tau: Union[float, ArrayType, Initializer, Callable] = 20., + tau_a: Union[float, ArrayType, Initializer, Callable] = 2000., + tau_ref: Union[float, ArrayType, Initializer, Callable] = None, + + # initializers + V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.), + a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.), + + # parameter for training + spike_fun: Callable = bm.surrogate.relu_grad, + + # other parameters + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(name=name, size=size, keep_size=keep_size, mode=mode) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_a = parameter(tau_a, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.spike_fun = is_callable(spike_fun, 'spike_fun') + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._a_initializer = is_initializer(a_initializer) + + # variables + self.reset_state(self.mode) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + def da(self, a, t): + return -a / self.tau_a + + def dV(self, V, t, I_ext): + return (- (V - self.V_rest) + self.R * I_ext) / self.tau + + @property + def derivative(self): + return JointEq([self.dV, self.da]) + + def reset_state(self, batch_size=None): + self.a = variable_(self._a_initializer, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.spike = variable_(bm.zeros, self.varshape, batch_size) + if self.tau_ref is not None: + self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integral + V, a = self.integral(self.V.value, self.a.value, t, x, dt) + + if self.tau_ref is not None: + # refractory + refractory = stop_gradient((t - self.t_last_spike) <= self.tau_ref) + V = bm.where(refractory, self.V.value, V) + # spike and reset + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * spike + t_last_spike = stop_gradient(bm.where(spike, t, self.t_last_spike.value)) + self.t_last_spike.value = t_last_spike + + else: + # spike and reset + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * spike + self.spike.value = spike + self.V.value = V + self.a.value = a + spike + return spike diff --git a/brainpy/_src/dynold/neurons/tests/test_biological_neurons.py b/brainpy/_src/dynold/neurons/tests/test_biological_neurons.py index 5cb267103..1e4206e3c 100644 --- a/brainpy/_src/dynold/neurons/tests/test_biological_neurons.py +++ b/brainpy/_src/dynold/neurons/tests/test_biological_neurons.py @@ -1,159 +1,148 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized from brainpy._src.dynold.neurons import biological_models class Test_Biological(parameterized.TestCase): - def test_HH(self): - bm.random.seed() - model = biological_models.HH(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - - def test_HH_with_noise(self): - bm.random.seed() - model = biological_models.HH(size=1, noise=0.1) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - - def test_HH_batching_mode(self): - bm.random.seed() - model = biological_models.HH(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'm', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - - def test_MorrisLecar(self): - bm.random.seed() - model = biological_models.MorrisLecar(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - - def test_MorrisLecar_with_noise(self): - bm.random.seed() - model = biological_models.MorrisLecar(size=1, noise=0.1) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - - def test_MorrisLecar_batching_mode(self): - bm.random.seed() - model = biological_models.MorrisLecar(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'W', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - - - def test_PinskyRinzelModel(self): - bm.random.seed() - model = biological_models.PinskyRinzelModel(size=1) - runner = bp.DSRunner(model, - monitors=['Vs', 'Vd'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['Vs'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Vd'].shape, (100, 1)) - - - def test_PinskyRinzelModel_with_noise(self): - bm.random.seed() - model = biological_models.PinskyRinzelModel(size=1, noise=0.1) - runner = bp.DSRunner(model, - monitors=['Vs', 'Vd'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['Vs'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['Vd'].shape, (100, 1)) - - - def test_PinskyRinzelModel_batching_mode(self): - bm.random.seed() - model = biological_models.PinskyRinzelModel(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['Vs', 'Vd'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['Vs'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['Vd'].shape, (1, 100, 10)) - - - def test_WangBuzsakiModel(self): - bm.random.seed() - model = biological_models.WangBuzsakiModel(size=1) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - - def test_WangBuzsakiModel_with_noise(self): - bm.random.seed() - model = biological_models.WangBuzsakiModel(size=1, noise=0.1) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - - - def test_WangBuzsakiModel_batching_mode(self): - bm.random.seed() - model = biological_models.WangBuzsakiModel(size=10, mode=bm.batching_mode) - runner = bp.DSRunner(model, - monitors=['V', 'n', 'h', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) - self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) - + def test_HH(self): + bm.random.seed() + model = biological_models.HH(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_HH_with_noise(self): + bm.random.seed() + model = biological_models.HH(size=1, noise=0.1) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['m'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_HH_batching_mode(self): + bm.random.seed() + model = biological_models.HH(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'm', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['m'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_MorrisLecar(self): + bm.random.seed() + model = biological_models.MorrisLecar(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_MorrisLecar_with_noise(self): + bm.random.seed() + model = biological_models.MorrisLecar(size=1, noise=0.1) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['W'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_MorrisLecar_batching_mode(self): + bm.random.seed() + model = biological_models.MorrisLecar(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'W', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['W'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) + + def test_PinskyRinzelModel(self): + bm.random.seed() + model = biological_models.PinskyRinzelModel(size=1) + runner = bp.DSRunner(model, + monitors=['Vs', 'Vd'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['Vs'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Vd'].shape, (100, 1)) + + def test_PinskyRinzelModel_with_noise(self): + bm.random.seed() + model = biological_models.PinskyRinzelModel(size=1, noise=0.1) + runner = bp.DSRunner(model, + monitors=['Vs', 'Vd'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['Vs'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['Vd'].shape, (100, 1)) + + def test_PinskyRinzelModel_batching_mode(self): + bm.random.seed() + model = biological_models.PinskyRinzelModel(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['Vs', 'Vd'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['Vs'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['Vd'].shape, (1, 100, 10)) + + def test_WangBuzsakiModel(self): + bm.random.seed() + model = biological_models.WangBuzsakiModel(size=1) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_WangBuzsakiModel_with_noise(self): + bm.random.seed() + model = biological_models.WangBuzsakiModel(size=1, noise=0.1) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['n'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['h'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + + def test_WangBuzsakiModel_batching_mode(self): + bm.random.seed() + model = biological_models.WangBuzsakiModel(size=10, mode=bm.batching_mode) + runner = bp.DSRunner(model, + monitors=['V', 'n', 'h', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['n'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['h'].shape, (1, 100, 10)) + self.assertTupleEqual(runner.mon['spike'].shape, (1, 100, 10)) diff --git a/brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py b/brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py index 686056d37..213ddfae0 100644 --- a/brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py +++ b/brainpy/_src/dynold/neurons/tests/test_fractional_neurons.py @@ -1,31 +1,32 @@ # -*- coding: utf-8 -*- -import brainpy as bp from absl.testing import parameterized + +import brainpy as bp from brainpy._src.dynold.neurons import fractional_models class Test_Fractional(parameterized.TestCase): - def test_FractionalFHR(self): - bp.math.random.seed() - model = fractional_models.FractionalFHR(size=1, alpha=0.5) - runner = bp.DSRunner(model, - monitors=['V', 'w', 'y', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['w'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['y'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + def test_FractionalFHR(self): + bp.math.random.seed() + model = fractional_models.FractionalFHR(size=1, alpha=0.5) + runner = bp.DSRunner(model, + monitors=['V', 'w', 'y', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['w'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['y'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - def test_FractionalIzhikevich(self): - bp.math.random.seed() - model = fractional_models.FractionalIzhikevich(size=1, alpha=0.5, num_memory=1000) - runner = bp.DSRunner(model, - monitors=['V', 'u', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['u'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) + def test_FractionalIzhikevich(self): + bp.math.random.seed() + model = fractional_models.FractionalIzhikevich(size=1, alpha=0.5, num_memory=1000) + runner = bp.DSRunner(model, + monitors=['V', 'u', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['u'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) diff --git a/brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py b/brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py index cde3dc24e..a21dad9b0 100644 --- a/brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py +++ b/brainpy/_src/dynold/neurons/tests/test_reduced_neurons.py @@ -1,75 +1,73 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized from brainpy._src.dynold.neurons import reduced_models class Test_Reduced(parameterized.TestCase): - @parameterized.named_parameters( - {'testcase_name': f'noise_of_{name}', 'neuron': name} - for name in reduced_models.__all__ - ) - def test_run_shape(self, neuron): - bm.random.seed() - model = getattr(reduced_models, neuron)(size=1) - if neuron == 'LeakyIntegrator': - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - else: - runner = bp.DSRunner(model, - monitors=['V', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - + @parameterized.named_parameters( + {'testcase_name': f'noise_of_{name}', 'neuron': name} + for name in reduced_models.__all__ + ) + def test_run_shape(self, neuron): + bm.random.seed() + model = getattr(reduced_models, neuron)(size=1) + if neuron == 'LeakyIntegrator': + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + else: + runner = bp.DSRunner(model, + monitors=['V', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - @parameterized.named_parameters( - {'testcase_name': f'noise_of_{name}', 'neuron': name} - for name in reduced_models.__all__ - ) - def test_noise_shape(self, neuron): - bm.random.seed() - model = getattr(reduced_models, neuron)(size=1, noise=0.1) - if neuron == 'LeakyIntegrator': - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - else: - runner = bp.DSRunner(model, - monitors=['V', 'spike'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) - self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - + @parameterized.named_parameters( + {'testcase_name': f'noise_of_{name}', 'neuron': name} + for name in reduced_models.__all__ + ) + def test_noise_shape(self, neuron): + bm.random.seed() + model = getattr(reduced_models, neuron)(size=1, noise=0.1) + if neuron == 'LeakyIntegrator': + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + else: + runner = bp.DSRunner(model, + monitors=['V', 'spike'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 1)) + self.assertTupleEqual(runner.mon['spike'].shape, (100, 1)) - @parameterized.named_parameters( - {'testcase_name': f'noise_of_{name}', 'neuron': name} - for name in reduced_models.__all__ - ) - def test_training_shape(self, neuron): - bm.random.seed() - if neuron == 'FHN': - model = getattr(reduced_models, neuron)(size=10) - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (100, 10)) - else: - model = getattr(reduced_models, neuron)(size=10, mode=bm.training_mode) - runner = bp.DSRunner(model, - monitors=['V'], - progress_bar=False) - runner.run(10.) - self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) - + @parameterized.named_parameters( + {'testcase_name': f'noise_of_{name}', 'neuron': name} + for name in reduced_models.__all__ + ) + def test_training_shape(self, neuron): + bm.random.seed() + if neuron == 'FHN': + model = getattr(reduced_models, neuron)(size=10) + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (100, 10)) + else: + model = getattr(reduced_models, neuron)(size=10, mode=bm.training_mode) + runner = bp.DSRunner(model, + monitors=['V'], + progress_bar=False) + runner.run(10.) + self.assertTupleEqual(runner.mon['V'].shape, (1, 100, 10)) diff --git a/brainpy/_src/dynold/synapses/__init__.py b/brainpy/_src/dynold/synapses/__init__.py index 233535ff5..01fd3605f 100644 --- a/brainpy/_src/dynold/synapses/__init__.py +++ b/brainpy/_src/dynold/synapses/__init__.py @@ -1,8 +1,7 @@ # -*- coding: utf-8 -*- -from .base import * from .abstract_models import * +from .base import * from .biological_models import * -from .learning_rules import * from .compat import * - +from .learning_rules import * diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index 52ad7163d..ba9472252 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -16,711 +16,711 @@ from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre __all__ = [ - 'Delta', - 'Exponential', - 'DualExponential', - 'Alpha', - 'NMDA', + 'Delta', + 'Exponential', + 'DualExponential', + 'Alpha', + 'NMDA', ] class Delta(TwoEndConn): - r"""Voltage Jump Synapse Model, or alias of Delta Synapse Model. - - **Model Descriptions** - - .. math:: - - I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \mathrm{STP} * \delta(t-t_j-D) - - where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, - :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, - :math:`C` the set of neurons connected to the post-synaptic neuron, - :math:`D` the transmission delay of chemical synapses, - and :math:`\mathrm{STP}` the short-term plasticity effect. - For simplicity, the rise and decay phases of post-synaptic currents are - omitted in this model. - - **Model Examples** - - >>> import brainpy as bp - >>> from brainpy import synapses, neurons - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.LIF(1) - >>> neu2 = neurons.LIF(1) - >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), g_max=5.) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.), ('post.input', 10.)], monitors=['pre.V', 'post.V', 'pre.spike']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.xlim(40, 150) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - pre: NeuDyn - The pre-synaptic neuron group. - post: NeuDyn - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Initializer, Callable - The synaptic strength. Default is 1. - post_ref_key: str - Whether the post-synaptic group has refractory period. - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: _SynOut = CUBA(target_var='V'), - stp: Optional[_SynSTP] = None, - comp_method: str = 'sparse', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[float, ArrayType, Initializer, Callable] = None, - post_ref_key: str = None, - name: str = None, - mode: bm.Mode = None, - stop_spike_gradient: bool = False, - ): - super().__init__(name=name, - pre=pre, - post=post, - conn=conn, - output=output, - stp=stp, - mode=mode) - - # parameters - self.stop_spike_gradient = stop_spike_gradient - self.post_ref_key = post_ref_key - if post_ref_key: - self.check_post_attrs(post_ref_key) - self.comp_method = comp_method - - # connections and weights - self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr') - - # register delay - self.pre.register_local_delay("spike", self.name, delay_step=delay_step) - - def update(self, pre_spike=None): - # pre-synaptic spikes - if pre_spike is None: - pre_spike = self.pre.get_local_delay("spike", self.name) - pre_spike = bm.as_jax(pre_spike) - if self.stop_spike_gradient: - pre_spike = jax.lax.stop_gradient(pre_spike) - - # update sub-components - if self.stp is not None: - self.stp.update(pre_spike) - - # synaptic values onto the post - if isinstance(self.conn, All2All): - syn_value = bm.asarray(pre_spike, dtype=bm.float_) - if self.stp is not None: - syn_value = self.stp(syn_value) - post_vs = self._syn2post_with_all2all(syn_value, self.g_max) - elif isinstance(self.conn, One2One): - syn_value = bm.asarray(pre_spike, dtype=bm.float_) - if self.stp is not None: - syn_value = self.stp(syn_value) - post_vs = self._syn2post_with_one2one(syn_value, self.g_max) - else: - if self.comp_method == 'sparse': + r"""Voltage Jump Synapse Model, or alias of Delta Synapse Model. + + **Model Descriptions** + + .. math:: + + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \mathrm{STP} * \delta(t-t_j-D) + + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + :math:`D` the transmission delay of chemical synapses, + and :math:`\mathrm{STP}` the short-term plasticity effect. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. + + **Model Examples** + + >>> import brainpy as bp + >>> from brainpy import synapses, neurons + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), g_max=5.) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.), ('post.input', 10.)], monitors=['pre.V', 'post.V', 'pre.spike']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(1, 1, 3, 8) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.xlim(40, 150) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + pre: NeuDyn + The pre-synaptic neuron group. + post: NeuDyn + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + g_max: float, ArrayType, Initializer, Callable + The synaptic strength. Default is 1. + post_ref_key: str + Whether the post-synaptic group has refractory period. + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: _SynOut = CUBA(target_var='V'), + stp: Optional[_SynSTP] = None, + comp_method: str = 'sparse', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[float, ArrayType, Initializer, Callable] = None, + post_ref_key: str = None, + name: str = None, + mode: bm.Mode = None, + stop_spike_gradient: bool = False, + ): + super().__init__(name=name, + pre=pre, + post=post, + conn=conn, + output=output, + stp=stp, + mode=mode) + + # parameters + self.stop_spike_gradient = stop_spike_gradient + self.post_ref_key = post_ref_key + if post_ref_key: + self.check_post_attrs(post_ref_key) + self.comp_method = comp_method + + # connections and weights + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr') + + # register delay + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) + + def update(self, pre_spike=None): + # pre-synaptic spikes + if pre_spike is None: + pre_spike = self.pre.get_local_delay("spike", self.name) + pre_spike = bm.as_jax(pre_spike) + if self.stop_spike_gradient: + pre_spike = jax.lax.stop_gradient(pre_spike) + + # update sub-components if self.stp is not None: - syn_value = self.stp(pre_spike) - f = lambda s: bm.sparse.csrmv( - self.g_max, self.conn_mask[0], self.conn_mask[1], s, - shape=(self.pre.num, self.post.num), transpose=True - ) + self.stp.update(pre_spike) + + # synaptic values onto the post + if isinstance(self.conn, All2All): + syn_value = bm.asarray(pre_spike, dtype=bm.float_) + if self.stp is not None: + syn_value = self.stp(syn_value) + post_vs = self._syn2post_with_all2all(syn_value, self.g_max) + elif isinstance(self.conn, One2One): + syn_value = bm.asarray(pre_spike, dtype=bm.float_) + if self.stp is not None: + syn_value = self.stp(syn_value) + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: - syn_value = pre_spike - f = lambda s: bm.event.csrmv( - self.g_max, self.conn_mask[0], self.conn_mask[1], s, - shape=(self.pre.num, self.post.num), transpose=True - ) - if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f) - post_vs = f(syn_value) - else: - syn_value = bm.asarray(pre_spike, dtype=bm.float_) - if self.stp is not None: - syn_value = self.stp(syn_value) - post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) - if self.post_ref_key: - post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key)) - - # update outputs - return self.output(post_vs) + if self.comp_method == 'sparse': + if self.stp is not None: + syn_value = self.stp(pre_spike) + f = lambda s: bm.sparse.csrmv( + self.g_max, self.conn_mask[0], self.conn_mask[1], s, + shape=(self.pre.num, self.post.num), transpose=True + ) + else: + syn_value = pre_spike + f = lambda s: bm.event.csrmv( + self.g_max, self.conn_mask[0], self.conn_mask[1], s, + shape=(self.pre.num, self.post.num), transpose=True + ) + if isinstance(self.mode, bm.BatchingMode): f = jax.vmap(f) + post_vs = f(syn_value) + else: + syn_value = bm.asarray(pre_spike, dtype=bm.float_) + if self.stp is not None: + syn_value = self.stp(syn_value) + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + if self.post_ref_key: + post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key)) + + # update outputs + return self.output(post_vs) class Exponential(TwoEndConn): - r"""Exponential decay synapse model. - - %s - - **Model Examples** - - - `(Brunel & Hakim, 1999) Fast Global Oscillation `_ - - `(Vreeswijk & Sompolinsky, 1996) E/I balanced network `_ - - `(Brette, et, al., 2007) CUBA `_ - - `(Tian, et al., 2020) E/I Net for fast response `_ - - - >>> import brainpy as bp - >>> from brainpy import neurons, synapses, synouts - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.LIF(1) - >>> neu2 = neurons.LIF(1) - >>> syn1 = synapses.Exponential(neu1, neu2, bp.conn.All2All(), - >>> g_max=5., output=synouts.CUBA()) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.legend() - >>> - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - pre: NeuGroup - The pre-synaptic neuron group. - post: NeuGroup - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau: float, ArrayType - The time constant of decay. [ms] - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: Optional[_SynOut] = CUBA(), - stp: Optional[_SynSTP] = None, - comp_method: str = 'sparse', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau: Union[float, ArrayType] = 8.0, - method: str = 'exp_auto', - - # other parameters - name: str = None, - mode: bm.Mode = None, - stop_spike_gradient: bool = False, - ): - super().__init__(pre=pre, - post=post, - conn=conn, - output=output, - stp=stp, - name=name, - mode=mode) - # parameters - self.stop_spike_gradient = stop_spike_gradient - - # synapse dynamics - self.syn = synapses.Expon(post.varshape, tau=tau, method=method) - - # Projection - if isinstance(conn, All2All): - self.comm = linear.AllToAll(pre.num, post.num, g_max) - elif isinstance(conn, One2One): - assert post.num == pre.num - self.comm = linear.OneToOne(pre.num, g_max) - else: - if comp_method == 'dense': - self.comm = linear.MaskedLinear(conn, g_max) - elif comp_method == 'sparse': - if self.stp is None: - self.comm = linear.EventCSRLinear(conn, g_max) + r"""Exponential decay synapse model. + + %s + + **Model Examples** + + - `(Brunel & Hakim, 1999) Fast Global Oscillation `_ + - `(Vreeswijk & Sompolinsky, 1996) E/I balanced network `_ + - `(Brette, et, al., 2007) CUBA `_ + - `(Tian, et al., 2020) E/I Net for fast response `_ + + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses, synouts + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.Exponential(neu1, neu2, bp.conn.All2All(), + >>> g_max=5., output=synouts.CUBA()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau: float, ArrayType + The time constant of decay. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: Optional[_SynOut] = CUBA(), + stp: Optional[_SynSTP] = None, + comp_method: str = 'sparse', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau: Union[float, ArrayType] = 8.0, + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + stop_spike_gradient: bool = False, + ): + super().__init__(pre=pre, + post=post, + conn=conn, + output=output, + stp=stp, + name=name, + mode=mode) + # parameters + self.stop_spike_gradient = stop_spike_gradient + + # synapse dynamics + self.syn = synapses.Expon(post.varshape, tau=tau, method=method) + + # Projection + if isinstance(conn, All2All): + self.comm = linear.AllToAll(pre.num, post.num, g_max) + elif isinstance(conn, One2One): + assert post.num == pre.num + self.comm = linear.OneToOne(pre.num, g_max) else: - self.comm = linear.CSRLinear(conn, g_max) - else: - raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".') - - # delay - self.pre.register_local_delay("spike", self.name, delay_step=delay_step) + if comp_method == 'dense': + self.comm = linear.MaskedLinear(conn, g_max) + elif comp_method == 'sparse': + if self.stp is None: + self.comm = linear.EventCSRLinear(conn, g_max) + else: + self.comm = linear.CSRLinear(conn, g_max) + else: + raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".') + + # delay + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) + + @property + def g(self): + return self.syn.g + + @g.setter + def g(self, value): + self.syn.g = value + + def update(self, pre_spike=None): + # delays + if pre_spike is None: + pre_spike = self.pre.get_local_delay("spike", self.name) + pre_spike = bm.as_jax(pre_spike) + if self.stop_spike_gradient: + pre_spike = jax.lax.stop_gradient(pre_spike) + + # update sub-components + self.output.update() + if self.stp is not None: + self.stp.update(pre_spike) + pre_spike = self.stp(pre_spike) - @property - def g(self): - return self.syn.g + # post values + g = self.syn(self.comm(pre_spike)) - @g.setter - def g(self, value): - self.syn.g = value + # output + return self.output(g) - def update(self, pre_spike=None): - # delays - if pre_spike is None: - pre_spike = self.pre.get_local_delay("spike", self.name) - pre_spike = bm.as_jax(pre_spike) - if self.stop_spike_gradient: - pre_spike = jax.lax.stop_gradient(pre_spike) - # update sub-components - self.output.update() - if self.stp is not None: - self.stp.update(pre_spike) - pre_spike = self.stp(pre_spike) +Exponential.__doc__ = Exponential.__doc__ % (_docs.exp_syn_doc,) - # post values - g = self.syn(self.comm(pre_spike)) - # output - return self.output(g) +class DualExponential(_TwoEndConnAlignPre): + r"""Dual exponential synapse model. + + %s + + **Model Examples** + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses, synouts + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.DualExponential(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + pre: NeuDyn + The pre-synaptic neuron group. + post: NeuDyn + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau_decay: float, ArrayArray, ndarray + The time constant of the synaptic decay phase. [ms] + tau_rise: float, ArrayArray, ndarray + The time constant of the synaptic rise phase. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + stp: Optional[_SynSTP] = None, + output: _SynOut = None, # CUBA(), + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau_decay: Union[float, ArrayType] = 10.0, + tau_rise: Union[float, ArrayType] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + A: Optional[Union[float, ArrayType, Callable]] = None, + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + stop_spike_gradient: bool = False, + ): + + # parameters + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method + self.tau_rise = tau_rise + self.tau_decay = tau_decay + if bm.size(self.tau_rise) != 1: + raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. ' + f'But we got {self.tau_rise}') + if bm.size(self.tau_decay) != 1: + raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. ' + f'But we got {self.tau_decay}') + + syn = synapses.DualExpon(pre.size, + pre.keep_size, + A=A, + mode=mode, + tau_decay=tau_decay, + tau_rise=tau_rise, + method=method, ) + + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + output=output, + stp=stp, + comp_method=comp_method, + g_max=g_max, + delay_step=delay_step, + name=name, + mode=mode) + + self.check_post_attrs('input') + # copy the references + self.g = syn.g + self.h = syn.h + + def update(self, pre_spike=None): + return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) -Exponential.__doc__ = Exponential.__doc__ % (_docs.exp_syn_doc,) +DualExponential.__doc__ = DualExponential.__doc__ % (_docs.dual_exp_syn_doc,) -class DualExponential(_TwoEndConnAlignPre): - r"""Dual exponential synapse model. - - %s - - **Model Examples** - - >>> import brainpy as bp - >>> from brainpy import neurons, synapses, synouts - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.LIF(1) - >>> neu2 = neurons.LIF(1) - >>> syn1 = synapses.DualExponential(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.legend() - >>> - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') - >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - pre: NeuDyn - The pre-synaptic neuron group. - post: NeuDyn - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau_decay: float, ArrayArray, ndarray - The time constant of the synaptic decay phase. [ms] - tau_rise: float, ArrayArray, ndarray - The time constant of the synaptic rise phase. [ms] - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - stp: Optional[_SynSTP] = None, - output: _SynOut = None, # CUBA(), - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - tau_decay: Union[float, ArrayType] = 10.0, - tau_rise: Union[float, ArrayType] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - A: Optional[Union[float, ArrayType, Callable]] = None, - method: str = 'exp_auto', - - # other parameters - name: str = None, - mode: bm.Mode = None, - stop_spike_gradient: bool = False, - ): - - # parameters - self.stop_spike_gradient = stop_spike_gradient - self.comp_method = comp_method - self.tau_rise = tau_rise - self.tau_decay = tau_decay - if bm.size(self.tau_rise) != 1: - raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. ' - f'But we got {self.tau_rise}') - if bm.size(self.tau_decay) != 1: - raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. ' - f'But we got {self.tau_decay}') - - syn = synapses.DualExpon(pre.size, +class Alpha(_TwoEndConnAlignPre): + r"""Alpha synapse model. + + %s + + **Model Examples** + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses, synouts + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + pre: NeuDyn + The pre-synaptic neuron group. + post: NeuDyn + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau_decay: float, ArrayType + The time constant of the synaptic decay phase. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: _SynOut = None, # CUBA(), + stp: Optional[_SynSTP] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau_decay: Union[float, ArrayType] = 10.0, + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + stop_spike_gradient: bool = False, + ): + # parameters + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method + self.tau_decay = tau_decay + if bm.size(self.tau_decay) != 1: + raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. ' + f'But we got {self.tau_decay}') + + syn = synapses.Alpha(pre.size, pre.keep_size, - A=A, mode=mode, tau_decay=tau_decay, - tau_rise=tau_rise, - method=method, ) - - super().__init__(pre=pre, - post=post, - syn=syn, - conn=conn, - output=output, - stp=stp, - comp_method=comp_method, - g_max=g_max, - delay_step=delay_step, - name=name, - mode=mode) - - self.check_post_attrs('input') - # copy the references - self.g = syn.g - self.h = syn.h - - def update(self, pre_spike=None): - return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) + method=method) + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + comp_method=comp_method, + delay_step=delay_step, + g_max=g_max, + output=output, + stp=stp, + name=name, + mode=mode, ) -DualExponential.__doc__ = DualExponential.__doc__ % (_docs.dual_exp_syn_doc,) - + self.check_post_attrs('input') + # copy the references + self.g = syn.g + self.h = syn.h -class Alpha(_TwoEndConnAlignPre): - r"""Alpha synapse model. - - %s - - **Model Examples** - - >>> import brainpy as bp - >>> from brainpy import neurons, synapses, synouts - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.LIF(1) - >>> neu2 = neurons.LIF(1) - >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.legend() - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') - >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - pre: NeuDyn - The pre-synaptic neuron group. - post: NeuDyn - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau_decay: float, ArrayType - The time constant of the synaptic decay phase. [ms] - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: _SynOut = None, # CUBA(), - stp: Optional[_SynSTP] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau_decay: Union[float, ArrayType] = 10.0, - method: str = 'exp_auto', - - # other parameters - name: str = None, - mode: bm.Mode = None, - stop_spike_gradient: bool = False, - ): - # parameters - self.stop_spike_gradient = stop_spike_gradient - self.comp_method = comp_method - self.tau_decay = tau_decay - if bm.size(self.tau_decay) != 1: - raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. ' - f'But we got {self.tau_decay}') - - syn = synapses.Alpha(pre.size, - pre.keep_size, - mode=mode, - tau_decay=tau_decay, - method=method) - - super().__init__(pre=pre, - post=post, - syn=syn, - conn=conn, - comp_method=comp_method, - delay_step=delay_step, - g_max=g_max, - output=output, - stp=stp, - name=name, - mode=mode, ) - - self.check_post_attrs('input') - # copy the references - self.g = syn.g - self.h = syn.h - - def update(self, pre_spike=None): - return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) + def update(self, pre_spike=None): + return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) Alpha.__doc__ = Alpha.__doc__ % (_docs.alpha_syn_doc,) class NMDA(_TwoEndConnAlignPre): - r"""NMDA synapse model. - - **Model Descriptions** - - The NMDA receptor is a glutamate receptor and ion channel found in neurons. - The NMDA receptor is one of three types of ionotropic glutamate receptors, - the other two being AMPA and kainate receptors. - - The NMDA receptor mediated conductance depends on the postsynaptic voltage. - The voltage dependence is due to the blocking of the pore of the NMDA receptor - from the outside by a positively charged magnesium ion. The channel is - nearly completely blocked at resting potential, but the magnesium block is - relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` - that are not blocked by magnesium can be fitted to - - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} - \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, - usually 1 mM. Thus, the channel acts as a - "coincidence detector" and only once both of these conditions are met, the - channel opens and it allows positively charged ions (cations) to flow through - the cell membrane [2]_. - - If we make the approximation that the magnesium block changes - instantaneously with voltage and is independent of the gating of the channel, - the net NMDA receptor-mediated synaptic current is given by - - .. math:: - - I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} - - where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the - reversal potential. - - Simultaneously, the kinetics of synaptic state :math:`g` is given by - - .. math:: - - & g_\mathrm{NMDA} (t) = g_{max} g \\ - & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\ - & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k}) - - where the decay time of NMDA currents is usually taken to be - :math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms. - - The NMDA receptor has been thought to be very important for controlling - synaptic plasticity and mediating learning and memory functions [3]_. - - - **Model Examples** - - - `(Wang, 2002) Decision making spiking model `_ - - - >>> import brainpy as bp - >>> from brainpy import synapses, neurons - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.HH(1) - >>> neu2 = neurons.HH(1) - >>> syn1 = synapses.NMDA(neu1, neu2, bp.connect.All2All()) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.legend() - >>> - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') - >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - pre: NeuDyn - The pre-synaptic neuron group. - post: NeuDyn - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `dense`. - delay_step: int, ArrayType, Initializer, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Initializer, Callable - The synaptic strength (the maximum conductance). Default is 1. - tau_decay: float, ArrayType - The time constant of the synaptic decay phase. Default 100 [ms] - tau_rise: float, ArrayType - The time constant of the synaptic rise phase. Default 2 [ms] - a: float, ArrayType - Default 0.5 ms^-1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - References:: - - .. [1] Brunel N, Wang X J. Effects of neuromodulation in a - cortical network model of object working memory dominated - by recurrent inhibition[J]. - Journal of computational neuroscience, 2001, 11(1): 63-85. - .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and - Eric Gouaux. "Subunit arrangement and function in NMDA receptors." - Nature 438, no. 7065 (2005): 185-192. - .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New - England journal of medicine, 361(3), p.302. - .. [4] https://en.wikipedia.org/wiki/NMDA_receptor - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: _SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2), - stp: Optional[_SynSTP] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 0.15, - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau_decay: Union[float, ArrayType] = 100., - a: Union[float, ArrayType] = 0.5, - tau_rise: Union[float, ArrayType] = 2., - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - stop_spike_gradient: bool = False, - ): - # parameters - self.tau_decay = tau_decay - self.tau_rise = tau_rise - self.a = a - if bm.size(a) != 1: - raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}') - if bm.size(tau_decay) != 1: - raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}') - if bm.size(tau_rise) != 1: - raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}') - self.comp_method = comp_method - self.stop_spike_gradient = stop_spike_gradient - - syn = synapses.NMDA(pre.size, - pre.keep_size, - mode=mode, - a=a, - tau_decay=tau_decay, - tau_rise=tau_rise, - method=method, ) - - super().__init__(pre=pre, - post=post, - syn=syn, - conn=conn, - output=output, - stp=stp, - comp_method=comp_method, - g_max=g_max, - delay_step=delay_step, - name=name, - mode=mode) - - # copy the references - self.g = syn.g - self.x = syn.x - - def update(self, pre_spike=None): - return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) + r"""NMDA synapse model. + + **Model Descriptions** + + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. + + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} + \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. + + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is given by + + .. math:: + + & g_\mathrm{NMDA} (t) = g_{max} g \\ + & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\ + & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k}) + + where the decay time of NMDA currents is usually taken to be + :math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + + **Model Examples** + + - `(Wang, 2002) Decision making spiking model `_ + + + >>> import brainpy as bp + >>> from brainpy import synapses, neurons + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.HH(1) + >>> neu2 = neurons.HH(1) + >>> syn1 = synapses.NMDA(neu1, neu2, bp.connect.All2All()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + pre: NeuDyn + The pre-synaptic neuron group. + post: NeuDyn + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + tau_decay: float, ArrayType + The time constant of the synaptic decay phase. Default 100 [ms] + tau_rise: float, ArrayType + The time constant of the synaptic rise phase. Default 2 [ms] + a: float, ArrayType + Default 0.5 ms^-1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References:: + + .. [1] Brunel N, Wang X J. Effects of neuromodulation in a + cortical network model of object working memory dominated + by recurrent inhibition[J]. + Journal of computational neuroscience, 2001, 11(1): 63-85. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: _SynOut = MgBlock(E=0., alpha=0.062, beta=3.57, cc_Mg=1.2), + stp: Optional[_SynSTP] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 0.15, + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau_decay: Union[float, ArrayType] = 100., + a: Union[float, ArrayType] = 0.5, + tau_rise: Union[float, ArrayType] = 2., + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + stop_spike_gradient: bool = False, + ): + # parameters + self.tau_decay = tau_decay + self.tau_rise = tau_rise + self.a = a + if bm.size(a) != 1: + raise ValueError(f'"a" must be a scalar or a tensor with size of 1. But we got {a}') + if bm.size(tau_decay) != 1: + raise ValueError(f'"tau_decay" must be a scalar or a tensor with size of 1. But we got {tau_decay}') + if bm.size(tau_rise) != 1: + raise ValueError(f'"tau_rise" must be a scalar or a tensor with size of 1. But we got {tau_rise}') + self.comp_method = comp_method + self.stop_spike_gradient = stop_spike_gradient + + syn = synapses.NMDA(pre.size, + pre.keep_size, + mode=mode, + a=a, + tau_decay=tau_decay, + tau_rise=tau_rise, + method=method, ) + + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + output=output, + stp=stp, + comp_method=comp_method, + g_max=g_max, + delay_step=delay_step, + name=name, + mode=mode) + + # copy the references + self.g = syn.g + self.x = syn.x + + def update(self, pre_spike=None): + return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index 27e5795f1..7cdb35cc5 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -6,8 +6,8 @@ from brainpy import math as bm from brainpy._src.connect import TwoEndConnector, One2One, All2All from brainpy._src.dnn import linear -from brainpy._src.dyn.projections.conn import SynConn from brainpy._src.dyn.base import NeuDyn +from brainpy._src.dyn.projections.conn import SynConn from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import parameter from brainpy._src.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo) @@ -15,326 +15,323 @@ from brainpy.types import ArrayType __all__ = [ - '_SynSTP', - '_SynOut', - 'TwoEndConn', - '_TwoEndConnAlignPre', + '_SynSTP', + '_SynOut', + 'TwoEndConn', + '_TwoEndConnAlignPre', ] class _SynapseComponent(DynamicalSystem): - """Base class for modeling synaptic components, - including synaptic output, synaptic short-term plasticity, - synaptic long-term plasticity, and others. """ + """Base class for modeling synaptic components, + including synaptic output, synaptic short-term plasticity, + synaptic long-term plasticity, and others. """ - '''Master of this component.''' - master: SynConn + '''Master of this component.''' + master: SynConn - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - self._registered = False + self._registered = False - @property - def isregistered(self) -> bool: - """State of the component, representing whether it has been registered.""" - return self._registered + @property + def isregistered(self) -> bool: + """State of the component, representing whether it has been registered.""" + return self._registered - @isregistered.setter - def isregistered(self, val: bool): - if not isinstance(val, bool): - raise ValueError('Must be an instance of bool.') - self._registered = val + @isregistered.setter + def isregistered(self, val: bool): + if not isinstance(val, bool): + raise ValueError('Must be an instance of bool.') + self._registered = val - def register_master(self, master: SynConn): - if not isinstance(master, SynConn): - raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') - if self.isregistered: - raise ValueError(f'master has been registered, but we got another master going to be registered.') - if hasattr(self, 'master') and self.master != master: - raise ValueError(f'master has been registered, but we got another master going to be registered.') - self.master = master - self._registered = True + def register_master(self, master: SynConn): + if not isinstance(master, SynConn): + raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') + if self.isregistered: + raise ValueError(f'master has been registered, but we got another master going to be registered.') + if hasattr(self, 'master') and self.master != master: + raise ValueError(f'master has been registered, but we got another master going to be registered.') + self.master = master + self._registered = True - def __repr__(self): - return self.__class__.__name__ + def __repr__(self): + return self.__class__.__name__ - def __call__(self, *args, **kwargs): - return self.filter(*args, **kwargs) + def __call__(self, *args, **kwargs): + return self.filter(*args, **kwargs) - def clone(self) -> '_SynapseComponent': - """The function useful to clone a new object when it has been used.""" - raise NotImplementedError + def clone(self) -> '_SynapseComponent': + """The function useful to clone a new object when it has been used.""" + raise NotImplementedError - def filter(self, g): - raise NotImplementedError + def filter(self, g): + raise NotImplementedError class _SynOut(_SynapseComponent, ParamDesc): - """Base class for synaptic current output.""" - - def __init__( - self, - name: str = None, - target_var: Union[str, bm.Variable] = None, - ): - super().__init__(name=name) - # check target variable - if target_var is not None: - if not isinstance(target_var, (str, bm.Variable)): - raise TypeError('"target_var" must be instance of string or Variable. ' - f'But we got {type(target_var)}') - self.target_var: Optional[bm.Variable] = target_var - - def register_master(self, master: SynConn): - super().register_master(master) - - # initialize target variable to output - if isinstance(self.target_var, str): - if not hasattr(self.master.post, self.target_var): - raise KeyError(f'Post-synaptic group does not have target variable: {self.target_var}') - self.target_var = getattr(self.master.post, self.target_var) - - def filter(self, g): - if self.target_var is None: - return g - else: - self.target_var += g - - def update(self): - pass + """Base class for synaptic current output.""" + + def __init__( + self, + name: str = None, + target_var: Union[str, bm.Variable] = None, + ): + super().__init__(name=name) + # check target variable + if target_var is not None: + if not isinstance(target_var, (str, bm.Variable)): + raise TypeError('"target_var" must be instance of string or Variable. ' + f'But we got {type(target_var)}') + self.target_var: Optional[bm.Variable] = target_var + + def register_master(self, master: SynConn): + super().register_master(master) + + # initialize target variable to output + if isinstance(self.target_var, str): + if not hasattr(self.master.post, self.target_var): + raise KeyError(f'Post-synaptic group does not have target variable: {self.target_var}') + self.target_var = getattr(self.master.post, self.target_var) + + def filter(self, g): + if self.target_var is None: + return g + else: + self.target_var += g + + def update(self): + pass class _SynSTP(_SynapseComponent, ParamDesc, SupportAutoDelay): - """Base class for synaptic short-term plasticity.""" + """Base class for synaptic short-term plasticity.""" - def update(self, pre_spike): - pass + def update(self, pre_spike): + pass - def return_info(self): - assert self.isregistered - return ReturnInfo(self.master.pre.varshape, None, self.master.pre.mode, bm.zeros) + def return_info(self): + assert self.isregistered + return ReturnInfo(self.master.pre.varshape, None, self.master.pre.mode, bm.zeros) class _NullSynOut(_SynOut): - def clone(self): - return _NullSynOut() + def clone(self): + return _NullSynOut() class TwoEndConn(SynConn): - """Base class to model synaptic connections. - - Parameters:: - - pre : NeuGroup - Pre-synaptic neuron group. - post : NeuGroup - Post-synaptic neuron group. - conn : optional, ndarray, ArrayType, dict, TwoEndConnector - The connection method between pre- and post-synaptic groups. - output: Optional, SynOutput - The output for the synaptic current. - - .. versionadded:: 2.1.13 - The output component for a two-end connection model. - - stp: Optional, SynSTP - The short-term plasticity model for the synaptic variables. - - .. versionadded:: 2.1.13 - The short-term plasticity component for a two-end connection model. - - ltp: Optional, SynLTP - The long-term plasticity model for the synaptic variables. - - .. versionadded:: 2.1.13 - The long-term plasticity component for a two-end connection model. - - name: Optional, str - The name of the dynamic system. - """ - - def __init__( - self, - pre: DynamicalSystem, - post: DynamicalSystem, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None, - output: _SynOut = _NullSynOut(), - stp: Optional[_SynSTP] = None, - ltp: Optional = None, - mode: bm.Mode = None, - name: str = None, - init_stp: bool = True - ): - super().__init__(pre=pre, - post=post, - conn=conn, - name=name, - mode=mode) - - # synaptic output - output = _NullSynOut() if output is None else output - if output.isregistered: - output = output.clone() - if not isinstance(output, _SynOut): - raise TypeError(f'output must be instance of {_SynOut.__name__}, ' - f'but we got {type(output)}') - output.register_master(master=self) - self.output: _SynOut = output - - # short-term synaptic plasticity - if init_stp: - stp = _init_stp(stp, self) - self.stp: Optional[_SynSTP] = stp - - def _init_weights( - self, - weight: Union[float, ArrayType, Callable], - comp_method: str, - sparse_data: str = 'csr' - ) -> Tuple[Union[float, ArrayType], ArrayType]: - if comp_method not in ['sparse', 'dense']: - raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') - if sparse_data not in ['csr', 'ij', 'coo']: - raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}') - if self.conn is None: - raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - - # connections and weights - if isinstance(self.conn, One2One): - weight = parameter(weight, (self.pre.num,), allow_none=False) - conn_mask = None - - elif isinstance(self.conn, All2All): - weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) - conn_mask = None - - else: - if comp_method == 'sparse': - if sparse_data == 'csr': - conn_mask = self.conn.require('pre2post') - elif sparse_data in ['ij', 'coo']: - conn_mask = self.conn.require('post_ids', 'pre_ids') + """Base class to model synaptic connections. + + Parameters:: + + pre : NeuGroup + Pre-synaptic neuron group. + post : NeuGroup + Post-synaptic neuron group. + conn : optional, ndarray, ArrayType, dict, TwoEndConnector + The connection method between pre- and post-synaptic groups. + output: Optional, SynOutput + The output for the synaptic current. + + .. versionadded:: 2.1.13 + The output component for a two-end connection model. + + stp: Optional, SynSTP + The short-term plasticity model for the synaptic variables. + + .. versionadded:: 2.1.13 + The short-term plasticity component for a two-end connection model. + + ltp: Optional, SynLTP + The long-term plasticity model for the synaptic variables. + + .. versionadded:: 2.1.13 + The long-term plasticity component for a two-end connection model. + + name: Optional, str + The name of the dynamic system. + """ + + def __init__( + self, + pre: DynamicalSystem, + post: DynamicalSystem, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]] = None, + output: _SynOut = _NullSynOut(), + stp: Optional[_SynSTP] = None, + ltp: Optional = None, + mode: bm.Mode = None, + name: str = None, + init_stp: bool = True + ): + super().__init__(pre=pre, + post=post, + conn=conn, + name=name, + mode=mode) + + # synaptic output + output = _NullSynOut() if output is None else output + if output.isregistered: + output = output.clone() + if not isinstance(output, _SynOut): + raise TypeError(f'output must be instance of {_SynOut.__name__}, ' + f'but we got {type(output)}') + output.register_master(master=self) + self.output: _SynOut = output + + # short-term synaptic plasticity + if init_stp: + stp = _init_stp(stp, self) + self.stp: Optional[_SynSTP] = stp + + def _init_weights( + self, + weight: Union[float, ArrayType, Callable], + comp_method: str, + sparse_data: str = 'csr' + ) -> Tuple[Union[float, ArrayType], ArrayType]: + if comp_method not in ['sparse', 'dense']: + raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') + if sparse_data not in ['csr', 'ij', 'coo']: + raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}') + if self.conn is None: + raise ValueError(f'Must provide "conn" when initialize the model {self.name}') + + # connections and weights + if isinstance(self.conn, One2One): + weight = parameter(weight, (self.pre.num,), allow_none=False) + conn_mask = None + + elif isinstance(self.conn, All2All): + weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) + conn_mask = None + + else: + if comp_method == 'sparse': + if sparse_data == 'csr': + conn_mask = self.conn.require('pre2post') + elif sparse_data in ['ij', 'coo']: + conn_mask = self.conn.require('post_ids', 'pre_ids') + else: + ValueError(f'Unknown sparse data type: {sparse_data}') + weight = parameter(weight, conn_mask[0].shape, allow_none=False) + elif comp_method == 'dense': + weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) + conn_mask = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {comp_method}') + + # training weights + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + return weight, conn_mask + + def _syn2post_with_all2all(self, syn_value, syn_weight): + if bm.ndim(syn_weight) == 0: + if isinstance(self.mode, bm.BatchingMode): + post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) + else: + post_vs = bm.sum(syn_value) + if not self.conn.include_self: + post_vs = post_vs - syn_value + post_vs = syn_weight * post_vs else: - ValueError(f'Unknown sparse data type: {sparse_data}') - weight = parameter(weight, conn_mask[0].shape, allow_none=False) - elif comp_method == 'dense': - weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False) - conn_mask = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {comp_method}') - - # training weights - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - return weight, conn_mask - - def _syn2post_with_all2all(self, syn_value, syn_weight): - if bm.ndim(syn_weight) == 0: - if isinstance(self.mode, bm.BatchingMode): - post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) - else: - post_vs = bm.sum(syn_value) - if not self.conn.include_self: - post_vs = post_vs - syn_value - post_vs = syn_weight * post_vs - else: - post_vs = syn_value @ syn_weight - return post_vs - - def _syn2post_with_one2one(self, syn_value, syn_weight): - return syn_value * syn_weight - - def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): - if bm.ndim(syn_weight) == 0: - post_vs = (syn_weight * syn_value) @ conn_mat - else: - post_vs = syn_value @ (syn_weight * conn_mat) - return post_vs + post_vs = syn_value @ syn_weight + return post_vs + def _syn2post_with_one2one(self, syn_value, syn_weight): + return syn_value * syn_weight -def _init_stp(stp, master): - if stp is not None: - if stp.isregistered: - stp = stp.clone() - if not isinstance(stp, _SynSTP): - raise TypeError(f'Short-term plasticity must be instance of {_SynSTP.__name__}, ' - f'but we got {type(stp)}') - stp.register_master(master=master) - return stp - + def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): + if bm.ndim(syn_weight) == 0: + post_vs = (syn_weight * syn_value) @ conn_mat + else: + post_vs = syn_value @ (syn_weight * conn_mat) + return post_vs -class _TwoEndConnAlignPre(TwoEndConn): - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - syn: DynamicalSystem, - conn: TwoEndConnector, - g_max: Union[float, ArrayType, Callable], - output: JointType[DynamicalSystem, BindCondData] = _NullSynOut(), - stp: Optional[_SynSTP] = None, - comp_method: str = 'dense', - delay_step: Union[int, ArrayType, Callable] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - assert isinstance(pre, NeuDyn) - assert isinstance(post, NeuDyn) - assert isinstance(syn, DynamicalSystem) - - super().__init__(pre=pre, - post=post, - conn=conn, - output=output, - stp=stp, - name=name, - mode=mode) - - # delay - self.pre.register_local_delay("spike", self.name, delay_step=delay_step) - - # synaptic dynamics - self.syn = syn - - # synaptic communications - if isinstance(conn, All2All): - self.comm = linear.AllToAll(pre.num, post.num, g_max, include_self=conn.include_self) - elif isinstance(conn, One2One): - assert post.num == pre.num - self.comm = linear.OneToOne(pre.num, g_max) - else: - if comp_method == 'dense': - self.comm = linear.MaskedLinear(conn, g_max) - elif comp_method == 'sparse': - self.comm = linear.CSRLinear(conn, g_max) - else: - raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".') - - def update(self, pre_spike=None, stop_spike_gradient: bool = False): - if pre_spike is None: - pre_spike = self.pre.get_local_delay("spike", self.name) - if stop_spike_gradient: - pre_spike = jax.lax.stop_gradient(pre_spike) - if self.stp is not None: - self.stp.update(pre_spike) - pre_spike = self.stp(pre_spike) - current = self.comm(self.syn(pre_spike)) - return self.output(current) - - @property - def g_max(self): - warnings.warn('".g_max" is deprecated. ' - 'Use ".comm.weight" instead.', - UserWarning) - return self.comm.weight - - @g_max.setter - def g_max(self, v): - warnings.warn('Updating ".g_max" is deprecated. ' - 'Updating ".comm.weight" instead.', - UserWarning) - self.comm.weight = v +def _init_stp(stp, master): + if stp is not None: + if stp.isregistered: + stp = stp.clone() + if not isinstance(stp, _SynSTP): + raise TypeError(f'Short-term plasticity must be instance of {_SynSTP.__name__}, ' + f'but we got {type(stp)}') + stp.register_master(master=master) + return stp +class _TwoEndConnAlignPre(TwoEndConn): + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + syn: DynamicalSystem, + conn: TwoEndConnector, + g_max: Union[float, ArrayType, Callable], + output: JointType[DynamicalSystem, BindCondData] = _NullSynOut(), + stp: Optional[_SynSTP] = None, + comp_method: str = 'dense', + delay_step: Union[int, ArrayType, Callable] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + assert isinstance(pre, NeuDyn) + assert isinstance(post, NeuDyn) + assert isinstance(syn, DynamicalSystem) + + super().__init__(pre=pre, + post=post, + conn=conn, + output=output, + stp=stp, + name=name, + mode=mode) + + # delay + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) + + # synaptic dynamics + self.syn = syn + + # synaptic communications + if isinstance(conn, All2All): + self.comm = linear.AllToAll(pre.num, post.num, g_max, include_self=conn.include_self) + elif isinstance(conn, One2One): + assert post.num == pre.num + self.comm = linear.OneToOne(pre.num, g_max) + else: + if comp_method == 'dense': + self.comm = linear.MaskedLinear(conn, g_max) + elif comp_method == 'sparse': + self.comm = linear.CSRLinear(conn, g_max) + else: + raise UnsupportedError(f'Does not support {comp_method}, only "sparse" or "dense".') + + def update(self, pre_spike=None, stop_spike_gradient: bool = False): + if pre_spike is None: + pre_spike = self.pre.get_local_delay("spike", self.name) + if stop_spike_gradient: + pre_spike = jax.lax.stop_gradient(pre_spike) + if self.stp is not None: + self.stp.update(pre_spike) + pre_spike = self.stp(pre_spike) + current = self.comm(self.syn(pre_spike)) + return self.output(current) + + @property + def g_max(self): + warnings.warn('".g_max" is deprecated. ' + 'Use ".comm.weight" instead.', + UserWarning) + return self.comm.weight + + @g_max.setter + def g_max(self, v): + warnings.warn('Updating ".g_max" is deprecated. ' + 'Updating ".comm.weight" instead.', + UserWarning) + self.comm.weight = v diff --git a/brainpy/_src/dynold/synapses/biological_models.py b/brainpy/_src/dynold/synapses/biological_models.py index 4d3ec02b0..21a09fd32 100644 --- a/brainpy/_src/dynold/synapses/biological_models.py +++ b/brainpy/_src/dynold/synapses/biological_models.py @@ -5,365 +5,365 @@ import brainpy.math as bm from brainpy._src.connect import TwoEndConnector from brainpy._src.dyn import synapses +from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynold.synapses import _SynSTP, _SynOut, _TwoEndConnAlignPre from brainpy._src.dynold.synouts import COBA, MgBlock -from brainpy._src.dyn.base import NeuDyn from brainpy.types import ArrayType __all__ = [ - 'AMPA', - 'GABAa', - 'BioNMDA', + 'AMPA', + 'GABAa', + 'BioNMDA', ] class AMPA(_TwoEndConnAlignPre): - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: _SynOut = COBA(E=0.), - stp: Optional[_SynSTP] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Callable] = 0.42, - delay_step: Union[int, ArrayType, Callable] = None, - alpha: float = 0.98, - beta: float = 0.18, - T: float = 0.5, - T_duration: float = 0.5, - method: str = 'exp_auto', - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - stop_spike_gradient: bool = False, - ): - # parameters - self.stop_spike_gradient = stop_spike_gradient - self.comp_method = comp_method - self.alpha = alpha - self.beta = beta - self.T = T - self.T_duration = T_duration - if bm.size(alpha) != 1: - raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}') - if bm.size(beta) != 1: - raise ValueError(f'"beta" must be a scalar or a tensor with size of 1. But we got {beta}') - if bm.size(T) != 1: - raise ValueError(f'"T" must be a scalar or a tensor with size of 1. But we got {T}') - if bm.size(T_duration) != 1: - raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}') - - # AMPA - syn = synapses.AMPA(pre.size, pre.keep_size, mode=mode, alpha=alpha, beta=beta, - T=T, T_dur=T_duration, method=method) - - super().__init__(pre=pre, - post=post, - syn=syn, - conn=conn, - output=output, - stp=stp, - comp_method=comp_method, - g_max=g_max, - delay_step=delay_step, - name=name, - mode=mode) - - # copy the references - self.g = syn.g - self.spike_arrival_time = syn.spike_arrival_time - - def update(self, pre_spike=None): - return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: _SynOut = COBA(E=0.), + stp: Optional[_SynSTP] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Callable] = 0.42, + delay_step: Union[int, ArrayType, Callable] = None, + alpha: float = 0.98, + beta: float = 0.18, + T: float = 0.5, + T_duration: float = 0.5, + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + stop_spike_gradient: bool = False, + ): + # parameters + self.stop_spike_gradient = stop_spike_gradient + self.comp_method = comp_method + self.alpha = alpha + self.beta = beta + self.T = T + self.T_duration = T_duration + if bm.size(alpha) != 1: + raise ValueError(f'"alpha" must be a scalar or a tensor with size of 1. But we got {alpha}') + if bm.size(beta) != 1: + raise ValueError(f'"beta" must be a scalar or a tensor with size of 1. But we got {beta}') + if bm.size(T) != 1: + raise ValueError(f'"T" must be a scalar or a tensor with size of 1. But we got {T}') + if bm.size(T_duration) != 1: + raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}') + + # AMPA + syn = synapses.AMPA(pre.size, pre.keep_size, mode=mode, alpha=alpha, beta=beta, + T=T, T_dur=T_duration, method=method) + + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + output=output, + stp=stp, + comp_method=comp_method, + g_max=g_max, + delay_step=delay_step, + name=name, + mode=mode) + + # copy the references + self.g = syn.g + self.spike_arrival_time = syn.spike_arrival_time + + def update(self, pre_spike=None): + return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) class GABAa(AMPA): - r"""GABAa synapse model. - - **Model Descriptions** - - GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_, - - .. math:: - - \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\ - I_{syn}&= - g_{max} g (V - E) - - but with the difference of: - - - Reversal potential of synapse :math:`E` is usually low, typically -80. mV - - Activating rate constant :math:`\alpha=0.53` - - De-activating rate constant :math:`\beta=0.18` - - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is - triggered by a pre-synaptic spike, with the duration of 1. ms. - - **Model Examples** - - - `Gamma oscillation network model `_ - - - Parameters:: - - pre: NeuDyn - The pre-synaptic neuron group. - post: NeuDyn - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `dense`. - delay_step: int, ArrayType, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Callable - The synaptic strength (the maximum conductance). Default is 1. - alpha: float, ArrayType - Binding constant. Default 0.062 - beta: float, ArrayType - Unbinding constant. Default 3.57 - T: float, ArrayType - Transmitter concentration when synapse is triggered by - a pre-synaptic spike.. Default 1 [mM]. - T_duration: float, ArrayType - Transmitter concentration duration time after being triggered. Default 1 [ms] - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - References:: - - .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity - on the integrative properties of neocortical pyramidal neurons - in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: _SynOut = COBA(E=-80.), - stp: Optional[_SynSTP] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Callable] = 0.04, - delay_step: Union[int, ArrayType, Callable] = None, - alpha: Union[float, ArrayType] = 0.53, - beta: Union[float, ArrayType] = 0.18, - T: Union[float, ArrayType] = 1., - T_duration: Union[float, ArrayType] = 1., - method: str = 'exp_auto', - - # other parameters - name: str = None, - mode: bm.Mode = None, - stop_spike_gradient: bool = False, - ): - super().__init__(pre=pre, - post=post, - conn=conn, - output=output, - stp=stp, - comp_method=comp_method, - delay_step=delay_step, - g_max=g_max, - alpha=alpha, - beta=beta, - T=T, - T_duration=T_duration, - method=method, - name=name, - mode=mode, - stop_spike_gradient=stop_spike_gradient, ) + r"""GABAa synapse model. + + **Model Descriptions** + + GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_, + + .. math:: + + \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\ + I_{syn}&= - g_{max} g (V - E) + + but with the difference of: + + - Reversal potential of synapse :math:`E` is usually low, typically -80. mV + - Activating rate constant :math:`\alpha=0.53` + - De-activating rate constant :math:`\beta=0.18` + - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is + triggered by a pre-synaptic spike, with the duration of 1. ms. + + **Model Examples** + + - `Gamma oscillation network model `_ + + + Parameters:: + + pre: NeuDyn + The pre-synaptic neuron group. + post: NeuDyn + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ArrayType, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + g_max: float, ArrayType, Callable + The synaptic strength (the maximum conductance). Default is 1. + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + T: float, ArrayType + Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_duration: float, ArrayType + Transmitter concentration duration time after being triggered. Default 1 [ms] + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References:: + + .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity + on the integrative properties of neocortical pyramidal neurons + in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: _SynOut = COBA(E=-80.), + stp: Optional[_SynSTP] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Callable] = 0.04, + delay_step: Union[int, ArrayType, Callable] = None, + alpha: Union[float, ArrayType] = 0.53, + beta: Union[float, ArrayType] = 0.18, + T: Union[float, ArrayType] = 1., + T_duration: Union[float, ArrayType] = 1., + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + stop_spike_gradient: bool = False, + ): + super().__init__(pre=pre, + post=post, + conn=conn, + output=output, + stp=stp, + comp_method=comp_method, + delay_step=delay_step, + g_max=g_max, + alpha=alpha, + beta=beta, + T=T, + T_duration=T_duration, + method=method, + name=name, + mode=mode, + stop_spike_gradient=stop_spike_gradient, ) class BioNMDA(_TwoEndConnAlignPre): - r"""Biological NMDA synapse model. - - **Model Descriptions** - - The NMDA receptor is a glutamate receptor and ion channel found in neurons. - The NMDA receptor is one of three types of ionotropic glutamate receptors, - the other two being AMPA and kainate receptors. - - The NMDA receptor mediated conductance depends on the postsynaptic voltage. - The voltage dependence is due to the blocking of the pore of the NMDA receptor - from the outside by a positively charged magnesium ion. The channel is - nearly completely blocked at resting potential, but the magnesium block is - relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` - that are not blocked by magnesium can be fitted to - - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V} - \frac{[{Mg}^{2+}]_{o}} {b})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, - usually 1 mM. Thus, the channel acts as a - "coincidence detector" and only once both of these conditions are met, the - channel opens and it allows positively charged ions (cations) to flow through - the cell membrane [2]_. - - If we make the approximation that the magnesium block changes - instantaneously with voltage and is independent of the gating of the channel, - the net NMDA receptor-mediated synaptic current is given by - - .. math:: - - I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} - - where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the - reversal potential. - - Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_: - - .. math:: - - & g_\mathrm{NMDA} (t) = g_{max} g \\ - & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\ - & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x - - where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and - :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x. - - The NMDA receptor has been thought to be very important for controlling - synaptic plasticity and mediating learning and memory functions [3]_. - - >>> import brainpy as bp - >>> from brainpy import neurons, synapses - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = neurons.HH(1) - >>> neu2 = neurons.HH(1) - >>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All()) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) - >>> runner.run(150.) - >>> - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') - >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') - >>> plt.legend() - >>> - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') - >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - pre: NeuDyn - The pre-synaptic neuron group. - post: NeuDyn - The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector - The synaptic connections. - comp_method: str - The connection type used for model speed optimization. It can be - `sparse` and `dense`. The default is `dense`. - delay_step: int, ArrayType, Callable - The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Callable - The synaptic strength (the maximum conductance). Default is 1. - alpha1: float, ArrayType - The conversion rate of g from inactive to active. Default 2 ms^-1. - beta1: float, ArrayType - The conversion rate of g from active to inactive. Default 0.01 ms^-1. - alpha2: float, ArrayType - The conversion rate of x from inactive to active. Default 1 ms^-1. - beta2: float, ArrayType - The conversion rate of x from active to inactive. Default 0.5 ms^-1. - name: str - The name of this synaptic projection. - method: str - The numerical integration methods. - - References:: - - .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. - Springer New York, 2010: 162. - .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and - Eric Gouaux. "Subunit arrangement and function in NMDA receptors." - Nature 438, no. 7065 (2005): 185-192. - .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New - England journal of medicine, 361(3), p.302. - .. [4] https://en.wikipedia.org/wiki/NMDA_receptor - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - output: _SynOut = MgBlock(E=0.), - stp: Optional[_SynSTP] = None, - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Callable] = 0.15, - delay_step: Union[int, ArrayType, Callable] = None, - alpha1: Union[float, ArrayType] = 2., - beta1: Union[float, ArrayType] = 0.01, - alpha2: Union[float, ArrayType] = 1., - beta2: Union[float, ArrayType] = 0.5, - T_0: Union[float, ArrayType] = 1., - T_dur: Union[float, ArrayType] = 0.5, - method: str = 'exp_auto', - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - stop_spike_gradient: bool = False, - ): - - # parameters - self.beta1 = beta1 - self.beta2 = beta2 - self.alpha1 = alpha1 - self.alpha2 = alpha2 - self.T_0 = T_0 - self.T_dur = T_dur - if bm.size(alpha1) != 1: - raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}') - if bm.size(beta1) != 1: - raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}') - if bm.size(alpha2) != 1: - raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}') - if bm.size(beta2) != 1: - raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}') - if bm.size(T_0) != 1: - raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}') - if bm.size(T_dur) != 1: - raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}') - self.comp_method = comp_method - self.stop_spike_gradient = stop_spike_gradient - - syn = synapses.BioNMDA(pre.size, - pre.keep_size, - mode=mode, - alpha1=alpha1, - beta1=beta1, - alpha2=alpha2, - beta2=beta2, - T=T_0, - T_dur=T_dur, - method=method, ) - super().__init__(pre=pre, - post=post, - syn=syn, - conn=conn, - output=output, - stp=stp, - comp_method=comp_method, - g_max=g_max, - delay_step=delay_step, - name=name, - mode=mode) - - # copy the references - self.g = syn.g - self.x = syn.x - self.spike_arrival_time = syn.spike_arrival_time - - def update(self, pre_spike=None): - return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) + r"""Biological NMDA synapse model. + + **Model Descriptions** + + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. + + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V} + \frac{[{Mg}^{2+}]_{o}} {b})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. + + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_: + + .. math:: + + & g_\mathrm{NMDA} (t) = g_{max} g \\ + & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\ + & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x + + where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and + :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.HH(1) + >>> neu2 = neurons.HH(1) + >>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + pre: NeuDyn + The pre-synaptic neuron group. + post: NeuDyn + The post-synaptic neuron group. + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ArrayType, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + g_max: float, ArrayType, Callable + The synaptic strength (the maximum conductance). Default is 1. + alpha1: float, ArrayType + The conversion rate of g from inactive to active. Default 2 ms^-1. + beta1: float, ArrayType + The conversion rate of g from active to inactive. Default 0.01 ms^-1. + alpha2: float, ArrayType + The conversion rate of x from inactive to active. Default 1 ms^-1. + beta2: float, ArrayType + The conversion rate of x from active to inactive. Default 0.5 ms^-1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References:: + + .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. + Springer New York, 2010: 162. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + output: _SynOut = MgBlock(E=0.), + stp: Optional[_SynSTP] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Callable] = 0.15, + delay_step: Union[int, ArrayType, Callable] = None, + alpha1: Union[float, ArrayType] = 2., + beta1: Union[float, ArrayType] = 0.01, + alpha2: Union[float, ArrayType] = 1., + beta2: Union[float, ArrayType] = 0.5, + T_0: Union[float, ArrayType] = 1., + T_dur: Union[float, ArrayType] = 0.5, + method: str = 'exp_auto', + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + stop_spike_gradient: bool = False, + ): + + # parameters + self.beta1 = beta1 + self.beta2 = beta2 + self.alpha1 = alpha1 + self.alpha2 = alpha2 + self.T_0 = T_0 + self.T_dur = T_dur + if bm.size(alpha1) != 1: + raise ValueError(f'"alpha1" must be a scalar or a tensor with size of 1. But we got {alpha1}') + if bm.size(beta1) != 1: + raise ValueError(f'"beta1" must be a scalar or a tensor with size of 1. But we got {beta1}') + if bm.size(alpha2) != 1: + raise ValueError(f'"alpha2" must be a scalar or a tensor with size of 1. But we got {alpha2}') + if bm.size(beta2) != 1: + raise ValueError(f'"beta2" must be a scalar or a tensor with size of 1. But we got {beta2}') + if bm.size(T_0) != 1: + raise ValueError(f'"T_0" must be a scalar or a tensor with size of 1. But we got {T_0}') + if bm.size(T_dur) != 1: + raise ValueError(f'"T_dur" must be a scalar or a tensor with size of 1. But we got {T_dur}') + self.comp_method = comp_method + self.stop_spike_gradient = stop_spike_gradient + + syn = synapses.BioNMDA(pre.size, + pre.keep_size, + mode=mode, + alpha1=alpha1, + beta1=beta1, + alpha2=alpha2, + beta2=beta2, + T=T_0, + T_dur=T_dur, + method=method, ) + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + output=output, + stp=stp, + comp_method=comp_method, + g_max=g_max, + delay_step=delay_step, + name=name, + mode=mode) + + # copy the references + self.g = syn.g + self.x = syn.x + self.spike_arrival_time = syn.spike_arrival_time + + def update(self, pre_spike=None): + return super().update(pre_spike, stop_spike_gradient=self.stop_spike_gradient) diff --git a/brainpy/_src/dynold/synapses/compat.py b/brainpy/_src/dynold/synapses/compat.py index 108f01ad5..f05f51bcb 100644 --- a/brainpy/_src/dynold/synapses/compat.py +++ b/brainpy/_src/dynold/synapses/compat.py @@ -4,254 +4,254 @@ from typing import Union, Dict, Callable from brainpy._src.connect import TwoEndConnector -from brainpy._src.dynold.synouts import COBA, CUBA from brainpy._src.dyn.base import NeuDyn +from brainpy._src.dynold.synouts import COBA, CUBA from brainpy._src.initialize import Initializer from brainpy.types import ArrayType from .abstract_models import Delta, Exponential, DualExponential __all__ = [ - 'DeltaSynapse', - 'ExpCUBA', - 'ExpCOBA', - 'DualExpCUBA', - 'DualExpCOBA', - 'AlphaCUBA', - 'AlphaCOBA', + 'DeltaSynapse', + 'ExpCUBA', + 'ExpCOBA', + 'DualExpCUBA', + 'DualExpCOBA', + 'AlphaCUBA', + 'AlphaCOBA', ] class DeltaSynapse(Delta): - """Delta synapse. - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.Delta" instead. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'sparse', - weights: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[float, ArrayType, Initializer, Callable] = None, - post_input_key: str = 'V', - post_has_ref: bool = False, - name: str = None, - ): - warnings.warn('Please use "brainpy.synapses.Delta" instead.', DeprecationWarning) - super().__init__(pre=pre, - post=post, - conn=conn, - output=CUBA(post_input_key), - name=name, - comp_method=conn_type, - g_max=weights, - delay_step=delay_step, - post_ref_key='refractory' if post_has_ref else None) + """Delta synapse. + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.Delta" instead. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'sparse', + weights: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[float, ArrayType, Initializer, Callable] = None, + post_input_key: str = 'V', + post_has_ref: bool = False, + name: str = None, + ): + warnings.warn('Please use "brainpy.synapses.Delta" instead.', DeprecationWarning) + super().__init__(pre=pre, + post=post, + conn=conn, + output=CUBA(post_input_key), + name=name, + comp_method=conn_type, + g_max=weights, + delay_step=delay_step, + post_ref_key='refractory' if post_has_ref else None) class ExpCUBA(Exponential): - r"""Current-based exponential decay synapse model. - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.Exponential" instead. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'sparse', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau: Union[float, ArrayType] = 8.0, - name: str = None, - method: str = 'exp_auto', - ): - super().__init__(pre=pre, - post=post, - conn=conn, - name=name, - comp_method=conn_type, - g_max=g_max, - delay_step=delay_step, - tau=tau, - method=method, - output=CUBA()) + r"""Current-based exponential decay synapse model. + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.Exponential" instead. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'sparse', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau: Union[float, ArrayType] = 8.0, + name: str = None, + method: str = 'exp_auto', + ): + super().__init__(pre=pre, + post=post, + conn=conn, + name=name, + comp_method=conn_type, + g_max=g_max, + delay_step=delay_step, + tau=tau, + method=method, + output=CUBA()) class ExpCOBA(Exponential): - """Conductance-based exponential decay synapse model. - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.Exponential" instead. - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - # connection - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'sparse', - # connection strength - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - # synapse parameter - tau: Union[float, ArrayType] = 8.0, - E: Union[float, ArrayType] = 0., - # synapse delay - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - # others - method: str = 'exp_auto', - name: str = None - ): - super().__init__(pre=pre, - post=post, - conn=conn, - comp_method=conn_type, - g_max=g_max, - delay_step=delay_step, - tau=tau, - method=method, - name=name, - output=COBA(E=E)) + """Conductance-based exponential decay synapse model. + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.Exponential" instead. + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + # connection + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'sparse', + # connection strength + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + # synapse parameter + tau: Union[float, ArrayType] = 8.0, + E: Union[float, ArrayType] = 0., + # synapse delay + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + # others + method: str = 'exp_auto', + name: str = None + ): + super().__init__(pre=pre, + post=post, + conn=conn, + comp_method=conn_type, + g_max=g_max, + delay_step=delay_step, + tau=tau, + method=method, + name=name, + output=COBA(E=E)) class DualExpCUBA(DualExponential): - r"""Current-based dual exponential synapse model. - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.DualExponential" instead. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - tau_decay: Union[float, ArrayType] = 10.0, - tau_rise: Union[float, ArrayType] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - name: str = None - ): - super().__init__(pre=pre, - post=post, - conn=conn, - comp_method=conn_type, - g_max=g_max, - tau_decay=tau_decay, - tau_rise=tau_rise, - delay_step=delay_step, - method=method, - name=name, - output=CUBA()) + r"""Current-based dual exponential synapse model. + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.DualExponential" instead. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau_decay: Union[float, ArrayType] = 10.0, + tau_rise: Union[float, ArrayType] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + name: str = None + ): + super().__init__(pre=pre, + post=post, + conn=conn, + comp_method=conn_type, + g_max=g_max, + tau_decay=tau_decay, + tau_rise=tau_rise, + delay_step=delay_step, + method=method, + name=name, + output=CUBA()) class DualExpCOBA(DualExponential): - """Conductance-based dual exponential synapse model. - - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.DualExponential" instead. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau_decay: Union[float, ArrayType] = 10.0, - tau_rise: Union[float, ArrayType] = 1., - E: Union[float, ArrayType] = 0., - method: str = 'exp_auto', - name: str = None - ): - super().__init__(pre=pre, - post=post, - conn=conn, - comp_method=conn_type, - g_max=g_max, - tau_decay=tau_decay, - tau_rise=tau_rise, - delay_step=delay_step, - method=method, - name=name, - output=COBA(E=E)) + """Conductance-based dual exponential synapse model. + + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.DualExponential" instead. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau_decay: Union[float, ArrayType] = 10.0, + tau_rise: Union[float, ArrayType] = 1., + E: Union[float, ArrayType] = 0., + method: str = 'exp_auto', + name: str = None + ): + super().__init__(pre=pre, + post=post, + conn=conn, + comp_method=conn_type, + g_max=g_max, + tau_decay=tau_decay, + tau_rise=tau_rise, + delay_step=delay_step, + method=method, + name=name, + output=COBA(E=E)) class AlphaCUBA(DualExpCUBA): - r"""Current-based alpha synapse model. - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.Alpha" instead. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau_decay: Union[float, ArrayType] = 10.0, - method: str = 'exp_auto', - name: str = None - ): - super().__init__(pre=pre, - post=post, - conn=conn, - conn_type=conn_type, - delay_step=delay_step, - g_max=g_max, - tau_decay=tau_decay, - tau_rise=tau_decay, - method=method, - name=name) + r"""Current-based alpha synapse model. + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.Alpha" instead. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau_decay: Union[float, ArrayType] = 10.0, + method: str = 'exp_auto', + name: str = None + ): + super().__init__(pre=pre, + post=post, + conn=conn, + conn_type=conn_type, + delay_step=delay_step, + g_max=g_max, + tau_decay=tau_decay, + tau_rise=tau_decay, + method=method, + name=name) class AlphaCOBA(DualExpCOBA): - """Conductance-based alpha synapse model. - - .. deprecated:: 2.1.13 - Please use "brainpy.synapses.Alpha" instead. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - conn_type: str = 'dense', - g_max: Union[float, ArrayType, Callable, Initializer] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - tau_decay: Union[float, ArrayType] = 10.0, - E: Union[float, ArrayType] = 0., - method: str = 'exp_auto', - name: str = None - ): - super().__init__(pre=pre, - post=post, - conn=conn, - conn_type=conn_type, - delay_step=delay_step, - g_max=g_max, E=E, - tau_decay=tau_decay, - tau_rise=tau_decay, - method=method, - name=name) + """Conductance-based alpha synapse model. + + .. deprecated:: 2.1.13 + Please use "brainpy.synapses.Alpha" instead. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + conn_type: str = 'dense', + g_max: Union[float, ArrayType, Callable, Initializer] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau_decay: Union[float, ArrayType] = 10.0, + E: Union[float, ArrayType] = 0., + method: str = 'exp_auto', + name: str = None + ): + super().__init__(pre=pre, + post=post, + conn=conn, + conn_type=conn_type, + delay_step=delay_step, + g_max=g_max, E=E, + tau_decay=tau_decay, + tau_rise=tau_decay, + method=method, + name=name) diff --git a/brainpy/_src/dynold/synapses/gap_junction.py b/brainpy/_src/dynold/synapses/gap_junction.py index ffeb44353..73a5495cc 100644 --- a/brainpy/_src/dynold/synapses/gap_junction.py +++ b/brainpy/_src/dynold/synapses/gap_junction.py @@ -3,60 +3,60 @@ from typing import Union, Dict, Callable import brainpy.math as bm -from brainpy._src.dyn.base import NeuDyn from brainpy._src.connect import TwoEndConnector +from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynold.synapses import TwoEndConn from brainpy._src.initialize import Initializer, parameter from brainpy.types import ArrayType __all__ = [ - 'GapJunction', + 'GapJunction', ] class GapJunction(TwoEndConn): - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - comp_method: str = 'dense', - g_max: Union[float, ArrayType, Initializer, Callable] = 1., - name: str = None, - ): - super(GapJunction, self).__init__(pre=pre, - post=post, - conn=conn, - name=name) - # checking - self.check_pre_attrs('V') - self.check_post_attrs('V', 'input') - - # assert isinstance(self.output, _NullSynOut) - # assert isinstance(self.stp, _NullSynSTP) - - # connections - self.comp_method = comp_method - if comp_method == 'dense': - self.conn_mat = self.conn.require('conn_mat') - self.weights = parameter(g_max, (pre.num, post.num), allow_none=False) - elif comp_method == 'sparse': - self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') - self.weights = parameter(g_max, self.pre_ids.shape, allow_none=False) - else: - raise ValueError - - def update(self): - if self.comp_method == 'dense': - # pre -> post - diff = (self.pre.V.reshape((-1, 1)) - self.post.V) * self.conn_mat * self.weights - self.post.input += bm.einsum('ij->j', diff) - # post -> pre - self.pre.input += bm.einsum('ij->i', -diff) - else: - diff = (self.pre.V[self.pre_ids] - self.post.V[self.post_ids]) * self.weights - self.post.input += bm.syn2post_sum(diff, self.post_ids, self.post.num) - self.pre.input += bm.syn2post_sum(-diff, self.pre_ids, self.pre.num) - - def reset_state(self, batch_size=None): - pass + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + name: str = None, + ): + super(GapJunction, self).__init__(pre=pre, + post=post, + conn=conn, + name=name) + # checking + self.check_pre_attrs('V') + self.check_post_attrs('V', 'input') + + # assert isinstance(self.output, _NullSynOut) + # assert isinstance(self.stp, _NullSynSTP) + + # connections + self.comp_method = comp_method + if comp_method == 'dense': + self.conn_mat = self.conn.require('conn_mat') + self.weights = parameter(g_max, (pre.num, post.num), allow_none=False) + elif comp_method == 'sparse': + self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') + self.weights = parameter(g_max, self.pre_ids.shape, allow_none=False) + else: + raise ValueError + + def update(self): + if self.comp_method == 'dense': + # pre -> post + diff = (self.pre.V.reshape((-1, 1)) - self.post.V) * self.conn_mat * self.weights + self.post.input += bm.einsum('ij->j', diff) + # post -> pre + self.pre.input += bm.einsum('ij->i', -diff) + else: + diff = (self.pre.V[self.pre_ids] - self.post.V[self.post_ids]) * self.weights + self.post.input += bm.syn2post_sum(diff, self.post_ids, self.post.num) + self.pre.input += bm.syn2post_sum(-diff, self.pre_ids, self.pre.num) + + def reset_state(self, batch_size=None): + pass diff --git a/brainpy/_src/dynold/synapses/learning_rules.py b/brainpy/_src/dynold/synapses/learning_rules.py index e10a57ae9..83f60a3c6 100644 --- a/brainpy/_src/dynold/synapses/learning_rules.py +++ b/brainpy/_src/dynold/synapses/learning_rules.py @@ -4,217 +4,217 @@ from brainpy._src.connect import TwoEndConnector from brainpy._src.dyn import synapses -from brainpy._src.dynold.synouts import CUBA +from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynold.synapses import _TwoEndConnAlignPre +from brainpy._src.dynold.synouts import CUBA from brainpy._src.dynsys import Sequential -from brainpy._src.dyn.base import NeuDyn from brainpy._src.initialize import Initializer from brainpy._src.mixin import ParamDesc from brainpy.types import ArrayType __all__ = [ - 'STP' + 'STP' ] class _STPModel(Sequential, ParamDesc): - def __init__(self, size, keep_size, tau, U, tau_f, tau_d, mode=None, method='exp_euler'): - stp = synapses.STP(size, keep_size, U=U, tau_f=tau_f, tau_d=tau_d, method=method, mode=mode) - exp = synapses.Expon(size, keep_size, tau=tau, method=method, mode=mode) - super().__init__(stp, exp) + def __init__(self, size, keep_size, tau, U, tau_f, tau_d, mode=None, method='exp_euler'): + stp = synapses.STP(size, keep_size, U=U, tau_f=tau_f, tau_d=tau_d, method=method, mode=mode) + exp = synapses.Expon(size, keep_size, tau=tau, method=method, mode=mode) + super().__init__(stp, exp) class STP(_TwoEndConnAlignPre): - r"""Short-term plasticity model. - - **Model Descriptions** - - Short-term plasticity (STP) [1]_ [2]_ [3]_, also called dynamical synapses, - refers to the changes of synaptic strengths over time in a way that reflects - the history of presynaptic activity. Two types of STP, with opposite effects - on synaptic efficacy, have been observed in experiments. They are known as - Short-Term Depression (STD) and Short-Term Facilitation (STF). - - In the model proposed by Tsodyks and Markram [4]_ [5]_, the STD effect is - modeled by a normalized variable :math:`x (0 \le x \le 1)`, denoting the fraction - of resources that remain available after neurotransmitter depletion. - The STF effect is modeled by a utilization parameter :math:`u`, representing - the fraction of available resources ready for use (release probability). - Following a spike, - - - (i) :math:`u` increases due to spike-induced calcium influx to the presynaptic - terminal, after which - - (ii) a fraction :math:`u` of available resources is consumed to produce the - post-synaptic current. - - Between spikes, :math:`u` decays back to zero with time constant :math:`\tau_f` - and :math:`x` recovers to 1 with time constant :math:`\tau_d`. - - In summary, the dynamics of STP is given by - - .. math:: - - \begin{aligned} - \frac{du}{dt} & = -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}),\nonumber \\ - \frac{dx}{dt} & = \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ - \frac{dI}{dt} & = -\frac{I}{\tau_s} + Au^+x^-\delta(t-t_{sp}), - \end{aligned} - - where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment - of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding - variables just before the arrival of the spike, and :math:`u^+` - refers to the moment just after the spike. The synaptic current generated - at the synapse by the spike arriving at :math:`t_{sp}` is then given by - - .. math:: - - \Delta I(t_{spike}) = Au^+x^- - - where :math:`A` denotes the response amplitude that would be produced - by total release of all the neurotransmitter (:math:`u=x=1`), called - absolute synaptic efficacy of the connections. - - **Model Examples** - - - `STP for Working Memory Capacity `_ - - **STD** - - >>> import brainpy as bp - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = bp.neurons.LIF(1) - >>> neu2 = bp.neurons.LIF(1) - >>> syn1 = bp.synapses.STP(neu1, neu2, bp.connect.All2All(), U=0.2, tau_d=150., tau_f=2.) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x']) - >>> runner.run(150.) - >>> - >>> - >>> # plot - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 7) - >>> - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.u'][:, 0], label='u') - >>> plt.plot(runner.mon.ts, runner.mon['syn.x'][:, 0], label='x') - >>> plt.legend() - >>> - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.I'][:, 0], label='I') - >>> plt.legend() - >>> - >>> plt.xlabel('Time (ms)') - >>> plt.show() - - **STF** - - >>> import brainpy as bp - >>> import matplotlib.pyplot as plt - >>> - >>> neu1 = bp.neurons.LIF(1) - >>> neu2 = bp.neurons.LIF(1) - >>> syn1 = bp.neurons.STP(neu1, neu2, bp.connect.All2All(), U=0.1, tau_d=10, tau_f=100.) - >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) - >>> - >>> runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x']) - >>> runner.run(150.) - >>> - >>> - >>> # plot - >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 7) - >>> - >>> fig.add_subplot(gs[0, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.u'][:, 0], label='u') - >>> plt.plot(runner.mon.ts, runner.mon['syn.x'][:, 0], label='x') - >>> plt.legend() - >>> - >>> fig.add_subplot(gs[1, 0]) - >>> plt.plot(runner.mon.ts, runner.mon['syn.I'][:, 0], label='I') - >>> plt.legend() - >>> - >>> plt.xlabel('Time (ms)') - >>> plt.show() - - - **Model Parameters** - - ============= ============== ======== =========================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ------------------------------------------- - tau_d 200 ms Time constant of short-term depression. - tau_f 1500 ms Time constant of short-term facilitation. - U .15 \ The increment of :math:`u` produced by a spike. - A 1 \ The response amplitude that would be produced by total release of all the neurotransmitter - delay 0 ms The decay time of the current :math:`I` output onto the post-synaptic neuron groups. - ============= ============== ======== =========================================== - - - **Model Variables** - - =============== ================== ===================================================================== - **Member name** **Initial values** **Explanation** - --------------- ------------------ --------------------------------------------------------------------- - u 0 Release probability of the neurotransmitters. - x 1 A Normalized variable denoting the fraction of remain neurotransmitters. - I 0 Synapse current output onto the post-synaptic neurons. - =============== ================== ===================================================================== - - **References** - - .. [1] Stevens, Charles F., and Yanyan Wang. "Facilitation and depression - at single central synapses." Neuron 14, no. 4 (1995): 795-802. - .. [2] Abbott, Larry F., J. A. Varela, Kamal Sen, and S. B. Nelson. "Synaptic - depression and cortical gain control." Science 275, no. 5297 (1997): 221-224. - .. [3] Abbott, L. F., and Wade G. Regehr. "Synaptic computation." - Nature 431, no. 7010 (2004): 796-803. - .. [4] Tsodyks, Misha, Klaus Pawelzik, and Henry Markram. "Neural networks - with dynamic synapses." Neural computation 10.4 (1998): 821-835. - .. [5] Tsodyks, Misha, and Si Wu. "Short-term synaptic plasticity." - Scholarpedia 8, no. 10 (2013): 3153. - - """ - - def __init__( - self, - pre: NeuDyn, - post: NeuDyn, - conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - U: Union[float, ArrayType] = 0.15, - tau_f: Union[float, ArrayType] = 1500., - tau_d: Union[float, ArrayType] = 200., - tau: Union[float, ArrayType] = 8., - A: Union[float, ArrayType] = 1., - delay_step: Union[int, ArrayType, Initializer, Callable] = None, - method: str = 'exp_auto', - name: Optional[str] = None - ): - # parameters - self.tau_d = tau_d - self.tau_f = tau_f - self.tau = tau - self.U = U - self.A = A - - syn = _STPModel(pre.size, - pre.keep_size, - tau, - U, - tau_f, - tau_d, - method=method) - - super().__init__(pre=pre, - post=post, - syn=syn, - conn=conn, - g_max=A, - output=CUBA(), - comp_method='sparse', - delay_step=delay_step, - name=name) - - # variables - self.x = self.syn[0].x - self.u = self.syn[0].u - self.I = self.syn[1].g + r"""Short-term plasticity model. + + **Model Descriptions** + + Short-term plasticity (STP) [1]_ [2]_ [3]_, also called dynamical synapses, + refers to the changes of synaptic strengths over time in a way that reflects + the history of presynaptic activity. Two types of STP, with opposite effects + on synaptic efficacy, have been observed in experiments. They are known as + Short-Term Depression (STD) and Short-Term Facilitation (STF). + + In the model proposed by Tsodyks and Markram [4]_ [5]_, the STD effect is + modeled by a normalized variable :math:`x (0 \le x \le 1)`, denoting the fraction + of resources that remain available after neurotransmitter depletion. + The STF effect is modeled by a utilization parameter :math:`u`, representing + the fraction of available resources ready for use (release probability). + Following a spike, + + - (i) :math:`u` increases due to spike-induced calcium influx to the presynaptic + terminal, after which + - (ii) a fraction :math:`u` of available resources is consumed to produce the + post-synaptic current. + + Between spikes, :math:`u` decays back to zero with time constant :math:`\tau_f` + and :math:`x` recovers to 1 with time constant :math:`\tau_d`. + + In summary, the dynamics of STP is given by + + .. math:: + + \begin{aligned} + \frac{du}{dt} & = -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}),\nonumber \\ + \frac{dx}{dt} & = \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ + \frac{dI}{dt} & = -\frac{I}{\tau_s} + Au^+x^-\delta(t-t_{sp}), + \end{aligned} + + where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment + of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding + variables just before the arrival of the spike, and :math:`u^+` + refers to the moment just after the spike. The synaptic current generated + at the synapse by the spike arriving at :math:`t_{sp}` is then given by + + .. math:: + + \Delta I(t_{spike}) = Au^+x^- + + where :math:`A` denotes the response amplitude that would be produced + by total release of all the neurotransmitter (:math:`u=x=1`), called + absolute synaptic efficacy of the connections. + + **Model Examples** + + - `STP for Working Memory Capacity `_ + + **STD** + + >>> import brainpy as bp + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = bp.neurons.LIF(1) + >>> neu2 = bp.neurons.LIF(1) + >>> syn1 = bp.synapses.STP(neu1, neu2, bp.connect.All2All(), U=0.2, tau_d=150., tau_f=2.) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x']) + >>> runner.run(150.) + >>> + >>> + >>> # plot + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 7) + >>> + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.u'][:, 0], label='u') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'][:, 0], label='x') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.I'][:, 0], label='I') + >>> plt.legend() + >>> + >>> plt.xlabel('Time (ms)') + >>> plt.show() + + **STF** + + >>> import brainpy as bp + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = bp.neurons.LIF(1) + >>> neu2 = bp.neurons.LIF(1) + >>> syn1 = bp.neurons.STP(neu1, neu2, bp.connect.All2All(), U=0.1, tau_d=10, tau_f=100.) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x']) + >>> runner.run(150.) + >>> + >>> + >>> # plot + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 7) + >>> + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.u'][:, 0], label='u') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'][:, 0], label='x') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.I'][:, 0], label='I') + >>> plt.legend() + >>> + >>> plt.xlabel('Time (ms)') + >>> plt.show() + + + **Model Parameters** + + ============= ============== ======== =========================================== + **Parameter** **Init Value** **Unit** **Explanation** + ------------- -------------- -------- ------------------------------------------- + tau_d 200 ms Time constant of short-term depression. + tau_f 1500 ms Time constant of short-term facilitation. + U .15 \ The increment of :math:`u` produced by a spike. + A 1 \ The response amplitude that would be produced by total release of all the neurotransmitter + delay 0 ms The decay time of the current :math:`I` output onto the post-synaptic neuron groups. + ============= ============== ======== =========================================== + + + **Model Variables** + + =============== ================== ===================================================================== + **Member name** **Initial values** **Explanation** + --------------- ------------------ --------------------------------------------------------------------- + u 0 Release probability of the neurotransmitters. + x 1 A Normalized variable denoting the fraction of remain neurotransmitters. + I 0 Synapse current output onto the post-synaptic neurons. + =============== ================== ===================================================================== + + **References** + + .. [1] Stevens, Charles F., and Yanyan Wang. "Facilitation and depression + at single central synapses." Neuron 14, no. 4 (1995): 795-802. + .. [2] Abbott, Larry F., J. A. Varela, Kamal Sen, and S. B. Nelson. "Synaptic + depression and cortical gain control." Science 275, no. 5297 (1997): 221-224. + .. [3] Abbott, L. F., and Wade G. Regehr. "Synaptic computation." + Nature 431, no. 7010 (2004): 796-803. + .. [4] Tsodyks, Misha, Klaus Pawelzik, and Henry Markram. "Neural networks + with dynamic synapses." Neural computation 10.4 (1998): 821-835. + .. [5] Tsodyks, Misha, and Si Wu. "Short-term synaptic plasticity." + Scholarpedia 8, no. 10 (2013): 3153. + + """ + + def __init__( + self, + pre: NeuDyn, + post: NeuDyn, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + U: Union[float, ArrayType] = 0.15, + tau_f: Union[float, ArrayType] = 1500., + tau_d: Union[float, ArrayType] = 200., + tau: Union[float, ArrayType] = 8., + A: Union[float, ArrayType] = 1., + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + method: str = 'exp_auto', + name: Optional[str] = None + ): + # parameters + self.tau_d = tau_d + self.tau_f = tau_f + self.tau = tau + self.U = U + self.A = A + + syn = _STPModel(pre.size, + pre.keep_size, + tau, + U, + tau_f, + tau_d, + method=method) + + super().__init__(pre=pre, + post=post, + syn=syn, + conn=conn, + g_max=A, + output=CUBA(), + comp_method='sparse', + delay_step=delay_step, + name=name) + + # variables + self.x = self.syn[0].x + self.u = self.syn[0].u + self.I = self.syn[1].g diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index f6605cb44..4f192b885 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -3,123 +3,118 @@ from absl.testing import parameterized -import pytest import brainpy as bp import brainpy.math as bm from brainpy._src.dynold.synapses import abstract_models class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, monitors=['pre.V', 'syn.g', 'post.V'], inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - \ No newline at end of file + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, monitors=['pre.V', 'syn.g', 'post.V'], inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py index 4171d106e..5f3747a71 100644 --- a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py @@ -1,104 +1,99 @@ # -*- coding: utf-8 -*- -import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm - biological_models = [ - bp.synapses.AMPA, - bp.synapses.GABAa, - bp.synapses.BioNMDA, + bp.synapses.AMPA, + bp.synapses.GABAa, + bp.synapses.BioNMDA, ] class Test_Biological_Synapse(parameterized.TestCase): - @parameterized.product( - synapse=biological_models, - delay_step=[None, 5, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_all2all_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + @parameterized.product( + synapse=biological_models, + delay_step=[None, 5, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_all2all_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - @parameterized.product( - synapse=biological_models, - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_one2one_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + @parameterized.product( + synapse=biological_models, + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_one2one_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - @parameterized.product( - synapse=biological_models, - comp_method=['sparse', 'dense'], - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(10) - post_neu = bp.neurons.LIF(10) - syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), - comp_method=comp_method, delay_step=delay_step, - stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + @parameterized.product( + synapse=biological_models, + comp_method=['sparse', 'dense'], + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(10) + post_neu = bp.neurons.LIF(10) + syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), + comp_method=comp_method, delay_step=delay_step, + stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) - expected_shape = (100, 10) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - + expected_shape = (100, 10) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) diff --git a/brainpy/_src/dynold/synapses/tests/test_dynold_base_synapse.py b/brainpy/_src/dynold/synapses/tests/test_dynold_base_synapse.py index 9dc755586..ae89b94a4 100644 --- a/brainpy/_src/dynold/synapses/tests/test_dynold_base_synapse.py +++ b/brainpy/_src/dynold/synapses/tests/test_dynold_base_synapse.py @@ -1,12 +1,10 @@ - import unittest + import brainpy as bp class Test_TwoEndConnAlignPre(unittest.TestCase): - def test1(self): - E = bp.neurons.HH(size=4) - syn = bp.synapses.AMPA(E, E, bp.conn.All2All(include_self=False)) - self.assertTrue(syn.conn.include_self == syn.comm.include_self) - - + def test1(self): + E = bp.neurons.HH(size=4) + syn = bp.synapses.AMPA(E, E, bp.conn.All2All(include_self=False)) + self.assertTrue(syn.conn.include_self == syn.comm.include_self) diff --git a/brainpy/_src/dynold/synapses/tests/test_gap_junction.py b/brainpy/_src/dynold/synapses/tests/test_gap_junction.py index 6ec851edd..b4615d092 100644 --- a/brainpy/_src/dynold/synapses/tests/test_gap_junction.py +++ b/brainpy/_src/dynold/synapses/tests/test_gap_junction.py @@ -1,23 +1,23 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized from brainpy._src.dynold.synapses import gap_junction class Test_gap_junction(parameterized.TestCase): - def test_gap_junction(self): - bm.random.seed() - neu = bp.neurons.HH(2, V_initializer=bp.init.Constant(-70.68)) - syn = gap_junction.GapJunction(neu, neu, conn=bp.connect.All2All(include_self=False)) - net = bp.Network(syn=syn, neu=neu) + def test_gap_junction(self): + bm.random.seed() + neu = bp.neurons.HH(2, V_initializer=bp.init.Constant(-70.68)) + syn = gap_junction.GapJunction(neu, neu, conn=bp.connect.All2All(include_self=False)) + net = bp.Network(syn=syn, neu=neu) - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['neu.V'], - inputs=('neu.input', 35.)) - runner(10.) - self.assertTupleEqual(runner.mon['neu.V'].shape, (100, 2)) - + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['neu.V'], + inputs=('neu.input', 35.)) + runner(10.) + self.assertTupleEqual(runner.mon['neu.V'].shape, (100, 2)) diff --git a/brainpy/_src/dynold/synapses/tests/test_learning_rule.py b/brainpy/_src/dynold/synapses/tests/test_learning_rule.py index 6722063f7..3b7af7f4e 100644 --- a/brainpy/_src/dynold/synapses/tests/test_learning_rule.py +++ b/brainpy/_src/dynold/synapses/tests/test_learning_rule.py @@ -1,33 +1,32 @@ # -*- coding: utf-8 -*- +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized class Test_learning_rule(parameterized.TestCase): - @parameterized.product( - delay_step=[None, 5, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_learning_rule(self, delay_step, mode): - bm.random.seed() - with bm.environment(mode=mode): - neu1 = bp.neurons.LIF(5) - neu2 = bp.neurons.LIF(5) - syn1 = bp.synapses.STP(neu1, neu2, bp.connect.All2All(), U=0.1, tau_d=10, tau_f=100., - delay_step=delay_step) - net = bp.Network(pre=neu1, syn=syn1, post=neu2) - - runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x']) - runner.run(10.) + @parameterized.product( + delay_step=[None, 5, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_learning_rule(self, delay_step, mode): + bm.random.seed() + with bm.environment(mode=mode): + neu1 = bp.neurons.LIF(5) + neu2 = bp.neurons.LIF(5) + syn1 = bp.synapses.STP(neu1, neu2, bp.connect.All2All(), U=0.1, tau_d=10, tau_f=100., + delay_step=delay_step) + net = bp.Network(pre=neu1, syn=syn1, post=neu2) - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['syn.I'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.u'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.x'].shape, expected_shape) - + runner = bp.DSRunner(net, inputs=[('pre.input', 28.)], monitors=['syn.I', 'syn.u', 'syn.x']) + runner.run(10.) + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['syn.I'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.u'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.x'].shape, expected_shape) diff --git a/brainpy/_src/dynold/synouts/conductances.py b/brainpy/_src/dynold/synouts/conductances.py index c1f1e2080..2d73ff4a5 100644 --- a/brainpy/_src/dynold/synouts/conductances.py +++ b/brainpy/_src/dynold/synouts/conductances.py @@ -8,100 +8,99 @@ from brainpy.types import ArrayType __all__ = [ - 'COBA', - 'CUBA', + 'COBA', + 'CUBA', ] class CUBA(_SynOut): - r"""Current-based synaptic output. + r"""Current-based synaptic output. - Given the conductance, this model outputs the post-synaptic current with a identity function: + Given the conductance, this model outputs the post-synaptic current with a identity function: - .. math:: + .. math:: - I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) + I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - Parameters:: + Parameters:: - name: str - The model name. + name: str + The model name. - See Also:: + See Also:: - COBA - """ + COBA + """ - def __init__( - self, - target_var: Optional[Union[str, Variable]] = 'input', - name: str = None, - ): - self._target_var = target_var - super().__init__(name=name, target_var=target_var) + def __init__( + self, + target_var: Optional[Union[str, Variable]] = 'input', + name: str = None, + ): + self._target_var = target_var + super().__init__(name=name, target_var=target_var) - def clone(self): - return CUBA(target_var=self._target_var) + def clone(self): + return CUBA(target_var=self._target_var) class COBA(_SynOut): - r"""Conductance-based synaptic output. - - Given the synaptic conductance, the model output the post-synaptic current with - - .. math:: - - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - - Parameters:: - - E: float, ArrayType, ndarray, callable, Initializer - The reversal potential. - name: str - The model name. - - See Also:: - - CUBA - """ - - def __init__( - self, - E: Union[float, ArrayType, Callable, Initializer] = 0., - target_var: Optional[Union[str, Variable]] = 'input', - membrane_var: Union[str, Variable] = 'V', - name: str = None, - ): - super().__init__(name=name, target_var=target_var) - self._E = E - self._target_var = target_var - self._membrane_var = membrane_var - - def clone(self): - return COBA(E=self._E, - target_var=self._target_var, - membrane_var=self._membrane_var) - - def register_master(self, master): - super().register_master(master) - - # reversal potential - self.E = parameter(self._E, self.master.post.num, allow_none=False) - - # membrane potential - if isinstance(self._membrane_var, str): - if not hasattr(self.master.post, self._membrane_var): - raise KeyError(f'Post-synaptic group does not have membrane variable: {self._membrane_var}') - self.membrane_var = getattr(self.master.post, self._membrane_var) - elif isinstance(self._membrane_var, Variable): - self.membrane_var = self._membrane_var - else: - raise TypeError('"membrane_var" must be instance of string or Variable. ' - f'But we got {type(self._membrane_var)}') - - def filter(self, g): - V = self.membrane_var.value - I = g * (self.E - V) - return super(COBA, self).filter(I) - + r"""Conductance-based synaptic output. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + + Parameters:: + + E: float, ArrayType, ndarray, callable, Initializer + The reversal potential. + name: str + The model name. + + See Also:: + + CUBA + """ + + def __init__( + self, + E: Union[float, ArrayType, Callable, Initializer] = 0., + target_var: Optional[Union[str, Variable]] = 'input', + membrane_var: Union[str, Variable] = 'V', + name: str = None, + ): + super().__init__(name=name, target_var=target_var) + self._E = E + self._target_var = target_var + self._membrane_var = membrane_var + + def clone(self): + return COBA(E=self._E, + target_var=self._target_var, + membrane_var=self._membrane_var) + + def register_master(self, master): + super().register_master(master) + + # reversal potential + self.E = parameter(self._E, self.master.post.num, allow_none=False) + + # membrane potential + if isinstance(self._membrane_var, str): + if not hasattr(self.master.post, self._membrane_var): + raise KeyError(f'Post-synaptic group does not have membrane variable: {self._membrane_var}') + self.membrane_var = getattr(self.master.post, self._membrane_var) + elif isinstance(self._membrane_var, Variable): + self.membrane_var = self._membrane_var + else: + raise TypeError('"membrane_var" must be instance of string or Variable. ' + f'But we got {type(self._membrane_var)}') + + def filter(self, g): + V = self.membrane_var.value + I = g * (self.E - V) + return super(COBA, self).filter(I) diff --git a/brainpy/_src/dynold/synouts/ions.py b/brainpy/_src/dynold/synouts/ions.py index c81d7aac5..72657f63f 100644 --- a/brainpy/_src/dynold/synouts/ions.py +++ b/brainpy/_src/dynold/synouts/ions.py @@ -10,86 +10,85 @@ from brainpy.types import ArrayType __all__ = [ - 'MgBlock', + 'MgBlock', ] class MgBlock(_SynOut): - r"""Synaptic output based on Magnesium blocking. - - Given the synaptic conductance, the model output the post-synaptic current with - - .. math:: - - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) - - where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to - - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - - Parameters:: - - E: float, ArrayType, callable, Initializer - The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType - Binding constant. Default 0.062 - beta: float, ArrayType, callable, Initializer - Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType, callable, Initializer - Concentration of Magnesium ion. Default 1.2 [mM]. - name: str - The model name. - """ - - def __init__( - self, - E: Union[float, ArrayType, Callable, Initializer] = 0., - cc_Mg: Union[float, ArrayType, Callable, Initializer] = 1.2, - alpha: Union[float, ArrayType, Callable, Initializer] = 0.062, - beta: Union[float, ArrayType, Callable, Initializer] = 3.57, - target_var: Optional[Union[str, bm.Variable]] = 'input', - membrane_var: Union[str, bm.Variable] = 'V', - name: str = None, - ): - super().__init__(name=name, target_var=target_var) - self._E = E - self._cc_Mg = cc_Mg - self._alpha = alpha - self._beta = beta - self._target_var = target_var - self._membrane_var = membrane_var - - def register_master(self, master): - super().register_master(master) - - self.E = parameter(self._E, self.master.post.num, allow_none=False) - self.cc_Mg = parameter(self._cc_Mg, self.master.post.num, allow_none=False) - self.alpha = parameter(self._alpha, self.master.post.num, allow_none=False) - self.beta = parameter(self._beta, self.master.post.num, allow_none=False) - if isinstance(self._membrane_var, str): - if not hasattr(self.master.post, self._membrane_var): - raise KeyError(f'Post-synaptic group does not have membrane variable: {self._membrane_var}') - self.membrane_var = getattr(self.master.post, self._membrane_var) - elif isinstance(self._membrane_var, bm.Variable): - self.membrane_var = self._membrane_var - else: - raise TypeError('"membrane_var" must be instance of string or Variable. ' - f'But we got {type(self._membrane_var)}') - - def filter(self, g): - V = self.membrane_var.value - I = g * (self.E - V) / (1 + self.cc_Mg / self.beta * jnp.exp(-self.alpha * V)) - return super(MgBlock, self).filter(I) - - def clone(self): - return MgBlock(E=self._E, - cc_Mg=self._cc_Mg, - alpha=self._alpha, - beta=self._beta, - target_var=self._target_var, - membrane_var=self._membrane_var) - + r"""Synaptic output based on Magnesium blocking. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + + Parameters:: + + E: float, ArrayType, callable, Initializer + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType, callable, Initializer + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType, callable, Initializer + Concentration of Magnesium ion. Default 1.2 [mM]. + name: str + The model name. + """ + + def __init__( + self, + E: Union[float, ArrayType, Callable, Initializer] = 0., + cc_Mg: Union[float, ArrayType, Callable, Initializer] = 1.2, + alpha: Union[float, ArrayType, Callable, Initializer] = 0.062, + beta: Union[float, ArrayType, Callable, Initializer] = 3.57, + target_var: Optional[Union[str, bm.Variable]] = 'input', + membrane_var: Union[str, bm.Variable] = 'V', + name: str = None, + ): + super().__init__(name=name, target_var=target_var) + self._E = E + self._cc_Mg = cc_Mg + self._alpha = alpha + self._beta = beta + self._target_var = target_var + self._membrane_var = membrane_var + + def register_master(self, master): + super().register_master(master) + + self.E = parameter(self._E, self.master.post.num, allow_none=False) + self.cc_Mg = parameter(self._cc_Mg, self.master.post.num, allow_none=False) + self.alpha = parameter(self._alpha, self.master.post.num, allow_none=False) + self.beta = parameter(self._beta, self.master.post.num, allow_none=False) + if isinstance(self._membrane_var, str): + if not hasattr(self.master.post, self._membrane_var): + raise KeyError(f'Post-synaptic group does not have membrane variable: {self._membrane_var}') + self.membrane_var = getattr(self.master.post, self._membrane_var) + elif isinstance(self._membrane_var, bm.Variable): + self.membrane_var = self._membrane_var + else: + raise TypeError('"membrane_var" must be instance of string or Variable. ' + f'But we got {type(self._membrane_var)}') + + def filter(self, g): + V = self.membrane_var.value + I = g * (self.E - V) / (1 + self.cc_Mg / self.beta * jnp.exp(-self.alpha * V)) + return super(MgBlock, self).filter(I) + + def clone(self): + return MgBlock(E=self._E, + cc_Mg=self._cc_Mg, + alpha=self._alpha, + beta=self._beta, + target_var=self._target_var, + membrane_var=self._membrane_var) diff --git a/brainpy/_src/dynold/synplast/short_term_plasticity.py b/brainpy/_src/dynold/synplast/short_term_plasticity.py index a660e7af9..7020e2430 100644 --- a/brainpy/_src/dynold/synplast/short_term_plasticity.py +++ b/brainpy/_src/dynold/synplast/short_term_plasticity.py @@ -12,182 +12,181 @@ from brainpy.types import ArrayType __all__ = [ - 'STD', - 'STP', + 'STD', + 'STP', ] class STD(_SynSTP): - r"""Synaptic output with short-term depression. + r"""Synaptic output with short-term depression. - This model filters the synaptic current by the following equation: + This model filters the synaptic current by the following equation: - .. math:: + .. math:: - I_{syn}^+(t) = I_{syn}^-(t) * x + I_{syn}^+(t) = I_{syn}^-(t) * x - where :math:`x` is the normalized variable between 0 and 1, and - :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before - and after STD filtering. + where :math:`x` is the normalized variable between 0 and 1, and + :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STD filtering. - Moreover, :math:`x` is updated according to the dynamics of: + Moreover, :math:`x` is updated according to the dynamics of: - .. math:: + .. math:: - \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike}) + \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike}) - where :math:`U` is the fraction of resources used per action potential, - :math:`\tau` is the time constant of recovery of the synaptic vesicles. + where :math:`U` is the fraction of resources used per action potential, + :math:`\tau` is the time constant of recovery of the synaptic vesicles. - Parameters:: + Parameters:: - tau: float - The time constant of recovery of the synaptic vesicles. - U: float - The fraction of resources used per action potential. + tau: float + The time constant of recovery of the synaptic vesicles. + U: float + The fraction of resources used per action potential. - See Also:: + See Also:: - STP - """ + STP + """ - def __init__( - self, - tau: float = 200., - U: float = 0.07, - method: str = 'exp_auto', - name: str = None - ): - super().__init__(name=name) + def __init__( + self, + tau: float = 200., + U: float = 0.07, + method: str = 'exp_auto', + name: str = None + ): + super().__init__(name=name) - # parameters - is_float(tau, 'tau', min_bound=0, ) - is_float(U, 'U', min_bound=0, ) - self.tau = tau - self.U = U - self.method = method + # parameters + is_float(tau, 'tau', min_bound=0, ) + is_float(U, 'U', min_bound=0, ) + self.tau = tau + self.U = U + self.method = method - # integral function - self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=self.method) + # integral function + self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=self.method) - def clone(self): - return STD(tau=self.tau, U=self.U, method=self.method) + def clone(self): + return STD(tau=self.tau, U=self.U, method=self.method) - def register_master(self, master): - super().register_master(master) - self.x = variable(jnp.ones, self.master.mode, self.master.pre.num) + def register_master(self, master): + super().register_master(master) + self.x = variable(jnp.ones, self.master.mode, self.master.pre.num) - def reset_state(self, batch_size=None): - self.x.value = variable(jnp.ones, batch_size, self.master.pre.num) + def reset_state(self, batch_size=None): + self.x.value = variable(jnp.ones, batch_size, self.master.pre.num) - def update(self, pre_spike): - x = self.integral(self.x.value, share['t'], share['dt']) - self.x.value = jnp.where(pre_spike, x - self.U * self.x, x) + def update(self, pre_spike): + x = self.integral(self.x.value, share['t'], share['dt']) + self.x.value = jnp.where(pre_spike, x - self.U * self.x, x) - def filter(self, g): - if jnp.shape(g) != self.x.shape: - raise ValueError('Shape does not match.') - return g * self.x + def filter(self, g): + if jnp.shape(g) != self.x.shape: + raise ValueError('Shape does not match.') + return g * self.x - def __repr__(self): - return f'{self.__class__.__name__}(tau={self.tau}, U={self.U}, method={self.method})' + def __repr__(self): + return f'{self.__class__.__name__}(tau={self.tau}, U={self.U}, method={self.method})' class STP(_SynSTP): - r"""Synaptic output with short-term plasticity. - - This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. - - .. math:: - - I_{syn}^+(t) = I_{syn}^-(t) * x * u - - where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before - and after STP filtering, :math:`x` denotes the fraction of resources that remain available - after neurotransmitter depletion, and :math:`u` represents the fraction of available - resources ready for use (release probability). - - The dynamics of :math:`u` and :math:`x` are governed by - - .. math:: - - \begin{aligned} - \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\ - \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ - \tag{1}\end{aligned} - - where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment - of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding - variables just before the arrival of the spike, and :math:`u^+` - refers to the moment just after the spike. - - Parameters:: - - tau_f: float - The time constant of short-term facilitation. - tau_d: float - The time constant of short-term depression. - U: float - The fraction of resources used per action potential. - method: str - The numerical integral method. - - See Also:: - - STD - """ - - def __init__( - self, - U: Union[float, ArrayType] = 0.15, - tau_f: Union[float, ArrayType] = 1500., - tau_d: Union[float, ArrayType] = 200., - method: str = 'exp_auto', - name: str = None - ): - super(STP, self).__init__(name=name) - # parameters - is_float(tau_f, 'tau_f', min_bound=0, ) - is_float(tau_d, 'tau_d', min_bound=0, ) - is_float(U, 'U', min_bound=0, ) - self.tau_f = tau_f - self.tau_d = tau_d - self.U = U - self.method = method - - # integral function - self.integral = odeint(self.derivative, method=self.method) - - def clone(self): - return STP(tau_f=self.tau_f, tau_d=self.tau_d, U=self.U, method=self.method) - - def register_master(self, master): - super().register_master(master) - self.x = variable(jnp.ones, self.master.mode, self.master.pre.num) - self.u = variable(lambda s: jnp.ones(s) * self.U, self.master.mode, self.master.pre.num) - - def reset_state(self, batch_size=None): - self.x.value = variable(jnp.ones, batch_size, self.master.pre.num) - self.u.value = variable(lambda s: jnp.ones(s) * self.U, batch_size, self.master.pre.num) - - @property - def derivative(self): - du = lambda u, t: self.U - u / self.tau_f - dx = lambda x, t: (1 - x) / self.tau_d - return JointEq(du, dx) - - def update(self, pre_spike): - u, x = self.integral(self.u.value, self.x.value, share['t'], share['dt']) - u = jnp.where(pre_spike, u + self.U * (1 - self.u), u) - x = jnp.where(pre_spike, x - u * self.x, x) - self.x.value = x - self.u.value = u - - def filter(self, g): - if jnp.shape(g) != self.x.shape: - raise ValueError('Shape does not match.') - return g * self.x * self.u - - def __repr__(self): - return f'{self.__class__.__name__}(tau_f={self.tau_f}, tau_d={self.tau_d}, U={self.U}, method={self.method})' - + r"""Synaptic output with short-term plasticity. + + This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * x * u + + where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STP filtering, :math:`x` denotes the fraction of resources that remain available + after neurotransmitter depletion, and :math:`u` represents the fraction of available + resources ready for use (release probability). + + The dynamics of :math:`u` and :math:`x` are governed by + + .. math:: + + \begin{aligned} + \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\ + \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ + \tag{1}\end{aligned} + + where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment + of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding + variables just before the arrival of the spike, and :math:`u^+` + refers to the moment just after the spike. + + Parameters:: + + tau_f: float + The time constant of short-term facilitation. + tau_d: float + The time constant of short-term depression. + U: float + The fraction of resources used per action potential. + method: str + The numerical integral method. + + See Also:: + + STD + """ + + def __init__( + self, + U: Union[float, ArrayType] = 0.15, + tau_f: Union[float, ArrayType] = 1500., + tau_d: Union[float, ArrayType] = 200., + method: str = 'exp_auto', + name: str = None + ): + super(STP, self).__init__(name=name) + # parameters + is_float(tau_f, 'tau_f', min_bound=0, ) + is_float(tau_d, 'tau_d', min_bound=0, ) + is_float(U, 'U', min_bound=0, ) + self.tau_f = tau_f + self.tau_d = tau_d + self.U = U + self.method = method + + # integral function + self.integral = odeint(self.derivative, method=self.method) + + def clone(self): + return STP(tau_f=self.tau_f, tau_d=self.tau_d, U=self.U, method=self.method) + + def register_master(self, master): + super().register_master(master) + self.x = variable(jnp.ones, self.master.mode, self.master.pre.num) + self.u = variable(lambda s: jnp.ones(s) * self.U, self.master.mode, self.master.pre.num) + + def reset_state(self, batch_size=None): + self.x.value = variable(jnp.ones, batch_size, self.master.pre.num) + self.u.value = variable(lambda s: jnp.ones(s) * self.U, batch_size, self.master.pre.num) + + @property + def derivative(self): + du = lambda u, t: self.U - u / self.tau_f + dx = lambda x, t: (1 - x) / self.tau_d + return JointEq(du, dx) + + def update(self, pre_spike): + u, x = self.integral(self.u.value, self.x.value, share['t'], share['dt']) + u = jnp.where(pre_spike, u + self.U * (1 - self.u), u) + x = jnp.where(pre_spike, x - u * self.x, x) + self.x.value = x + self.u.value = u + + def filter(self, g): + if jnp.shape(g) != self.x.shape: + raise ValueError('Shape does not match.') + return g * self.x * self.u + + def __repr__(self): + return f'{self.__class__.__name__}(tau_f={self.tau_f}, tau_d={self.tau_d}, U={self.U}, method={self.method})' diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 964d6f465..371e0128a 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -17,14 +17,14 @@ from brainpy.types import ArrayType, Shape __all__ = [ - # general - 'DynamicalSystem', + # general + 'DynamicalSystem', - # containers - 'DynSysGroup', 'Network', 'Sequential', + # containers + 'DynSysGroup', 'Network', 'Sequential', - # category - 'Dynamic', 'Projection', + # category + 'Dynamic', 'Projection', ] IonChaDyn = None @@ -35,907 +35,902 @@ def not_implemented(fun): - def new_fun(*args, **kwargs): - return fun(*args, **kwargs) + def new_fun(*args, **kwargs): + return fun(*args, **kwargs) - new_fun._not_implemented = True - return new_fun + new_fun._not_implemented = True + return new_fun class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj): - """Base Dynamical System class. - - .. note:: - In general, every instance of :py:class:`~.DynamicalSystem` implemented in - BrainPy only defines the evolving function at each time step :math:`t`. - - If users want to define the logic of running models across multiple steps, - we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`, - :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`. - - To be compatible with previous APIs, :py:class:`~.DynamicalSystem` inherits - from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of - :py:class:`~.DelayRegister` will be removed in the future, including: - - - ``.register_delay()`` - - ``.get_delay_data()`` - - ``.update_local_delays()`` - - ``.reset_local_delays()`` - - Parameters:: - - name : optional, str - The name of the dynamical system. - mode: optional, Mode - The model computation mode. It should be an instance of :py:class:`~.Mode`. - """ - - supported_modes: Optional[Sequence[bm.Mode]] = None - '''Supported computing modes.''' - - def __init__( - self, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - # mode setting - mode = bm.get_mode() if mode is None else mode - if not isinstance(mode, bm.Mode): - raise ValueError(f'Should be instance of {bm.Mode.__name__}, ' - f'but we got {type(mode)}: {mode}') - self._mode = mode - - if self.supported_modes is not None: - if not self.mode.is_parent_of(*self.supported_modes): - raise UnsupportedError(f'The mode only supports computing modes ' - f'which are parents of {self.supported_modes}, ' - f'but we got {self.mode}.') - - # Attribute for "SupportInputProj" - # each instance of "SupportInputProj" should have a "cur_inputs" attribute - self._current_inputs: Optional[Dict[str, Callable]] = None - self._delta_inputs: Optional[Dict[str, Callable]] = None - - # the before- / after-updates used for computing - # added after the version of 2.4.3 - self._before_updates: Optional[Dict[str, Callable]] = None - self._after_updates: Optional[Dict[str, Callable]] = None - - # super initialization - super().__init__(name=name) - - @property - def current_inputs(self): - if self._current_inputs is None: - self._current_inputs = bm.node_dict() - return self._current_inputs - - @property - def delta_inputs(self): - if self._delta_inputs is None: - self._delta_inputs = bm.node_dict() - return self._delta_inputs - - @property - def before_updates(self): - if self._before_updates is None: - self._before_updates = bm.node_dict() - return self._before_updates - - @property - def after_updates(self): - if self._after_updates is None: - self._after_updates = bm.node_dict() - return self._after_updates - - def add_bef_update(self, key: Any, fun: Callable): - """Add the before update into this node""" - if key in self.before_updates: - raise KeyError(f'{key} has been registered in before_updates of {self}') - self.before_updates[key] = fun - - def add_aft_update(self, key: Any, fun: Callable): - """Add the after update into this node""" - if key in self.after_updates: - raise KeyError(f'{key} has been registered in after_updates of {self}') - self.after_updates[key] = fun - - def get_bef_update(self, key: Any): - """Get the before update of this node by the given ``key``.""" - if key not in self.before_updates: - raise KeyError(f'{key} is not registered in before_updates of {self}') - return self.before_updates.get(key) - - def get_aft_update(self, key: Any): - """Get the after update of this node by the given ``key``.""" - if key not in self.after_updates: - raise KeyError(f'{key} is not registered in after_updates of {self}') - return self.after_updates.get(key) - - def has_bef_update(self, key: Any): - """Whether this node has the before update of the given ``key``.""" - return key in self.before_updates - - def has_aft_update(self, key: Any): - """Whether this node has the after update of the given ``key``.""" - return key in self.after_updates - - def update(self, *args, **kwargs): - """The function to specify the updating rule. - """ - raise NotImplementedError('Must implement "update" function by subclass self.') + """Base Dynamical System class. - def reset(self, *args, **kwargs): - """Reset function which reset the whole variables in the model (including its children models). + .. note:: + In general, every instance of :py:class:`~.DynamicalSystem` implemented in + BrainPy only defines the evolving function at each time step :math:`t`. - ``reset()`` function is a collective behavior which resets all states in this model. + If users want to define the logic of running models across multiple steps, + we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`, + :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`. - See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. - """ - global reset_state - if reset_state is None: - from brainpy._src.helpers import reset_state - reset_state(self, *args, **kwargs) + To be compatible with previous APIs, :py:class:`~.DynamicalSystem` inherits + from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of + :py:class:`~.DelayRegister` will be removed in the future, including: - @not_implemented - def reset_state(self, *args, **kwargs): - """Reset function which resets local states in this model. + - ``.register_delay()`` + - ``.get_delay_data()`` + - ``.update_local_delays()`` + - ``.reset_local_delays()`` - Simply speaking, this function should implement the logic of resetting of - local variables in this node. + Parameters:: - See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. + name : optional, str + The name of the dynamical system. + mode: optional, Mode + The model computation mode. It should be an instance of :py:class:`~.Mode`. """ - pass - def clear_input(self, *args, **kwargs): - """Clear the input at the current time step.""" - pass + supported_modes: Optional[Sequence[bm.Mode]] = None + '''Supported computing modes.''' + + def __init__( + self, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + # mode setting + mode = bm.get_mode() if mode is None else mode + if not isinstance(mode, bm.Mode): + raise ValueError(f'Should be instance of {bm.Mode.__name__}, ' + f'but we got {type(mode)}: {mode}') + self._mode = mode + + if self.supported_modes is not None: + if not self.mode.is_parent_of(*self.supported_modes): + raise UnsupportedError(f'The mode only supports computing modes ' + f'which are parents of {self.supported_modes}, ' + f'but we got {self.mode}.') + + # Attribute for "SupportInputProj" + # each instance of "SupportInputProj" should have a "cur_inputs" attribute + self._current_inputs: Optional[Dict[str, Callable]] = None + self._delta_inputs: Optional[Dict[str, Callable]] = None + + # the before- / after-updates used for computing + # added after the version of 2.4.3 + self._before_updates: Optional[Dict[str, Callable]] = None + self._after_updates: Optional[Dict[str, Callable]] = None + + # super initialization + super().__init__(name=name) + + @property + def current_inputs(self): + if self._current_inputs is None: + self._current_inputs = bm.node_dict() + return self._current_inputs + + @property + def delta_inputs(self): + if self._delta_inputs is None: + self._delta_inputs = bm.node_dict() + return self._delta_inputs + + @property + def before_updates(self): + if self._before_updates is None: + self._before_updates = bm.node_dict() + return self._before_updates + + @property + def after_updates(self): + if self._after_updates is None: + self._after_updates = bm.node_dict() + return self._after_updates + + def add_bef_update(self, key: Any, fun: Callable): + """Add the before update into this node""" + if key in self.before_updates: + raise KeyError(f'{key} has been registered in before_updates of {self}') + self.before_updates[key] = fun + + def add_aft_update(self, key: Any, fun: Callable): + """Add the after update into this node""" + if key in self.after_updates: + raise KeyError(f'{key} has been registered in after_updates of {self}') + self.after_updates[key] = fun + + def get_bef_update(self, key: Any): + """Get the before update of this node by the given ``key``.""" + if key not in self.before_updates: + raise KeyError(f'{key} is not registered in before_updates of {self}') + return self.before_updates.get(key) + + def get_aft_update(self, key: Any): + """Get the after update of this node by the given ``key``.""" + if key not in self.after_updates: + raise KeyError(f'{key} is not registered in after_updates of {self}') + return self.after_updates.get(key) + + def has_bef_update(self, key: Any): + """Whether this node has the before update of the given ``key``.""" + return key in self.before_updates + + def has_aft_update(self, key: Any): + """Whether this node has the after update of the given ``key``.""" + return key in self.after_updates + + def update(self, *args, **kwargs): + """The function to specify the updating rule. + """ + raise NotImplementedError('Must implement "update" function by subclass self.') + + def reset(self, *args, **kwargs): + """Reset function which reset the whole variables in the model (including its children models). + + ``reset()`` function is a collective behavior which resets all states in this model. + + See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. + """ + global reset_state + if reset_state is None: + from brainpy._src.helpers import reset_state + reset_state(self, *args, **kwargs) + + @not_implemented + def reset_state(self, *args, **kwargs): + """Reset function which resets local states in this model. + + Simply speaking, this function should implement the logic of resetting of + local variables in this node. + + See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. + """ + pass + + def clear_input(self, *args, **kwargs): + """Clear the input at the current time step.""" + pass + + def step_run(self, i, *args, **kwargs): + """The step run function. + + This function can be directly applied to run the dynamical system. + Particularly, ``i`` denotes the running index. + + Args: + i: The current running index. + *args: The arguments of ``update()`` function. + **kwargs: The arguments of ``update()`` function. + + Returns: + out: The update function returns. + """ + global clear_input + if clear_input is None: + from brainpy._src.helpers import clear_input + share.save(i=i, t=i * bm.dt) + out = self.update(*args, **kwargs) + clear_input(self) + return out + + @bm.cls_jit(inline=True) + def jit_step_run(self, i, *args, **kwargs): + """The jitted step function for running. + + Args: + i: The current running index. + *args: The arguments of ``update()`` function. + **kwargs: The arguments of ``update()`` function. + + Returns: + out: The update function returns. + """ + return self.step_run(i, *args, **kwargs) + + @property + def mode(self) -> bm.Mode: + """Mode of the model, which is useful to control the multiple behaviors of the model.""" + return self._mode + + @mode.setter + def mode(self, value): + if not isinstance(value, bm.Mode): + raise ValueError(f'Must be instance of {bm.Mode.__name__}, ' + f'but we got {type(value)}: {value}') + self._mode = value + + def register_local_delay( + self, + var_name: str, + delay_name: str, + delay_time: Union[numbers.Number, ArrayType] = None, + delay_step: Union[numbers.Number, ArrayType] = None, + ): + """Register local relay at the given delay time. + + Args: + var_name: str. The name of the delay target variable. + delay_name: str. The name of the current delay data. + delay_time: The delay time. Float. + delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``. + """ + delay_identifier, init_delay_by_return = _get_delay_tool() + delay_identifier = delay_identifier + var_name + # check whether the "var_name" has been registered + try: + target = getattr(self, var_name) + except AttributeError: + raise AttributeError(f'This node {self} does not has attribute of "{var_name}".') + if not self.has_aft_update(delay_identifier): + # add a model to receive the return of the target model + # moreover, the model should not receive the return of the update function + model = not_receive_update_output(init_delay_by_return(target)) + # register the model + self.add_aft_update(delay_identifier, model) + delay_cls = self.get_aft_update(delay_identifier) + delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step) + + def get_local_delay(self, var_name, delay_name): + """Get the delay at the given identifier (`name`). + + Args: + var_name: The name of the target delay variable. + delay_name: The identifier of the delay. + + Returns: + The delayed data at the given delay position. + """ + delay_identifier, init_delay_by_return = _get_delay_tool() + delay_identifier = delay_identifier + var_name + return self.get_aft_update(delay_identifier).at(delay_name) + + def _compatible_update(self, *args, **kwargs): + update_fun = super().__getattribute__('update') + update_args = tuple(inspect.signature(update_fun).parameters.values()) + + if len(update_args) and update_args[0].name in ['tdi', 'sh', 'sha']: + # define the update function with: + # update(tdi, *args, **kwargs) + # + if len(args) > 0: + if isinstance(args[0], dict) and all([bm.ndim(v) == 0 for v in args[0].values()]): + # define: + # update(tdi, *args, **kwargs) + # call: + # update(tdi, *args, **kwargs) + ret = update_fun(*args, **kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + else: + # define: + # update(tdi, *args, **kwargs) + # call: + # update(*args, **kwargs) + ret = update_fun(share.get_shargs(), *args, **kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + else: + if update_args[0].name in kwargs: + # define: + # update(tdi, *args, **kwargs) + # call: + # update(tdi=??, **kwargs) + ret = update_fun(**kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + else: + # define: + # update(tdi, *args, **kwargs) + # call: + # update(**kwargs) + ret = update_fun(share.get_shargs(), *args, **kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + return ret - def step_run(self, i, *args, **kwargs): - """The step run function. + try: + ba = inspect.signature(update_fun).bind(*args, **kwargs) + except TypeError: + if len(args) and isinstance(args[0], dict): + # user define ``update()`` function which does not receive the shared argument, + # but do provide these shared arguments when calling ``update()`` function + # ----- + # change + # update(tdi, *args, **kwargs) + # as + # update(*args, **kwargs) + share.save(**args[0]) + ret = update_fun(*args[1:], **kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + return ret + else: + # user define ``update()`` function which receives the shared argument, + # but not provide these shared arguments when calling ``update()`` function + # ----- + # change + # update(*args, **kwargs) + # as + # update(tdi, *args, **kwargs) + ret = update_fun(share.get_shargs(), *args, **kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + return ret + else: + if len(args) and isinstance(args[0], dict) and all([bm.ndim(v) == 0 for v in args[0].values()]): + try: + ba = inspect.signature(update_fun).bind(*args[1:], **kwargs) + except TypeError: + pass + else: + # ----- + # define as: + # update(x=None) + # call as + # update(tdi) + share.save(**args[0]) + ret = update_fun(*args[1:], **kwargs) + warnings.warn(_update_deprecate_msg, UserWarning) + return ret + return update_fun(*args, **kwargs) + + def _compatible_reset_state(self, *args, **kwargs): + global the_top_layer_reset_state + the_top_layer_reset_state = False + try: + if hasattr(self.reset_state, '_not_implemented'): + self.reset(*args, **kwargs) + warnings.warn( + ''' + From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_saving_and_loading.html for details. + + 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use + "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)". + ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes). + + 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. + + ''', + DeprecationWarning + ) + else: + self.reset_state(*args, **kwargs) + finally: + the_top_layer_reset_state = True + + def _get_update_fun(self): + return object.__getattribute__(self, 'update') + + def __getattribute__(self, item): + if item == 'update': + return self._compatible_update # update function compatible with previous ``update()`` function + if item == 'reset_state': + if the_top_layer_reset_state: + return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function + return super().__getattribute__(item) + + def __repr__(self): + return f'{self.name}(mode={self.mode})' + + def __call__(self, *args, **kwargs): + """The shortcut to call ``update`` methods.""" + + # ``before_updates`` + for model in self.before_updates.values(): + if hasattr(model, '_receive_update_input'): + model(*args, **kwargs) + else: + model() + + # update the model self + ret = self.update(*args, **kwargs) + + # ``after_updates`` + for model in self.after_updates.values(): + if hasattr(model, '_not_receive_update_output'): + model() + else: + model(ret) + return ret - This function can be directly applied to run the dynamical system. - Particularly, ``i`` denotes the running index. + def __rrshift__(self, other): + """Support using right shift operator to call modules. - Args: - i: The current running index. - *args: The arguments of ``update()`` function. - **kwargs: The arguments of ``update()`` function. + Examples:: - Returns: - out: The update function returns. - """ - global clear_input - if clear_input is None: - from brainpy._src.helpers import clear_input - share.save(i=i, t=i * bm.dt) - out = self.update(*args, **kwargs) - clear_input(self) - return out - - @bm.cls_jit(inline=True) - def jit_step_run(self, i, *args, **kwargs): - """The jitted step function for running. + >>> import brainpy as bp + >>> x = bp.math.random.rand((10, 10)) + >>> l = bp.layers.Activation(bm.tanh) + >>> y = x >> l + """ + return self.__call__(other) - Args: - i: The current running index. - *args: The arguments of ``update()`` function. - **kwargs: The arguments of ``update()`` function. - Returns: - out: The update function returns. - """ - return self.step_run(i, *args, **kwargs) - - @property - def mode(self) -> bm.Mode: - """Mode of the model, which is useful to control the multiple behaviors of the model.""" - return self._mode - - @mode.setter - def mode(self, value): - if not isinstance(value, bm.Mode): - raise ValueError(f'Must be instance of {bm.Mode.__name__}, ' - f'but we got {type(value)}: {value}') - self._mode = value - - def register_local_delay( - self, - var_name: str, - delay_name: str, - delay_time: Union[numbers.Number, ArrayType] = None, - delay_step: Union[numbers.Number, ArrayType] = None, - ): - """Register local relay at the given delay time. +class DynSysGroup(DynamicalSystem, Container): + """A group of :py:class:`~.DynamicalSystem`s in which the updating order does not matter. Args: - var_name: str. The name of the delay target variable. - delay_name: str. The name of the current delay data. - delay_time: The delay time. Float. - delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``. + children_as_tuple: The children objects. + children_as_dict: The children objects. + name: The object name. + mode: The mode which controls the model computation. + child_type: The type of the children object. Default is :py:class:`DynamicalSystem`. """ - delay_identifier, init_delay_by_return = _get_delay_tool() - delay_identifier = delay_identifier + var_name - # check whether the "var_name" has been registered - try: - target = getattr(self, var_name) - except AttributeError: - raise AttributeError(f'This node {self} does not has attribute of "{var_name}".') - if not self.has_aft_update(delay_identifier): - # add a model to receive the return of the target model - # moreover, the model should not receive the return of the update function - model = not_receive_update_output(init_delay_by_return(target)) - # register the model - self.add_aft_update(delay_identifier, model) - delay_cls = self.get_aft_update(delay_identifier) - delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step) - - def get_local_delay(self, var_name, delay_name): - """Get the delay at the given identifier (`name`). - Args: - var_name: The name of the target delay variable. - delay_name: The identifier of the delay. + def __init__( + self, + *children_as_tuple, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + child_type: type = DynamicalSystem, + **children_as_dict + ): + super().__init__(name=name, mode=mode) + + # Attribute of "Container" + self.children = bm.node_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict)) + + def update(self, *args, **kwargs): + """Step function of a network. + + In this update function, the update functions in children systems are + iteratively called. + """ + nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView) + + # update nodes of projections + for node in nodes.subset(Projection).values(): + node() + + # update nodes of dynamics + for node in nodes.subset(Dynamic).values(): + node() + + # update nodes with other types, including delays, ... + for node in nodes.not_subset(Dynamic).not_subset(Projection).values(): + node() + - Returns: - The delayed data at the given delay position. +class Network(DynSysGroup): + """A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network. """ - delay_identifier, init_delay_by_return = _get_delay_tool() - delay_identifier = delay_identifier + var_name - return self.get_aft_update(delay_identifier).at(delay_name) - - def _compatible_update(self, *args, **kwargs): - update_fun = super().__getattribute__('update') - update_args = tuple(inspect.signature(update_fun).parameters.values()) - - if len(update_args) and update_args[0].name in ['tdi', 'sh', 'sha']: - # define the update function with: - # update(tdi, *args, **kwargs) - # - if len(args) > 0: - if isinstance(args[0], dict) and all([bm.ndim(v) == 0 for v in args[0].values()]): - # define: - # update(tdi, *args, **kwargs) - # call: - # update(tdi, *args, **kwargs) - ret = update_fun(*args, **kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - else: - # define: - # update(tdi, *args, **kwargs) - # call: - # update(*args, **kwargs) - ret = update_fun(share.get_shargs(), *args, **kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - else: - if update_args[0].name in kwargs: - # define: - # update(tdi, *args, **kwargs) - # call: - # update(tdi=??, **kwargs) - ret = update_fun(**kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - else: - # define: - # update(tdi, *args, **kwargs) - # call: - # update(**kwargs) - ret = update_fun(share.get_shargs(), *args, **kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - return ret - - try: - ba = inspect.signature(update_fun).bind(*args, **kwargs) - except TypeError: - if len(args) and isinstance(args[0], dict): - # user define ``update()`` function which does not receive the shared argument, - # but do provide these shared arguments when calling ``update()`` function - # ----- - # change - # update(tdi, *args, **kwargs) - # as - # update(*args, **kwargs) - share.save(**args[0]) - ret = update_fun(*args[1:], **kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - return ret - else: - # user define ``update()`` function which receives the shared argument, - # but not provide these shared arguments when calling ``update()`` function - # ----- - # change - # update(*args, **kwargs) - # as - # update(tdi, *args, **kwargs) - ret = update_fun(share.get_shargs(), *args, **kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - return ret - else: - if len(args) and isinstance(args[0], dict) and all([bm.ndim(v) == 0 for v in args[0].values()]): - try: - ba = inspect.signature(update_fun).bind(*args[1:], **kwargs) - except TypeError: - pass - else: - # ----- - # define as: - # update(x=None) - # call as - # update(tdi) - share.save(**args[0]) - ret = update_fun(*args[1:], **kwargs) - warnings.warn(_update_deprecate_msg, UserWarning) - return ret - return update_fun(*args, **kwargs) - - def _compatible_reset_state(self, *args, **kwargs): - global the_top_layer_reset_state - the_top_layer_reset_state = False - try: - if hasattr(self.reset_state, '_not_implemented'): - self.reset(*args, **kwargs) - warnings.warn( - ''' - From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_saving_and_loading.html for details. - - 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use - "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)". - ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes). - - 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. - - ''', - DeprecationWarning - ) - else: - self.reset_state(*args, **kwargs) - finally: - the_top_layer_reset_state = True - - def _get_update_fun(self): - return object.__getattribute__(self, 'update') - - def __getattribute__(self, item): - if item == 'update': - return self._compatible_update # update function compatible with previous ``update()`` function - if item == 'reset_state': - if the_top_layer_reset_state: - return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function - return super().__getattribute__(item) - - def __repr__(self): - return f'{self.name}(mode={self.mode})' - - def __call__(self, *args, **kwargs): - """The shortcut to call ``update`` methods.""" - - # ``before_updates`` - for model in self.before_updates.values(): - if hasattr(model, '_receive_update_input'): - model(*args, **kwargs) - else: - model() - - # update the model self - ret = self.update(*args, **kwargs) - - # ``after_updates`` - for model in self.after_updates.values(): - if hasattr(model, '_not_receive_update_output'): - model() - else: - model(ret) - return ret - - def __rrshift__(self, other): - """Support using right shift operator to call modules. + pass + + +class Sequential(DynamicalSystem, SupportAutoDelay, Container): + """A sequential `input-output` module. + + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``dict`` of modules can be + passed in. The ``update()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :py:class:`Container`? A ``Container`` is exactly what it + sounds like--a container to store :py:class:`DynamicalSystem` s! + On the other hand, the layers in a ``Sequential`` are connected + in a cascading way. Examples:: - + >>> import brainpy as bp - >>> x = bp.math.random.rand((10, 10)) - >>> l = bp.layers.Activation(bm.tanh) - >>> y = x >> l + >>> import brainpy.math as bm + >>> + >>> # composing ANN models + >>> l = bp.Sequential(bp.layers.Dense(100, 10), + >>> bm.relu, + >>> bp.layers.Dense(10, 2)) + >>> l(bm.random.random((256, 100))) + >>> + >>> # Using Sequential with Dict. This is functionally the + >>> # same as the above code + >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10), + >>> l2=bm.relu, + >>> l3=bp.layers.Dense(10, 2)) + >>> l(bm.random.random((256, 100))) + + + Args: + modules_as_tuple: The children modules. + modules_as_dict: The children modules. + name: The object name. + mode: The object computing context/mode. Default is ``None``. """ - return self.__call__(other) + def __init__( + self, + *modules_as_tuple, + name: str = None, + mode: bm.Mode = None, + **modules_as_dict + ): + super().__init__(name=name, mode=mode) + + # Attribute of "Container" + self.children = bm.node_dict(self.format_elements(object, *modules_as_tuple, **modules_as_dict)) + + def update(self, x): + """Update function of a sequential model. + """ + for m in self.children.values(): + x = m(x) + return x + + def return_info(self): + last = self[-1] + if not isinstance(last, SupportAutoDelay): + raise UnsupportedError(f'Does not support "return_info()" because the last node is ' + f'not instance of {SupportAutoDelay.__name__}') + return last.return_info() + + def __getitem__(self, key: Union[int, slice, str]): + if isinstance(key, str): + if key in self.children: + return self.children[key] + else: + raise KeyError(f'Does not find a component named {key} in\n {str(self)}') + elif isinstance(key, slice): + return Sequential(**dict(tuple(self.children.items())[key])) + elif isinstance(key, int): + return tuple(self.children.values())[key] + elif isinstance(key, (tuple, list)): + _all_nodes = tuple(self.children.items()) + return Sequential(**dict(_all_nodes[k] for k in key)) + else: + raise KeyError(f'Unknown type of key: {type(key)}') -class DynSysGroup(DynamicalSystem, Container): - """A group of :py:class:`~.DynamicalSystem`s in which the updating order does not matter. - - Args: - children_as_tuple: The children objects. - children_as_dict: The children objects. - name: The object name. - mode: The mode which controls the model computation. - child_type: The type of the children object. Default is :py:class:`DynamicalSystem`. - """ - - def __init__( - self, - *children_as_tuple, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - child_type: type = DynamicalSystem, - **children_as_dict - ): - super().__init__(name=name, mode=mode) - - # Attribute of "Container" - self.children = bm.node_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict)) - - def update(self, *args, **kwargs): - """Step function of a network. - - In this update function, the update functions in children systems are - iteratively called. + def __repr__(self): + nodes = self.children.values() + entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes)) + return f'{self.__class__.__name__}(\n{entries}\n)' + + +class Projection(DynamicalSystem): + """Base class to model synaptic projections. + + Args: + name: The name of the dynamic system. + mode: The computing mode. It should be an instance of :py:class:`~.Mode`. """ - nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView) - # update nodes of projections - for node in nodes.subset(Projection).values(): - node() + def update(self, *args, **kwargs): + nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()) + if len(nodes): + for node in nodes: + node.update(*args, **kwargs) + else: + raise ValueError('Do not implement the update() function.') - # update nodes of dynamics - for node in nodes.subset(Dynamic).values(): - node() + def clear_input(self, *args, **kwargs): + """Empty function of clearing inputs.""" + pass - # update nodes with other types, including delays, ... - for node in nodes.not_subset(Dynamic).not_subset(Projection).values(): - node() + def reset_state(self, *args, **kwargs): + pass -class Network(DynSysGroup): - """A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network. - """ - pass +class Dynamic(DynamicalSystem): + """Base class to model dynamics. + There are several essential attributes: -class Sequential(DynamicalSystem, SupportAutoDelay, Container): - """A sequential `input-output` module. - - Modules will be added to it in the order they are passed in the - constructor. Alternatively, an ``dict`` of modules can be - passed in. The ``update()`` method of ``Sequential`` accepts any - input and forwards it to the first module it contains. It then - "chains" outputs to inputs sequentially for each subsequent module, - finally returning the output of the last module. - - The value a ``Sequential`` provides over manually calling a sequence - of modules is that it allows treating the whole container as a - single module, such that performing a transformation on the - ``Sequential`` applies to each of the modules it stores (which are - each a registered submodule of the ``Sequential``). - - What's the difference between a ``Sequential`` and a - :py:class:`Container`? A ``Container`` is exactly what it - sounds like--a container to store :py:class:`DynamicalSystem` s! - On the other hand, the layers in a ``Sequential`` are connected - in a cascading way. - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> # composing ANN models - >>> l = bp.Sequential(bp.layers.Dense(100, 10), - >>> bm.relu, - >>> bp.layers.Dense(10, 2)) - >>> l(bm.random.random((256, 100))) - >>> - >>> # Using Sequential with Dict. This is functionally the - >>> # same as the above code - >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10), - >>> l2=bm.relu, - >>> l3=bp.layers.Dense(10, 2)) - >>> l(bm.random.random((256, 100))) - - - Args: - modules_as_tuple: The children modules. - modules_as_dict: The children modules. - name: The object name. - mode: The object computing context/mode. Default is ``None``. - """ - - def __init__( - self, - *modules_as_tuple, - name: str = None, - mode: bm.Mode = None, - **modules_as_dict - ): - super().__init__(name=name, mode=mode) - - # Attribute of "Container" - self.children = bm.node_dict(self.format_elements(object, *modules_as_tuple, **modules_as_dict)) - - def update(self, x): - """Update function of a sequential model. + - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of + neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes + a 3-dimensional neuron group. + - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \ + `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`. + + Args: + size: The neuron group geometry. + name: The name of the dynamic system. + keep_size: Whether keep the geometry information. + mode: The computing mode. """ - for m in self.children.values(): - x = m(x) - return x - - def return_info(self): - last = self[-1] - if not isinstance(last, SupportAutoDelay): - raise UnsupportedError(f'Does not support "return_info()" because the last node is ' - f'not instance of {SupportAutoDelay.__name__}') - return last.return_info() - - def __getitem__(self, key: Union[int, slice, str]): - if isinstance(key, str): - if key in self.children: - return self.children[key] - else: - raise KeyError(f'Does not find a component named {key} in\n {str(self)}') - elif isinstance(key, slice): - return Sequential(**dict(tuple(self.children.items())[key])) - elif isinstance(key, int): - return tuple(self.children.values())[key] - elif isinstance(key, (tuple, list)): - _all_nodes = tuple(self.children.items()) - return Sequential(**dict(_all_nodes[k] for k in key)) - else: - raise KeyError(f'Unknown type of key: {type(key)}') - - def __repr__(self): - nodes = self.children.values() - entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes)) - return f'{self.__class__.__name__}(\n{entries}\n)' + def __init__( + self, + size: Shape, + keep_size: bool = False, + sharding: Optional[Any] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + method: str = 'exp_auto' + ): + # size + if isinstance(size, (list, tuple)): + if len(size) <= 0: + raise ValueError(f'size must be int, or a tuple/list of int. ' + f'But we got {type(size)}') + if not isinstance(size[0], (int, np.integer)): + raise ValueError('size must be int, or a tuple/list of int.' + f'But we got {type(size)}') + size = tuple(size) + elif isinstance(size, (int, np.integer)): + size = (size,) + else: + raise ValueError('size must be int, or a tuple/list of int.' + f'But we got {type(size)}') + self.size = size + self.keep_size = keep_size -class Projection(DynamicalSystem): - """Base class to model synaptic projections. - - Args: - name: The name of the dynamic system. - mode: The computing mode. It should be an instance of :py:class:`~.Mode`. - """ - - def update(self, *args, **kwargs): - nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()) - if len(nodes): - for node in nodes: - node.update(*args, **kwargs) - else: - raise ValueError('Do not implement the update() function.') - - def clear_input(self, *args, **kwargs): - """Empty function of clearing inputs.""" - pass + # number of neurons + self.num = tools.size2num(size) - def reset_state(self, *args, **kwargs): - pass + # axis names for parallelization + self.sharding = sharding + # integration method + self.method = method -class Dynamic(DynamicalSystem): - """Base class to model dynamics. - - There are several essential attributes: - - - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of - neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes - a 3-dimensional neuron group. - - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \ - `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`. - - Args: - size: The neuron group geometry. - name: The name of the dynamic system. - keep_size: Whether keep the geometry information. - mode: The computing mode. - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - sharding: Optional[Any] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - method: str = 'exp_auto' - ): - # size - if isinstance(size, (list, tuple)): - if len(size) <= 0: - raise ValueError(f'size must be int, or a tuple/list of int. ' - f'But we got {type(size)}') - if not isinstance(size[0], (int, np.integer)): - raise ValueError('size must be int, or a tuple/list of int.' - f'But we got {type(size)}') - size = tuple(size) - elif isinstance(size, (int, np.integer)): - size = (size,) - else: - raise ValueError('size must be int, or a tuple/list of int.' - f'But we got {type(size)}') - self.size = size - self.keep_size = keep_size - - # number of neurons - self.num = tools.size2num(size) - - # axis names for parallelization - self.sharding = sharding - - # integration method - self.method = method - - # initialize - super().__init__(name=name, mode=mode) - - @property - def varshape(self): - """The shape of variables in the neuron group.""" - return self.size if self.keep_size else (self.num,) - - def get_batch_shape(self, batch_size=None): - if batch_size is None: - return self.varshape - else: - return (batch_size,) + self.varshape - - def update(self, *args, **kwargs): - """The function to specify the updating rule. - """ - raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' - f'implement "update" function.') + # initialize + super().__init__(name=name, mode=mode) - def init_param(self, param, shape=None, sharding=None): - """Initialize parameters. + @property + def varshape(self): + """The shape of variables in the neuron group.""" + return self.size if self.keep_size else (self.num,) - If ``sharding`` is provided and ``param`` is array, this function will - partition the parameter across the default device mesh. + def get_batch_shape(self, batch_size=None): + if batch_size is None: + return self.varshape + else: + return (batch_size,) + self.varshape - See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting. - """ - shape = self.varshape if shape is None else shape - sharding = self.sharding if sharding is None else sharding - return parameter(param, - sizes=shape, - allow_none=False, - sharding=sharding) + def update(self, *args, **kwargs): + """The function to specify the updating rule. + """ + raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' + f'implement "update" function.') - def init_variable(self, var_data, batch_or_mode, shape=None, sharding=None): - """Initialize variables. + def init_param(self, param, shape=None, sharding=None): + """Initialize parameters. - If ``sharding`` is provided and ``var_data`` is array, this function will - partition the variable across the default device mesh. + If ``sharding`` is provided and ``param`` is array, this function will + partition the parameter across the default device mesh. - See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting. - """ - shape = self.varshape if shape is None else shape - sharding = self.sharding if sharding is None else sharding - return variable_(var_data, - sizes=shape, - batch_or_mode=batch_or_mode, - axis_names=sharding, - batch_axis_name=bm.sharding.BATCH_AXIS) - - def __repr__(self): - return f'{self.name}(mode={self.mode}, size={self.size})' - - def __getitem__(self, item): - return DynView(target=self, index=item) - - def clear_input(self, *args, **kwargs): - """Empty function of clearing inputs.""" - pass + See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting. + """ + shape = self.varshape if shape is None else shape + sharding = self.sharding if sharding is None else sharding + return parameter(param, + sizes=shape, + allow_none=False, + sharding=sharding) + + def init_variable(self, var_data, batch_or_mode, shape=None, sharding=None): + """Initialize variables. + + If ``sharding`` is provided and ``var_data`` is array, this function will + partition the variable across the default device mesh. + + See :py:func:`~.brainpy.math.sharding.device_mesh` for the mesh setting. + """ + shape = self.varshape if shape is None else shape + sharding = self.sharding if sharding is None else sharding + return variable_(var_data, + sizes=shape, + batch_or_mode=batch_or_mode, + axis_names=sharding, + batch_axis_name=bm.sharding.BATCH_AXIS) + + def __repr__(self): + return f'{self.name}(mode={self.mode}, size={self.size})' + + def __getitem__(self, item): + return DynView(target=self, index=item) + + def clear_input(self, *args, **kwargs): + """Empty function of clearing inputs.""" + pass class DynView(Dynamic): - """DSView, an object used to get a view of a dynamical system instance. - - It can get a subset view of variables in a dynamical system instance. - For instance, - - >>> import brainpy as bp - >>> hh = bp.neurons.HH(10) - >>> DynView(hh, slice(5, 10, None)) - >>> # or, simply - >>> hh[5:] - """ - - def __init__( - self, - target: Dynamic, - index: Union[slice, Sequence, ArrayType], - name: Optional[str] = None, - ): - # check target - if not isinstance(target, Dynamic): - raise TypeError(f'Should be instance of {Dynamic.__name__}, but we got {type(target)}.') - self.target = target # the target object to slice - - # check slicing - if isinstance(index, (int, slice)): - index = (index,) - self.index = index # the slice - if len(self.index) > len(target.varshape): - raise ValueError(f"Length of the index should be less than " - f"that of the target's varshape. But we " - f"got {len(self.index)} > {len(target.varshape)}") - - # get all variables for slicing - if hasattr(self.target, SLICE_VARS): - all_vars = {} - for var_str in getattr(self.target, SLICE_VARS): - v = eval(f'target.{var_str}') - all_vars[var_str] = v - else: - all_vars = target.vars(level=1, include_self=True, method='relative') - all_vars = {k: v for k, v in all_vars.items()} # TODO - # all_vars = {k: v for k, v in all_vars.items() if v.nobatch_shape == varshape} - - # slice variables - self.slice_vars = dict() - for k, v in all_vars.items(): - if v.batch_axis is not None: - index = ( - (self.index[:v.batch_axis] + (slice(None, None, None),) + self.index[v.batch_axis:]) - if (len(self.index) > v.batch_axis) else - (self.index + tuple([slice(None, None, None) for _ in range(v.batch_axis - len(self.index) + 1)])) - ) - else: - index = self.index - self.slice_vars[k] = bm.VariableView(v, index) - - # sub-nodes - # nodes = target.nodes(method='relative', level=0, include_self=True).subset(DynamicalSystem) - # for k, node in nodes.items(): - # if isinstance(node, Dynamic): - # node = DynView(node, self.index) - # else: - # node = DynView(node, self.index) - # setattr(self, k, node) - - # initialization - # get size - size = [] - for i, idx in enumerate(self.index): - if isinstance(idx, int): - size.append(1) - elif isinstance(idx, slice): - size.append(_slice_to_num(idx, target.varshape[i])) - else: - # should be a list/tuple/array of int - # do not check again - if not isinstance(idx, collections.abc.Iterable): - raise TypeError('Should be an iterable object of int.') - size.append(len(idx)) - size += list(target.varshape[len(self.index):]) - - super().__init__(size, keep_size=target.keep_size, name=name, mode=target.mode) - - def __repr__(self): - return f'{self.name}(target={self.target}, index={self.index})' - - def __getattribute__(self, item): - try: - slice_vars = object.__getattribute__(self, 'slice_vars') - if item in slice_vars: - value = slice_vars[item] - return value - return object.__getattribute__(self, item) - except AttributeError: - return object.__getattribute__(self, item) - - def __setattr__(self, key, value): - if hasattr(self, 'slice_vars'): - slice_vars = super().__getattribute__('slice_vars') - if key in slice_vars: - v = slice_vars[key] - v.value = value - return - super(DynView, self).__setattr__(key, value) - - def update(self, *args, **kwargs): - raise NoImplementationError(f'{DynView.__name__} {self} cannot be updated. ' - f'Please update its parent {self.target}') - - def reset_state(self, batch_size=None): - pass + """DSView, an object used to get a view of a dynamical system instance. + + It can get a subset view of variables in a dynamical system instance. + For instance, + + >>> import brainpy as bp + >>> hh = bp.neurons.HH(10) + >>> DynView(hh, slice(5, 10, None)) + >>> # or, simply + >>> hh[5:] + """ + + def __init__( + self, + target: Dynamic, + index: Union[slice, Sequence, ArrayType], + name: Optional[str] = None, + ): + # check target + if not isinstance(target, Dynamic): + raise TypeError(f'Should be instance of {Dynamic.__name__}, but we got {type(target)}.') + self.target = target # the target object to slice + + # check slicing + if isinstance(index, (int, slice)): + index = (index,) + self.index = index # the slice + if len(self.index) > len(target.varshape): + raise ValueError(f"Length of the index should be less than " + f"that of the target's varshape. But we " + f"got {len(self.index)} > {len(target.varshape)}") + + # get all variables for slicing + if hasattr(self.target, SLICE_VARS): + all_vars = {} + for var_str in getattr(self.target, SLICE_VARS): + v = eval(f'target.{var_str}') + all_vars[var_str] = v + else: + all_vars = target.vars(level=1, include_self=True, method='relative') + all_vars = {k: v for k, v in all_vars.items()} # TODO + # all_vars = {k: v for k, v in all_vars.items() if v.nobatch_shape == varshape} + + # slice variables + self.slice_vars = dict() + for k, v in all_vars.items(): + if v.batch_axis is not None: + index = ( + (self.index[:v.batch_axis] + (slice(None, None, None),) + self.index[v.batch_axis:]) + if (len(self.index) > v.batch_axis) else + (self.index + tuple([slice(None, None, None) for _ in range(v.batch_axis - len(self.index) + 1)])) + ) + else: + index = self.index + self.slice_vars[k] = bm.VariableView(v, index) + + # sub-nodes + # nodes = target.nodes(method='relative', level=0, include_self=True).subset(DynamicalSystem) + # for k, node in nodes.items(): + # if isinstance(node, Dynamic): + # node = DynView(node, self.index) + # else: + # node = DynView(node, self.index) + # setattr(self, k, node) + + # initialization + # get size + size = [] + for i, idx in enumerate(self.index): + if isinstance(idx, int): + size.append(1) + elif isinstance(idx, slice): + size.append(_slice_to_num(idx, target.varshape[i])) + else: + # should be a list/tuple/array of int + # do not check again + if not isinstance(idx, collections.abc.Iterable): + raise TypeError('Should be an iterable object of int.') + size.append(len(idx)) + size += list(target.varshape[len(self.index):]) + + super().__init__(size, keep_size=target.keep_size, name=name, mode=target.mode) + + def __repr__(self): + return f'{self.name}(target={self.target}, index={self.index})' + + def __getattribute__(self, item): + try: + slice_vars = object.__getattribute__(self, 'slice_vars') + if item in slice_vars: + value = slice_vars[item] + return value + return object.__getattribute__(self, item) + except AttributeError: + return object.__getattribute__(self, item) + + def __setattr__(self, key, value): + if hasattr(self, 'slice_vars'): + slice_vars = super().__getattribute__('slice_vars') + if key in slice_vars: + v = slice_vars[key] + v.value = value + return + super(DynView, self).__setattr__(key, value) + + def update(self, *args, **kwargs): + raise NoImplementationError(f'{DynView.__name__} {self} cannot be updated. ' + f'Please update its parent {self.target}') + + def reset_state(self, batch_size=None): + pass @tools.numba_jit def _slice_to_num(slice_: slice, length: int): - # start - start = slice_.start - if start is None: - start = 0 - if start < 0: - start = length + start - start = max(start, 0) - # stop - stop = slice_.stop - if stop is None: - stop = length - if stop < 0: - stop = length + stop - stop = min(stop, length) - # step - step = slice_.step - if step is None: - step = 1 - # number - num = 0 - while start < stop: - start += step - num += 1 - return num + # start + start = slice_.start + if start is None: + start = 0 + if start < 0: + start = length + start + start = max(start, 0) + # stop + stop = slice_.stop + if stop is None: + stop = length + if stop < 0: + stop = length + stop + stop = min(stop, length) + # step + step = slice_.step + if step is None: + step = 1 + # number + num = 0 + while start < stop: + start += step + num += 1 + return num def receive_update_output(cls: object): - """ - The decorator to mark the object (as the after updates) to receive the output of the update function. + """ + The decorator to mark the object (as the after updates) to receive the output of the update function. - That is, the `aft_update` will receive the return of the update function:: + That is, the `aft_update` will receive the return of the update function:: - ret = model.update(*args, **kwargs) - for fun in model.aft_updates: - fun(ret) + ret = model.update(*args, **kwargs) + for fun in model.aft_updates: + fun(ret) - """ - # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' - if hasattr(cls, '_not_receive_update_output'): - delattr(cls, '_not_receive_update_output') - return cls + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + if hasattr(cls, '_not_receive_update_output'): + delattr(cls, '_not_receive_update_output') + return cls def not_receive_update_output(cls: object): - """ - The decorator to mark the object (as the after updates) to not receive the output of the update function. + """ + The decorator to mark the object (as the after updates) to not receive the output of the update function. - That is, the `aft_update` will not receive the return of the update function:: + That is, the `aft_update` will not receive the return of the update function:: - ret = model.update(*args, **kwargs) - for fun in model.aft_updates: - fun() + ret = model.update(*args, **kwargs) + for fun in model.aft_updates: + fun() - """ - # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' - cls._not_receive_update_output = True - return cls + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + cls._not_receive_update_output = True + return cls def receive_update_input(cls: object): - """ - The decorator to mark the object (as the before updates) to receive the input of the update function. - - That is, the `bef_update` will receive the input of the update function:: + """ + The decorator to mark the object (as the before updates) to receive the input of the update function. + That is, the `bef_update` will receive the input of the update function:: - for fun in model.bef_updates: - fun(*args, **kwargs) - model.update(*args, **kwargs) - """ - # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' - cls._receive_update_input = True - return cls - + for fun in model.bef_updates: + fun(*args, **kwargs) + model.update(*args, **kwargs) -def not_receive_update_input(cls: object): - """ - The decorator to mark the object (as the before updates) to not receive the input of the update function. + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + cls._receive_update_input = True + return cls - That is, the `bef_update` will not receive the input of the update function:: - for fun in model.bef_updates: - fun() - model.update() - - """ - # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' - if hasattr(cls, '_receive_update_input'): - delattr(cls, '_receive_update_input') - return cls - +def not_receive_update_input(cls: object): + """ + The decorator to mark the object (as the before updates) to not receive the input of the update function. - + That is, the `bef_update` will not receive the input of the update function:: + for fun in model.bef_updates: + fun() + model.update() + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + if hasattr(cls, '_receive_update_input'): + delattr(cls, '_receive_update_input') + return cls diff --git a/brainpy/_src/encoding/__init__.py b/brainpy/_src/encoding/__init__.py index 397151c01..1b67bbc87 100644 --- a/brainpy/_src/encoding/__init__.py +++ b/brainpy/_src/encoding/__init__.py @@ -4,4 +4,3 @@ from .base import * from .stateful_encoding import * from .stateless_encoding import * - diff --git a/brainpy/_src/encoding/base.py b/brainpy/_src/encoding/base.py index d2a53242d..bb2472118 100644 --- a/brainpy/_src/encoding/base.py +++ b/brainpy/_src/encoding/base.py @@ -1,22 +1,20 @@ # -*- coding: utf-8 -*- -import brainpy.math as bm from brainpy._src.math.object_transform.base import BrainPyObject __all__ = [ - 'Encoder' + 'Encoder' ] class Encoder(BrainPyObject): - """Base class for encoding rate values as spike trains.""" + """Base class for encoding rate values as spike trains.""" - def __repr__(self): - return self.__class__.__name__ + def __repr__(self): + return self.__class__.__name__ - def single_step(self, *args, **kwargs): - raise NotImplementedError('Please implement the function for single step encoding.') - - def multi_steps(self, *args, **kwargs): - raise NotImplementedError('Encode implement the function for multiple-step encoding.') + def single_step(self, *args, **kwargs): + raise NotImplementedError('Please implement the function for single step encoding.') + def multi_steps(self, *args, **kwargs): + raise NotImplementedError('Encode implement the function for multiple-step encoding.') diff --git a/brainpy/_src/encoding/stateful_encoding.py b/brainpy/_src/encoding/stateful_encoding.py index 3fca9295e..f97ebc242 100644 --- a/brainpy/_src/encoding/stateful_encoding.py +++ b/brainpy/_src/encoding/stateful_encoding.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- -import math -from typing import Union, Callable, Optional +from typing import Callable, Optional -import jax import numpy as np import brainpy.math as bm @@ -12,186 +10,186 @@ from .base import Encoder __all__ = [ - 'LatencyEncoder', - 'WeightedPhaseEncoder', + 'LatencyEncoder', + 'WeightedPhaseEncoder', ] class WeightedPhaseEncoder(Encoder): - r"""Encode the rate input into the spike train according to [1]_. - - The main idea of the weighted spikes is assigning different weights - to different phases (or to spikes in those phases) in order to pack - more information into the spikes. This is the major difference from - a conventional rate coding scheme that assigns the same weight to every spike [1]_. - - Parameters:: - - min_val: float - The minimal value in the given data `x`, used to the data normalization. - max_val: float - The maximum value in the given data `x`, used to the data normalization. - num_phase: int - The number of the encoding period. - weight_fun: Callable - The function to generate weight at the phase :math:`i`. - - References:: - - .. [1] Kim, Jaehyun et al. “Deep neural networks with weighted spikes.” Neurocomputing 311 (2018): 373-386. - """ - - def __init__(self, - min_val: float, - max_val: float, - num_phase: int, - weight_fun: Callable = None): - super().__init__() - - check.is_callable(weight_fun, 'weight_fun', allow_none=True) - self.num_phase = check.is_integer(num_phase, 'num_phase', min_bound=1) - self.min_val = check.is_float(min_val, 'min_val') - self.max_val = check.is_float(max_val, 'max_val') - self.weight_fun = (lambda i: 2 ** (-(i % num_phase + 1))) if weight_fun is None else weight_fun - self.scale = (1 - self.weight_fun(self.num_phase - 1)) / (self.max_val - self.min_val) - - def __call__(self, x: ArrayType, num_step: int): - """Encoding function. + r"""Encode the rate input into the spike train according to [1]_. + + The main idea of the weighted spikes is assigning different weights + to different phases (or to spikes in those phases) in order to pack + more information into the spikes. This is the major difference from + a conventional rate coding scheme that assigns the same weight to every spike [1]_. Parameters:: - x: ArrayType - The input rate value. - num_step: int - The number of time steps. + min_val: float + The minimal value in the given data `x`, used to the data normalization. + max_val: float + The maximum value in the given data `x`, used to the data normalization. + num_phase: int + The number of the encoding period. + weight_fun: Callable + The function to generate weight at the phase :math:`i`. - Returns:: + References:: - out: ArrayType - The encoded spike train. + .. [1] Kim, Jaehyun et al. “Deep neural networks with weighted spikes.” Neurocomputing 311 (2018): 373-386. """ - # normalize all input signals to fit into the range [1, 1-2^K] - x = (x - self.min_val) * self.scale - # run - inputs = bm.Variable(x) + def __init__(self, + min_val: float, + max_val: float, + num_phase: int, + weight_fun: Callable = None): + super().__init__() + + check.is_callable(weight_fun, 'weight_fun', allow_none=True) + self.num_phase = check.is_integer(num_phase, 'num_phase', min_bound=1) + self.min_val = check.is_float(min_val, 'min_val') + self.max_val = check.is_float(max_val, 'max_val') + self.weight_fun = (lambda i: 2 ** (-(i % num_phase + 1))) if weight_fun is None else weight_fun + self.scale = (1 - self.weight_fun(self.num_phase - 1)) / (self.max_val - self.min_val) + + def __call__(self, x: ArrayType, num_step: int): + """Encoding function. + + Parameters:: + + x: ArrayType + The input rate value. + num_step: int + The number of time steps. + + Returns:: + + out: ArrayType + The encoded spike train. + """ + # normalize all input signals to fit into the range [1, 1-2^K] + x = (x - self.min_val) * self.scale + + # run + inputs = bm.Variable(x) - def f(i): - w = self.weight_fun(i) - spike = inputs >= w - inputs.value -= w * spike - return spike + def f(i): + w = self.weight_fun(i) + spike = inputs >= w + inputs.value -= w * spike + return spike - return bm.for_loop(f, bm.arange(num_step).value) + return bm.for_loop(f, bm.arange(num_step).value) class LatencyEncoder(Encoder): - r"""Encode the rate input as the spike train using the latency encoding. - - Use input features to determine time-to-first spike. - - Expected inputs should be between 0 and 1. If not, the latency encoder will encode ``x`` - (normalized into ``[0, 1]`` according to - :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`) - to spikes whose firing time is :math:`0 \le t_f \le \text{num_period}-1`. - A larger ``x`` will cause the earlier firing time. - - - Example:: - - >>> a = bm.array([0.02, 0.5, 1]) - >>> encoder = LatencyEncoder(method='linear', normalize=True) - >>> encoder.multi_steps(a, n_time=5) - Array([[0., 0., 1.], - [0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.], - [1., 0., 0.]]) - - - Args: - min_val: float. The minimal value in the given data `x`, used to the data normalization. - max_val: float. The maximum value in the given data `x`, used to the data normalization. - method: str. How to convert intensity to firing time. Currently, we support `linear` or `log`. - - If ``method='linear'``, the firing rate is calculated as - :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. - - If ``method='log'``, the firing rate is calculated as - :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, - where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. - threshold: float. Input features below the threhold will fire at the - final time step unless ``clip=True`` in which case they will not - fire at all, defaults to ``0.01``. - clip: bool. Option to remove spikes from features that fall - below the threshold, defaults to ``False``. - tau: float. RC Time constant for LIF model used to calculate - firing time, defaults to ``1``. - normalize: bool. Option to normalize the latency code such that - the final spike(s) occur within num_steps, defaults to ``False``. - epsilon: float. A tiny positive value to avoid rounding errors when - using torch.arange, defaults to ``1e-7``. - """ - - def __init__( - self, - min_val: float = None, - max_val: float = None, - method: str = 'log', - threshold: float = 0.01, - clip: bool = False, - tau: float = 1., - normalize: bool = False, - first_spk_time: float = 0., - epsilon: float = 1e-7, - ): - super().__init__() - - if method not in ['linear', 'log']: - raise ValueError('The conversion method can only be "linear" and "log".') - self.method = method - self.min_val = check.is_float(min_val, 'min_val', allow_none=True) - self.max_val = check.is_float(max_val, 'max_val', allow_none=True) - if threshold < 0 or threshold > 1: - raise ValueError(f"``threshold`` [{threshold}] must be between [0, 1]") - self.threshold = threshold - self.clip = clip - self.tau = tau - self.normalize = normalize - self.first_spk_time = check.is_float(first_spk_time) - self.first_spk_step = int(first_spk_time / bm.get_dt()) - self.epsilon = epsilon - - def single_step(self, x, i_step: int = None): - raise NotImplementedError - - def multi_steps(self, data, n_time: Optional[float] = None): - """Generate latency spikes according to the given input data. - - Ensuring x in [0., 1.]. + r"""Encode the rate input as the spike train using the latency encoding. - Args: - data: The rate-based input. - n_time: float. The total time to generate data. If None, use ``tau`` instead. + Use input features to determine time-to-first spike. + + Expected inputs should be between 0 and 1. If not, the latency encoder will encode ``x`` + (normalized into ``[0, 1]`` according to + :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`) + to spikes whose firing time is :math:`0 \le t_f \le \text{num_period}-1`. + A larger ``x`` will cause the earlier firing time. + + + Example:: - Returns: - out: array. The output spiking trains. + >>> a = bm.array([0.02, 0.5, 1]) + >>> encoder = LatencyEncoder(method='linear', normalize=True) + >>> encoder.multi_steps(a, n_time=5) + Array([[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.]]) + + + Args: + min_val: float. The minimal value in the given data `x`, used to the data normalization. + max_val: float. The maximum value in the given data `x`, used to the data normalization. + method: str. How to convert intensity to firing time. Currently, we support `linear` or `log`. + - If ``method='linear'``, the firing rate is calculated as + :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. + - If ``method='log'``, the firing rate is calculated as + :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, + where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. + threshold: float. Input features below the threhold will fire at the + final time step unless ``clip=True`` in which case they will not + fire at all, defaults to ``0.01``. + clip: bool. Option to remove spikes from features that fall + below the threshold, defaults to ``False``. + tau: float. RC Time constant for LIF model used to calculate + firing time, defaults to ``1``. + normalize: bool. Option to normalize the latency code such that + the final spike(s) occur within num_steps, defaults to ``False``. + epsilon: float. A tiny positive value to avoid rounding errors when + using torch.arange, defaults to ``1e-7``. """ - if n_time is None: - n_time = self.tau - tau = n_time if self.normalize else self.tau - x = data - if self.min_val is not None and self.max_val is not None: - x = (x - self.min_val) / (self.max_val - self.min_val) - if self.method == 'linear': - spike_time = (tau - self.first_spk_time - bm.dt) * (1 - x) + self.first_spk_time - - elif self.method == 'log': - x = bm.maximum(x, self.threshold + self.epsilon) # saturates all values below threshold. - spike_time = (tau - self.first_spk_time - bm.dt) * bm.log(x / (x - self.threshold)) + self.first_spk_time - - else: - raise ValueError(f'Unsupported method: {self.method}. Only support "log" and "linear".') - - if self.clip: - spike_time = bm.where(data < self.threshold, np.inf, spike_time) - spike_steps = bm.round(spike_time / bm.get_dt()).astype(int) - return bm.one_hot(spike_steps, num_classes=int(n_time / bm.get_dt()), axis=0, dtype=x.dtype) + + def __init__( + self, + min_val: float = None, + max_val: float = None, + method: str = 'log', + threshold: float = 0.01, + clip: bool = False, + tau: float = 1., + normalize: bool = False, + first_spk_time: float = 0., + epsilon: float = 1e-7, + ): + super().__init__() + + if method not in ['linear', 'log']: + raise ValueError('The conversion method can only be "linear" and "log".') + self.method = method + self.min_val = check.is_float(min_val, 'min_val', allow_none=True) + self.max_val = check.is_float(max_val, 'max_val', allow_none=True) + if threshold < 0 or threshold > 1: + raise ValueError(f"``threshold`` [{threshold}] must be between [0, 1]") + self.threshold = threshold + self.clip = clip + self.tau = tau + self.normalize = normalize + self.first_spk_time = check.is_float(first_spk_time) + self.first_spk_step = int(first_spk_time / bm.get_dt()) + self.epsilon = epsilon + + def single_step(self, x, i_step: int = None): + raise NotImplementedError + + def multi_steps(self, data, n_time: Optional[float] = None): + """Generate latency spikes according to the given input data. + + Ensuring x in [0., 1.]. + + Args: + data: The rate-based input. + n_time: float. The total time to generate data. If None, use ``tau`` instead. + + Returns: + out: array. The output spiking trains. + """ + if n_time is None: + n_time = self.tau + tau = n_time if self.normalize else self.tau + x = data + if self.min_val is not None and self.max_val is not None: + x = (x - self.min_val) / (self.max_val - self.min_val) + if self.method == 'linear': + spike_time = (tau - self.first_spk_time - bm.dt) * (1 - x) + self.first_spk_time + + elif self.method == 'log': + x = bm.maximum(x, self.threshold + self.epsilon) # saturates all values below threshold. + spike_time = (tau - self.first_spk_time - bm.dt) * bm.log(x / (x - self.threshold)) + self.first_spk_time + + else: + raise ValueError(f'Unsupported method: {self.method}. Only support "log" and "linear".') + + if self.clip: + spike_time = bm.where(data < self.threshold, np.inf, spike_time) + spike_steps = bm.round(spike_time / bm.get_dt()).astype(int) + return bm.one_hot(spike_steps, num_classes=int(n_time / bm.get_dt()), axis=0, dtype=x.dtype) diff --git a/brainpy/_src/encoding/stateless_encoding.py b/brainpy/_src/encoding/stateless_encoding.py index 5410d736c..6eebbf7dd 100644 --- a/brainpy/_src/encoding/stateless_encoding.py +++ b/brainpy/_src/encoding/stateless_encoding.py @@ -7,183 +7,183 @@ from .base import Encoder __all__ = [ - 'PoissonEncoder', - 'DiffEncoder', + 'PoissonEncoder', + 'DiffEncoder', ] class PoissonEncoder(Encoder): - r"""Encode the rate input as the Poisson spike train. + r"""Encode the rate input as the Poisson spike train. - Expected inputs should be between 0 and 1. If not, the input :math:`x` will be - normalized to :math:`x_{\text{normalize}}` within ``[0, 1]`` according - to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`. + Expected inputs should be between 0 and 1. If not, the input :math:`x` will be + normalized to :math:`x_{\text{normalize}}` within ``[0, 1]`` according + to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`. - Given the input :math:`x`, the poisson encoder will output - spikes whose firing probability is :math:`x_{\text{normalize}}`. + Given the input :math:`x`, the poisson encoder will output + spikes whose firing probability is :math:`x_{\text{normalize}}`. - Examples:: + Examples:: - import brainpy as bp - import brainpy.math as bm + import brainpy as bp + import brainpy.math as bm - img = bm.random.random((10, 2)) # image to encode (normalized to [0., 1.]) - encoder = bp.encoding.PoissonEncoder() # the encoder + img = bm.random.random((10, 2)) # image to encode (normalized to [0., 1.]) + encoder = bp.encoding.PoissonEncoder() # the encoder - # encode the image at each time - for run_index in range(100): - spike = encoder.single_step(img) - # do something + # encode the image at each time + for run_index in range(100): + spike = encoder.single_step(img) + # do something - # or, encode the image at multiple times once - spikes = encoder.multi_steps(img, n_time=10.) + # or, encode the image at multiple times once + spikes = encoder.multi_steps(img, n_time=10.) - Args: - min_val: float. The minimal value in the given data `x`, used to the data normalization. - max_val: float. The maximum value in the given data `x`, used to the data normalization. - gain: float. Scale input features by the gain, defaults to ``1``. - offset: float. Shift input features by the offset, defaults to ``0``. - first_spk_time: float. The time to first spike, defaults to ``0``. - """ - - def __init__( - self, - min_val: Optional[float] = None, - max_val: Optional[float] = None, - gain: float = 1.0, - offset: float = 0.0, - first_spk_time: float = 0., - ): - super().__init__() - - self.min_val = check.is_float(min_val, 'min_val', allow_none=True) - self.max_val = check.is_float(max_val, 'max_val', allow_none=True) - self.gain = check.is_float(gain, allow_none=False) - self.offset = check.is_float(offset, allow_none=False) - self.first_spk_time = check.is_float(first_spk_time) - self.first_spk_step = int(self.first_spk_time / bm.get_dt()) + Args: + min_val: float. The minimal value in the given data `x`, used to the data normalization. + max_val: float. The maximum value in the given data `x`, used to the data normalization. + gain: float. Scale input features by the gain, defaults to ``1``. + offset: float. Shift input features by the offset, defaults to ``0``. + first_spk_time: float. The time to first spike, defaults to ``0``. + """ - def single_step(self, x, i_step: int = None): - """Generate spikes at the single step according to the inputs. + def __init__( + self, + min_val: Optional[float] = None, + max_val: Optional[float] = None, + gain: float = 1.0, + offset: float = 0.0, + first_spk_time: float = 0., + ): + super().__init__() + + self.min_val = check.is_float(min_val, 'min_val', allow_none=True) + self.max_val = check.is_float(max_val, 'max_val', allow_none=True) + self.gain = check.is_float(gain, allow_none=False) + self.offset = check.is_float(offset, allow_none=False) + self.first_spk_time = check.is_float(first_spk_time) + self.first_spk_step = int(self.first_spk_time / bm.get_dt()) + + def single_step(self, x, i_step: int = None): + """Generate spikes at the single step according to the inputs. + + Args: + x: Array. The rate input. + i_step: int. The time step to generate spikes. + + Returns: + out: Array. The encoded spike train. + """ + if i_step is None: + return self.multi_steps(x, n_time=None) + else: + return bm.cond(bm.as_jax(i_step < self.first_spk_step), self._zero_out, self.multi_steps, x) + + def multi_steps(self, x, n_time: Optional[float]): + """Generate spikes at multiple steps according to the inputs. + + Args: + x: Array. The rate input. + n_time: float. Encode rate values as spike trains in the given time length. + ``n_time`` is converted into the ``n_step`` according to `n_step = int(n_time / brainpy.math.dt)`. + - If ``n_time=None``, encode the rate values at the current time step. + Users should repeatedly call it to encode `x` as a spike train. + - Else, given the ``x`` with shape ``(S, ...)``, the encoded + spike train is the array with shape ``(n_step, S, ...)``. + + Returns: + out: Array. The encoded spike train. + """ + n_time = int(n_time / bm.get_dt()) + + if (self.min_val is not None) and (self.max_val is not None): + x = (x - self.min_val) / (self.max_val - self.min_val) + x = x * self.gain + self.offset + if n_time is not None and self.first_spk_step > 0: + pre = bm.zeros((self.first_spk_step,) + x.shape, dtype=x.dtype) + shape = ((n_time - self.first_spk_step,) + x.shape) + post = bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + return bm.cat([pre, post], axis=0) + else: + shape = x.shape if (n_time is None) else ((n_time - self.first_spk_step,) + x.shape) + return bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + + def _zero_out(self, x): + return bm.zeros_like(x) - Args: - x: Array. The rate input. - i_step: int. The time step to generate spikes. - Returns: - out: Array. The encoded spike train. - """ - if i_step is None: - return self.multi_steps(x, n_time=None) - else: - return bm.cond(bm.as_jax(i_step < self.first_spk_step), self._zero_out, self.multi_steps, x) +class DiffEncoder(Encoder): + """Generate spike only when the difference between two subsequent + time steps meets a threshold. - def multi_steps(self, x, n_time: Optional[float]): - """Generate spikes at multiple steps according to the inputs. + Optionally include `off_spikes` for negative changes. - Args: - x: Array. The rate input. - n_time: float. Encode rate values as spike trains in the given time length. - ``n_time`` is converted into the ``n_step`` according to `n_step = int(n_time / brainpy.math.dt)`. - - If ``n_time=None``, encode the rate values at the current time step. - Users should repeatedly call it to encode `x` as a spike train. - - Else, given the ``x`` with shape ``(S, ...)``, the encoded - spike train is the array with shape ``(n_step, S, ...)``. - - Returns: - out: Array. The encoded spike train. - """ - n_time = int(n_time / bm.get_dt()) + Example:: - if (self.min_val is not None) and (self.max_val is not None): - x = (x - self.min_val) / (self.max_val - self.min_val) - x = x * self.gain + self.offset - if n_time is not None and self.first_spk_step > 0: - pre = bm.zeros((self.first_spk_step,) + x.shape, dtype=x.dtype) - shape = ((n_time - self.first_spk_step,) + x.shape) - post = bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) - return bm.cat([pre, post], axis=0) - else: - shape = x.shape if (n_time is None) else ((n_time - self.first_spk_step,) + x.shape) - return bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + >>> a = bm.array([1, 2, 2.9, 3, 3.9]) + >>> encoder = DiffEncoder(threshold=1) + >>> encoder.multi_steps(a) + Array([1., 0., 0., 0.]) - def _zero_out(self, x): - return bm.zeros_like(x) + >>> encoder = DiffEncoder(threshold=1, padding=True) + >>> encoder.multi_steps(a) + Array([0., 1., 0., 0., 0.]) + >>> b = bm.array([1, 2, 0, 2, 2.9]) + >>> encoder = DiffEncoder(threshold=1, off_spike=True) + >>> encoder.multi_steps(b) + Array([ 1., 1., -1., 1., 0.]) -class DiffEncoder(Encoder): - """Generate spike only when the difference between two subsequent - time steps meets a threshold. - - Optionally include `off_spikes` for negative changes. - - Example:: - - >>> a = bm.array([1, 2, 2.9, 3, 3.9]) - >>> encoder = DiffEncoder(threshold=1) - >>> encoder.multi_steps(a) - Array([1., 0., 0., 0.]) - - >>> encoder = DiffEncoder(threshold=1, padding=True) - >>> encoder.multi_steps(a) - Array([0., 1., 0., 0., 0.]) - - >>> b = bm.array([1, 2, 0, 2, 2.9]) - >>> encoder = DiffEncoder(threshold=1, off_spike=True) - >>> encoder.multi_steps(b) - Array([ 1., 1., -1., 1., 0.]) - - >>> encoder = DiffEncoder(threshold=1, padding=True, off_spike=True) - >>> encoder.multi_steps(b) - Array([ 0., 1., -1., 1., 0.]) - - Args: - threshold: float. Input features with a change greater than the thresold - across one timestep will generate a spike, defaults to ``0.1``. - padding: bool. Used to change how the first time step of spikes are - measured. If ``True``, the first time step will be repeated with itself - resulting in ``0``'s for the output spikes. - If ``False``, the first time step will be padded with ``0``'s, defaults - to ``False``. - off_spike: bool. If ``True``, negative spikes for changes less than - ``-threshold``, defaults to ``False``. - """ - - def __init__( - self, - threshold: float = 0.1, - padding: bool = False, - off_spike: bool = False, - ): - super().__init__() - - self.threshold = threshold - self.padding = padding - self.off_spike = off_spike - - def single_step(self, *args, **kwargs): - raise NotImplementedError(f'{DiffEncoder.__class__.__name__} does not support single-step encoding.') - - def multi_steps(self, x): - """Encoding multistep inputs with the spiking trains. + >>> encoder = DiffEncoder(threshold=1, padding=True, off_spike=True) + >>> encoder.multi_steps(b) + Array([ 0., 1., -1., 1., 0.]) Args: - x: Array. The array with the shape of `(num_step, ....)`. - - Returns: - out: Array. The spike train. + threshold: float. Input features with a change greater than the thresold + across one timestep will generate a spike, defaults to ``0.1``. + padding: bool. Used to change how the first time step of spikes are + measured. If ``True``, the first time step will be repeated with itself + resulting in ``0``'s for the output spikes. + If ``False``, the first time step will be padded with ``0``'s, defaults + to ``False``. + off_spike: bool. If ``True``, negative spikes for changes less than + ``-threshold``, defaults to ``False``. """ - if self.padding: - diff = bm.diff(x, axis=0, prepend=x[:1]) - else: - diff = bm.diff(x, axis=0, prepend=bm.zeros((1,) + x.shape[1:], dtype=x.dtype)) - - if self.off_spike: - on_spk = bm.asarray(diff >= self.threshold, dtype=x.dtype) - off_spk = -bm.asarray(diff <= -self.threshold, dtype=x.dtype) - return on_spk + off_spk - - else: - return bm.asarray(diff >= self.threshold, dtype=x.dtype) + + def __init__( + self, + threshold: float = 0.1, + padding: bool = False, + off_spike: bool = False, + ): + super().__init__() + + self.threshold = threshold + self.padding = padding + self.off_spike = off_spike + + def single_step(self, *args, **kwargs): + raise NotImplementedError(f'{DiffEncoder.__class__.__name__} does not support single-step encoding.') + + def multi_steps(self, x): + """Encoding multistep inputs with the spiking trains. + + Args: + x: Array. The array with the shape of `(num_step, ....)`. + + Returns: + out: Array. The spike train. + """ + if self.padding: + diff = bm.diff(x, axis=0, prepend=x[:1]) + else: + diff = bm.diff(x, axis=0, prepend=bm.zeros((1,) + x.shape[1:], dtype=x.dtype)) + + if self.off_spike: + on_spk = bm.asarray(diff >= self.threshold, dtype=x.dtype) + off_spk = -bm.asarray(diff <= -self.threshold, dtype=x.dtype) + return on_spk + off_spk + + else: + return bm.asarray(diff >= self.threshold, dtype=x.dtype) diff --git a/brainpy/_src/encoding/tests/test_stateless_encoding.py b/brainpy/_src/encoding/tests/test_stateless_encoding.py index 8dcd19cd9..8991cf17a 100644 --- a/brainpy/_src/encoding/tests/test_stateless_encoding.py +++ b/brainpy/_src/encoding/tests/test_stateless_encoding.py @@ -1,79 +1,75 @@ import unittest -import brainpy.math as bm + import brainpy as bp +import brainpy.math as bm class TestDiffEncoder(unittest.TestCase): - def test_delta(self): - a = bm.array([1, 2, 2.9, 3, 3.9]) - encoder = bp.encoding.DiffEncoder(threshold=1) - r = encoder.multi_steps(a) - excepted = bm.asarray([1., 1., 0., 0., 0.]) - self.assertTrue(bm.allclose(r, excepted)) + def test_delta(self): + a = bm.array([1, 2, 2.9, 3, 3.9]) + encoder = bp.encoding.DiffEncoder(threshold=1) + r = encoder.multi_steps(a) + excepted = bm.asarray([1., 1., 0., 0., 0.]) + self.assertTrue(bm.allclose(r, excepted)) - encoder = bp.encoding.DiffEncoder(threshold=1, padding=True) - r = encoder.multi_steps(a) - excepted = bm.asarray([0., 1., 0., 0., 0.]) - self.assertTrue(bm.allclose(r, excepted)) + encoder = bp.encoding.DiffEncoder(threshold=1, padding=True) + r = encoder.multi_steps(a) + excepted = bm.asarray([0., 1., 0., 0., 0.]) + self.assertTrue(bm.allclose(r, excepted)) - + def test_delta_off_spike(self): + b = bm.array([1, 2, 0, 2, 2.9]) + encoder = bp.encoding.DiffEncoder(threshold=1, off_spike=True) + r = encoder.multi_steps(b) + excepted = bm.asarray([1., 1., -1., 1., 0.]) + self.assertTrue(bm.allclose(r, excepted)) - def test_delta_off_spike(self): - b = bm.array([1, 2, 0, 2, 2.9]) - encoder = bp.encoding.DiffEncoder(threshold=1, off_spike=True) - r = encoder.multi_steps(b) - excepted = bm.asarray([1., 1., -1., 1., 0.]) - self.assertTrue(bm.allclose(r, excepted)) - - encoder = bp.encoding.DiffEncoder(threshold=1, padding=True, off_spike=True) - r = encoder.multi_steps(b) - excepted = bm.asarray([0., 1., -1., 1., 0.]) - self.assertTrue(bm.allclose(r, excepted)) - - + encoder = bp.encoding.DiffEncoder(threshold=1, padding=True, off_spike=True) + r = encoder.multi_steps(b) + excepted = bm.asarray([0., 1., -1., 1., 0.]) + self.assertTrue(bm.allclose(r, excepted)) class TestLatencyEncoder(unittest.TestCase): - def test_latency(self): - a = bm.array([0.02, 0.5, 1]) - encoder = bp.encoding.LatencyEncoder(method='linear') - - r = encoder.multi_steps(a, n_time=0.5) - excepted = bm.asarray( - [[0., 0., 1.], - [0., 0., 0.], - [0., 0., 0.], - [0., 0., 0.], - [0., 1., 0.], - ] - ) - self.assertTrue(bm.allclose(r, excepted)) + def test_latency(self): + a = bm.array([0.02, 0.5, 1]) + encoder = bp.encoding.LatencyEncoder(method='linear') - r = encoder.multi_steps(a, n_time=1.0) - excepted = bm.asarray( - [[0., 0., 1.], - [0., 0., 0.], - [0., 0., 0.], - [0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.], - [0., 0., 0.], - [0., 0., 0.], - [0., 0., 0.], - [1., 0., 0.], - ] - ) - self.assertTrue(bm.allclose(r, excepted)) + r = encoder.multi_steps(a, n_time=0.5) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) - encoder = bp.encoding.LatencyEncoder(method='linear', normalize=True) - r = encoder.multi_steps(a, n_time=0.5) - excepted = bm.asarray( - [[0., 0., 1.], - [0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.], - [1., 0., 0.], - ] - ) - self.assertTrue(bm.allclose(r, excepted)) + r = encoder.multi_steps(a, n_time=1.0) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [1., 0., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + encoder = bp.encoding.LatencyEncoder(method='linear', normalize=True) + r = encoder.multi_steps(a, n_time=0.5) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) diff --git a/brainpy/_src/helpers.py b/brainpy/_src/helpers.py index ab0a306e9..f6ea20552 100644 --- a/brainpy/_src/helpers.py +++ b/brainpy/_src/helpers.py @@ -6,120 +6,118 @@ from brainpy._src.math.object_transform.base import StateLoadResult __all__ = [ - 'reset_level', - 'reset_state', - 'load_state', - 'save_state', - 'clear_input', + 'reset_level', + 'reset_state', + 'load_state', + 'save_state', + 'clear_input', ] - _max_level = 10 def reset_level(level: int = 0): - """The decorator for indicating the resetting level. + """The decorator for indicating the resetting level. - The function takes an optional integer argument level with a default value of 0. + The function takes an optional integer argument level with a default value of 0. - The lower the level, the earlier the function is called. + The lower the level, the earlier the function is called. - >>> import brainpy as bp - >>> bp.reset_level(0) - >>> bp.reset_level(-1) - >>> bp.reset_level(-2) + >>> import brainpy as bp + >>> bp.reset_level(0) + >>> bp.reset_level(-1) + >>> bp.reset_level(-2) - """ - if level < 0: - level = _max_level + level - if level < 0 or level >= _max_level: - raise ValueError(f'"reset_level" must be an integer in [0, 10). but we got {level}') + """ + if level < 0: + level = _max_level + level + if level < 0 or level >= _max_level: + raise ValueError(f'"reset_level" must be an integer in [0, 10). but we got {level}') - def wrap(fun: Callable): - fun.reset_level = level - return fun + def wrap(fun: Callable): + fun.reset_level = level + return fun - return wrap + return wrap def reset_state(target: DynamicalSystem, *args, **kwargs): - """Reset states of all children nodes in the given target. + """Reset states of all children nodes in the given target. - See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. + See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. - Args: - target: The target DynamicalSystem. - """ - dynsys.the_top_layer_reset_state = False + Args: + target: The target DynamicalSystem. + """ + dynsys.the_top_layer_reset_state = False - try: - nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values()) - nodes_with_level = [] + try: + nodes = list(target.nodes().subset(DynamicalSystem).not_subset(DynView).not_subset(IonChaDyn).unique().values()) + nodes_with_level = [] - # reset node whose `reset_state` has no `reset_level` - for node in nodes: - if not hasattr(node.reset_state, 'reset_level'): - node.reset_state(*args, **kwargs) - else: - nodes_with_level.append(node) + # reset node whose `reset_state` has no `reset_level` + for node in nodes: + if not hasattr(node.reset_state, 'reset_level'): + node.reset_state(*args, **kwargs) + else: + nodes_with_level.append(node) - # reset the node's states - for l in range(_max_level): - for node in nodes_with_level: - if node.reset_state.reset_level == l: - node.reset_state(*args, **kwargs) + # reset the node's states + for l in range(_max_level): + for node in nodes_with_level: + if node.reset_state.reset_level == l: + node.reset_state(*args, **kwargs) - finally: - dynsys.the_top_layer_reset_state = True + finally: + dynsys.the_top_layer_reset_state = True def clear_input(target: DynamicalSystem, *args, **kwargs): - """Clear all inputs in the given target. + """Clear all inputs in the given target. - Args: - target:The target DynamicalSystem. + Args: + target:The target DynamicalSystem. - """ - for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values(): - node.clear_input(*args, **kwargs) + """ + for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values(): + node.clear_input(*args, **kwargs) def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs): - """Copy parameters and buffers from :attr:`state_dict` into - this module and its descendants. - - Args: - target: DynamicalSystem. The dynamical system to load its states. - state_dict: dict. A dict containing parameters and persistent buffers. - - Returns: - ------- - ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - - * **missing_keys** is a list of str containing the missing keys - * **unexpected_keys** is a list of str containing the unexpected keys - """ - nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() - missing_keys = [] - unexpected_keys = [] - for name, node in nodes.items(): - r = node.load_state(state_dict[name], **kwargs) - if r is not None: - missing, unexpected = r - missing_keys.extend([f'{name}.{key}' for key in missing]) - unexpected_keys.extend([f'{name}.{key}' for key in unexpected]) - return StateLoadResult(missing_keys, unexpected_keys) + """Copy parameters and buffers from :attr:`state_dict` into + this module and its descendants. + + Args: + target: DynamicalSystem. The dynamical system to load its states. + state_dict: dict. A dict containing parameters and persistent buffers. + + Returns: + ------- + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + """ + nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() + missing_keys = [] + unexpected_keys = [] + for name, node in nodes.items(): + r = node.load_state(state_dict[name], **kwargs) + if r is not None: + missing, unexpected = r + missing_keys.extend([f'{name}.{key}' for key in missing]) + unexpected_keys.extend([f'{name}.{key}' for key in unexpected]) + return StateLoadResult(missing_keys, unexpected_keys) def save_state(target: DynamicalSystem, **kwargs) -> Dict: - """Save all states in the ``target`` as a dictionary for later disk serialization. - - Args: - target: DynamicalSystem. The node to save its states. + """Save all states in the ``target`` as a dictionary for later disk serialization. - Returns: - Dict. The state dict for serialization. - """ - nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() # retrieve all nodes - return {key: node.save_state(**kwargs) for key, node in nodes.items()} + Args: + target: DynamicalSystem. The node to save its states. + Returns: + Dict. The state dict for serialization. + """ + nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() # retrieve all nodes + return {key: node.save_state(**kwargs) for key, node in nodes.items()} diff --git a/brainpy/_src/initialize/__init__.py b/brainpy/_src/initialize/__init__.py index f01a161ca..75dc0793d 100644 --- a/brainpy/_src/initialize/__init__.py +++ b/brainpy/_src/initialize/__init__.py @@ -1,9 +1,8 @@ # -*- coding: utf-8 -*- - from .base import * +from .decay_inits import * from .generic import * from .random_inits import * from .regular_inits import * -from .decay_inits import * diff --git a/brainpy/_src/initialize/base.py b/brainpy/_src/initialize/base.py index 0da0e1c9e..77ae21723 100644 --- a/brainpy/_src/initialize/base.py +++ b/brainpy/_src/initialize/base.py @@ -3,28 +3,28 @@ import abc __all__ = [ - 'Initializer', - '_InterLayerInitializer', - '_IntraLayerInitializer' + 'Initializer', + '_InterLayerInitializer', + '_IntraLayerInitializer' ] class Initializer(abc.ABC): - """Base Initialization Class.""" + """Base Initialization Class.""" - @abc.abstractmethod - def __call__(self, shape, dtype=None): - raise NotImplementedError + @abc.abstractmethod + def __call__(self, shape, dtype=None): + raise NotImplementedError - def __repr__(self): - return self.__class__.__name__ + def __repr__(self): + return self.__class__.__name__ class _InterLayerInitializer(Initializer): - """The superclass of Initializers that initialize the weights between two layers.""" - pass + """The superclass of Initializers that initialize the weights between two layers.""" + pass class _IntraLayerInitializer(Initializer): - """The superclass of Initializers that initialize the weights within a layer.""" - pass + """The superclass of Initializers that initialize the weights within a layer.""" + pass diff --git a/brainpy/_src/initialize/decay_inits.py b/brainpy/_src/initialize/decay_inits.py index 1d18422fb..a9c6ece74 100644 --- a/brainpy/_src/initialize/decay_inits.py +++ b/brainpy/_src/initialize/decay_inits.py @@ -1,165 +1,165 @@ # -*- coding: utf-8 -*- -import numpy as np +from functools import partial +import numpy as np from jax import vmap, jit, numpy as jnp -from functools import partial from brainpy import math as bm from brainpy.tools import to_size, size2num from .base import _IntraLayerInitializer __all__ = [ - 'GaussianDecay', - 'DOGDecay', + 'GaussianDecay', + 'DOGDecay', ] @jit @partial(vmap, in_axes=(0, None, None)) def gaussian_decay_dist_cal1(i_value, post_values, sigma): - dists = jnp.abs(i_value - post_values) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) + dists = jnp.abs(i_value - post_values) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) @jit @partial(vmap, in_axes=(0, None, None, None)) def gaussian_decay_dist_cal2(i_value, post_values, value_sizes, sigma): - dists = jnp.abs(i_value - post_values) - dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) + dists = jnp.abs(i_value - post_values) + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) class GaussianDecay(_IntraLayerInitializer): - r"""Builds a Gaussian connectivity pattern within a population of neurons, - where the weights decay with gaussian function. - - Specifically, for any pair of neurons :math:`(i, j)`, the weight is computed as - - .. math:: - - w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) - - where :math:`v_k^i` is the $i$-th neuron's encoded value at dimension $k$. - - Parameters:: - - sigma : float - Width of the Gaussian function. - max_w : float - The weight amplitude of the Gaussian function. - min_w : float, None - The minimum weight value below which synapses are not created (default: :math:`0.005 * max\_w`). - include_self : bool - Whether create the conn at the same position. - encoding_values : optional, list, tuple, int, float - The value ranges to encode for neurons at each axis. - - - If `values` is not provided, the neuron only encodes each positional - information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is - the index in the high-dimensional space. - - If `values` is a single tuple/list of int/float, neurons at each dimension - will encode the same range of values. For example, `values=(0, np.pi)`, - neurons at each dimension will encode a continuous value space `[0, np.pi]`. - - If `values` is a tuple/list of list/tuple, it means the value space will be - different for each dimension. For example, `values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))`. - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - normalize : bool - Whether normalize the connection probability. - """ - - def __init__(self, sigma, max_w, min_w=None, encoding_values=None, - periodic_boundary=False, include_self=True, normalize=False): - super(GaussianDecay, self).__init__() - self.sigma = sigma - self.max_w = max_w - self.min_w = max_w * 0.005 if min_w is None else min_w - self.encoding_values = encoding_values - self.periodic_boundary = periodic_boundary - self.include_self = include_self - self.normalize = normalize - - def __call__(self, shape, dtype=None): - """Build the weights. + r"""Builds a Gaussian connectivity pattern within a population of neurons, + where the weights decay with gaussian function. + + Specifically, for any pair of neurons :math:`(i, j)`, the weight is computed as + + .. math:: + + w(i, j) = w_{max} \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) + + where :math:`v_k^i` is the $i$-th neuron's encoded value at dimension $k$. Parameters:: - shape : tuple of int, list of int, int - The network shape. Note, this is not the weight shape. + sigma : float + Width of the Gaussian function. + max_w : float + The weight amplitude of the Gaussian function. + min_w : float, None + The minimum weight value below which synapses are not created (default: :math:`0.005 * max\_w`). + include_self : bool + Whether create the conn at the same position. + encoding_values : optional, list, tuple, int, float + The value ranges to encode for neurons at each axis. + + - If `values` is not provided, the neuron only encodes each positional + information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is + the index in the high-dimensional space. + - If `values` is a single tuple/list of int/float, neurons at each dimension + will encode the same range of values. For example, `values=(0, np.pi)`, + neurons at each dimension will encode a continuous value space `[0, np.pi]`. + - If `values` is a tuple/list of list/tuple, it means the value space will be + different for each dimension. For example, `values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))`. + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + normalize : bool + Whether normalize the connection probability. """ - shape = to_size(shape) - net_size = size2num(shape) - - # value ranges to encode - if self.encoding_values is None: - value_ranges = tuple([(0, s) for s in shape]) - elif isinstance(self.encoding_values, (tuple, list)): - if len(self.encoding_values) == 0: - raise ValueError - elif isinstance(self.encoding_values[0], (int, float)): - assert len(self.encoding_values) == 2 - assert self.encoding_values[0] < self.encoding_values[1] - value_ranges = tuple([self.encoding_values for _ in shape]) - elif isinstance(self.encoding_values[0], (tuple, list)): - if len(self.encoding_values) != len(shape): - raise ValueError(f'The network size has {len(shape)} dimensions, while ' - f'the encoded values provided only has {len(self.encoding_values)}-D. ' - f'Error in {str(self)}.') - for v in self.encoding_values: - assert isinstance(v[0], (int, float)) - assert len(v) == 2 - value_ranges = tuple(self.encoding_values) - else: - raise ValueError(f'Unsupported encoding values: {self.encoding_values}') - else: - raise ValueError(f'Unsupported encoding values: {self.encoding_values}') - - # values - values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, shape)] - post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) - value_sizes = np.array([v[1] - v[0] for v in value_ranges]) - if value_sizes.ndim < post_values.ndim: - value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - - # connectivity matrix - i_value_list = np.zeros(shape=(net_size, len(shape), 1)) - for i in range(net_size): - list_index = i - # values for node i - i_coordinate = tuple() - for s in shape[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - i_value_list[list_index] = i_value - - if self.periodic_boundary: - conn_mat = gaussian_decay_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) - else: - conn_mat = gaussian_decay_dist_cal1(i_value_list, post_values, self.sigma) - - if self.normalize: - conn_mat /= conn_mat.max() - if not self.include_self: - bm.fill_diagonal(conn_mat, 0.) - - # connectivity weights - conn_mat *= self.max_w - conn_mat = bm.where(conn_mat < self.min_w, 0., conn_mat) - return bm.asarray(conn_mat, dtype=dtype) - - def __repr__(self): - name = self.__class__.__name__ - bank = ' ' * len(name) - return (f'{name}(sigma={self.sigma}, max_w={self.max_w}, min_w={self.min_w}, \n' - f'{bank}periodic_boundary={self.periodic_boundary}, ' - f'include_self={self.include_self}, ' - f'normalize={self.normalize})') + + def __init__(self, sigma, max_w, min_w=None, encoding_values=None, + periodic_boundary=False, include_self=True, normalize=False): + super(GaussianDecay, self).__init__() + self.sigma = sigma + self.max_w = max_w + self.min_w = max_w * 0.005 if min_w is None else min_w + self.encoding_values = encoding_values + self.periodic_boundary = periodic_boundary + self.include_self = include_self + self.normalize = normalize + + def __call__(self, shape, dtype=None): + """Build the weights. + + Parameters:: + + shape : tuple of int, list of int, int + The network shape. Note, this is not the weight shape. + """ + shape = to_size(shape) + net_size = size2num(shape) + + # value ranges to encode + if self.encoding_values is None: + value_ranges = tuple([(0, s) for s in shape]) + elif isinstance(self.encoding_values, (tuple, list)): + if len(self.encoding_values) == 0: + raise ValueError + elif isinstance(self.encoding_values[0], (int, float)): + assert len(self.encoding_values) == 2 + assert self.encoding_values[0] < self.encoding_values[1] + value_ranges = tuple([self.encoding_values for _ in shape]) + elif isinstance(self.encoding_values[0], (tuple, list)): + if len(self.encoding_values) != len(shape): + raise ValueError(f'The network size has {len(shape)} dimensions, while ' + f'the encoded values provided only has {len(self.encoding_values)}-D. ' + f'Error in {str(self)}.') + for v in self.encoding_values: + assert isinstance(v[0], (int, float)) + assert len(v) == 2 + value_ranges = tuple(self.encoding_values) + else: + raise ValueError(f'Unsupported encoding values: {self.encoding_values}') + else: + raise ValueError(f'Unsupported encoding values: {self.encoding_values}') + + # values + values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, shape)] + post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) + value_sizes = np.array([v[1] - v[0] for v in value_ranges]) + if value_sizes.ndim < post_values.ndim: + value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + + # connectivity matrix + i_value_list = np.zeros(shape=(net_size, len(shape), 1)) + for i in range(net_size): + list_index = i + # values for node i + i_coordinate = tuple() + for s in shape[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + i_value_list[list_index] = i_value + + if self.periodic_boundary: + conn_mat = gaussian_decay_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) + else: + conn_mat = gaussian_decay_dist_cal1(i_value_list, post_values, self.sigma) + + if self.normalize: + conn_mat /= conn_mat.max() + if not self.include_self: + bm.fill_diagonal(conn_mat, 0.) + + # connectivity weights + conn_mat *= self.max_w + conn_mat = bm.where(conn_mat < self.min_w, 0., conn_mat) + return bm.asarray(conn_mat, dtype=dtype) + + def __repr__(self): + name = self.__class__.__name__ + bank = ' ' * len(name) + return (f'{name}(sigma={self.sigma}, max_w={self.max_w}, min_w={self.min_w}, \n' + f'{bank}periodic_boundary={self.periodic_boundary}, ' + f'include_self={self.include_self}, ' + f'normalize={self.normalize})') @jit @@ -168,19 +168,19 @@ def _dog_decay_pd(voxel_ids, values, post_values, value_sizes, max_w_p, sigma_p, max_w_n, sigma_n): - i_value = [] - for i in range(len(voxel_ids)): - p_id = voxel_ids[i] # position id - i_value.append(values[i][p_id]) - i_value = bm.array(i_value) - if i_value.ndim < post_values.ndim: - i_value = bm.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = bm.abs(i_value - post_values) - dists = bm.where(dists > value_sizes / 2, value_sizes - dists, dists) - dists_exp_p = max_w_p * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_p) ** 2 / 2) - dists_exp_n = max_w_n * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_n) ** 2 / 2) - return dists_exp_p - dists_exp_n + i_value = [] + for i in range(len(voxel_ids)): + p_id = voxel_ids[i] # position id + i_value.append(values[i][p_id]) + i_value = bm.array(i_value) + if i_value.ndim < post_values.ndim: + i_value = bm.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + # distances + dists = bm.abs(i_value - post_values) + dists = bm.where(dists > value_sizes / 2, value_sizes - dists, dists) + dists_exp_p = max_w_p * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_p) ** 2 / 2) + dists_exp_n = max_w_n * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_n) ** 2 / 2) + return dists_exp_p - dists_exp_n @jit @@ -189,137 +189,137 @@ def _dog_decay(voxel_ids, values, post_values, max_w_p, sigma_p, max_w_n, sigma_n): - i_value = [] - for i in range(len(voxel_ids)): - p_id = voxel_ids[i] # position id - i_value.append(values[i][p_id]) - i_value = bm.array(i_value) - if i_value.ndim < post_values.ndim: - i_value = bm.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = bm.abs(i_value - post_values) - dists_exp_p = max_w_p * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_p) ** 2 / 2) - dists_exp_n = max_w_n * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_n) ** 2 / 2) - return dists_exp_p - dists_exp_n + i_value = [] + for i in range(len(voxel_ids)): + p_id = voxel_ids[i] # position id + i_value.append(values[i][p_id]) + i_value = bm.array(i_value) + if i_value.ndim < post_values.ndim: + i_value = bm.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + # distances + dists = bm.abs(i_value - post_values) + dists_exp_p = max_w_p * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_p) ** 2 / 2) + dists_exp_n = max_w_n * bm.exp(-(bm.linalg.norm(dists, axis=0) / sigma_n) ** 2 / 2) + return dists_exp_p - dists_exp_n class DOGDecay(_IntraLayerInitializer): - r"""Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons. - - Mathematically, for the given pair of neurons :math:`(i, j)`, the weight between them is computed as - - .. math:: - - w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - - w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2}) - - where weights smaller than :math:`0.005 * max(w_{max}, w_{min})` are not created and - self-connections are avoided by default (parameter allow_self_connections). - - Parameters:: - - sigmas : tuple - Widths of the positive and negative Gaussian functions. - max_ws : tuple - The weight amplitudes of the positive and negative Gaussian functions. - min_w : float, None - The minimum weight value below which synapses are not created (default: :math:`0.005 * min(max\_ws)`). - include_self : bool - Whether create the connections at the same position (self-connections). - normalize : bool - Whether normalize the connection probability . - encoding_values : optional, list, tuple, int, float - The value ranges to encode for neurons at each axis. - - - If `values` is not provided, the neuron only encodes each positional - information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is - the index in the high-dimensional space. - - If `values` is a single tuple/list of int/float, neurons at each dimension - will encode the same range of values. For example, `values=(0, np.pi)`, - neurons at each dimension will encode a continuous value space `[0, np.pi]`. - - If `values` is a tuple/list of list/tuple, it means the value space will be - different for each dimension. For example, `values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))`. - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - """ - - def __init__(self, sigmas, max_ws, min_w=None, encoding_values=None, - periodic_boundary=False, normalize=True, include_self=True): - super(DOGDecay, self).__init__() - self.sigma_p, self.sigma_n = sigmas - self.max_w_p, self.max_w_n = max_ws - self.min_w = 0.005 * min(self.max_w_p, self.max_w_n) if min_w is None else min_w - self.normalize = normalize - self.include_self = include_self - self.encoding_values = encoding_values - self.periodic_boundary = periodic_boundary - - def __call__(self, shape, dtype=None): - """Build the weights. + r"""Builds a Difference-Of-Gaussian (dog) connectivity pattern within a population of neurons. + + Mathematically, for the given pair of neurons :math:`(i, j)`, the weight between them is computed as + + .. math:: + + w(i, j) = w_{max}^+ \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_+^2}) - + w_{max}^- \cdot \exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2}{2\sigma_-^2}) + + where weights smaller than :math:`0.005 * max(w_{max}, w_{min})` are not created and + self-connections are avoided by default (parameter allow_self_connections). Parameters:: - shape : tuple of int, list of int, int - The network shape. Note, this is not the weight shape. + sigmas : tuple + Widths of the positive and negative Gaussian functions. + max_ws : tuple + The weight amplitudes of the positive and negative Gaussian functions. + min_w : float, None + The minimum weight value below which synapses are not created (default: :math:`0.005 * min(max\_ws)`). + include_self : bool + Whether create the connections at the same position (self-connections). + normalize : bool + Whether normalize the connection probability . + encoding_values : optional, list, tuple, int, float + The value ranges to encode for neurons at each axis. + + - If `values` is not provided, the neuron only encodes each positional + information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is + the index in the high-dimensional space. + - If `values` is a single tuple/list of int/float, neurons at each dimension + will encode the same range of values. For example, `values=(0, np.pi)`, + neurons at each dimension will encode a continuous value space `[0, np.pi]`. + - If `values` is a tuple/list of list/tuple, it means the value space will be + different for each dimension. For example, `values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))`. + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. """ - shape = to_size(shape) - - # value ranges to encode - if self.encoding_values is None: - value_ranges = tuple([(0, s) for s in shape]) - elif isinstance(self.encoding_values, (tuple, list)): - if len(self.encoding_values) == 0: - raise ValueError - elif isinstance(self.encoding_values[0], (int, float)): - assert len(self.encoding_values) == 2 - assert self.encoding_values[0] < self.encoding_values[1] - value_ranges = tuple([self.encoding_values for _ in shape]) - elif isinstance(self.encoding_values[0], (tuple, list)): - if len(self.encoding_values) != len(shape): - raise ValueError(f'The network size has {len(shape)} dimensions, while ' - f'the encoded values provided only has {len(self.encoding_values)}-D. ' - f'Error in {str(self)}.') - for v in self.encoding_values: - assert isinstance(v[0], (int, float)) - assert len(v) == 2 - value_ranges = tuple(self.encoding_values) - else: - raise ValueError(f'Unsupported encoding values: {self.encoding_values}') - else: - raise ValueError(f'Unsupported encoding values: {self.encoding_values}') - - # values - values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, shape)] - post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) - value_sizes = np.array([v[1] - v[0] for v in value_ranges]) - if value_sizes.ndim < post_values.ndim: - value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - voxel_ids = np.meshgrid(*[np.arange(s) for s in shape]) - if np.ndim(voxel_ids[0]) > 1: - voxel_ids = tuple(np.moveaxis(m, 0, 1).flatten() for m in voxel_ids) - - # connectivity matrix - if self.periodic_boundary: - conn_weights = _dog_decay_pd(voxel_ids, values, post_values, value_sizes, - self.max_w_p, self.sigma_p, - self.max_w_n, self.sigma_n) - else: - conn_weights = _dog_decay(voxel_ids, values, post_values, - self.max_w_p, self.sigma_p, - self.max_w_n, self.sigma_n) - if not self.include_self: - conn_weights = bm.asarray(conn_weights) - bm.fill_diagonal(conn_weights, 0.) - - # connectivity weights - conn_weights = bm.where(np.abs(conn_weights) < self.min_w, 0., conn_weights) - return bm.asarray(conn_weights, dtype=dtype) - - def __repr__(self): - name = self.__class__.__name__ - bank = ' ' * len(name) - return (f'{name}(sigmas={(self.sigma_p, self.sigma_n)}, ' - f'max_ws={(self.max_w_p, self.max_w_n)}, min_w={self.min_w}, \n' - f'{bank}periodic_boundary={self.periodic_boundary}, ' - f'include_self={self.include_self}, ' - f'normalize={self.normalize})') + + def __init__(self, sigmas, max_ws, min_w=None, encoding_values=None, + periodic_boundary=False, normalize=True, include_self=True): + super(DOGDecay, self).__init__() + self.sigma_p, self.sigma_n = sigmas + self.max_w_p, self.max_w_n = max_ws + self.min_w = 0.005 * min(self.max_w_p, self.max_w_n) if min_w is None else min_w + self.normalize = normalize + self.include_self = include_self + self.encoding_values = encoding_values + self.periodic_boundary = periodic_boundary + + def __call__(self, shape, dtype=None): + """Build the weights. + + Parameters:: + + shape : tuple of int, list of int, int + The network shape. Note, this is not the weight shape. + """ + shape = to_size(shape) + + # value ranges to encode + if self.encoding_values is None: + value_ranges = tuple([(0, s) for s in shape]) + elif isinstance(self.encoding_values, (tuple, list)): + if len(self.encoding_values) == 0: + raise ValueError + elif isinstance(self.encoding_values[0], (int, float)): + assert len(self.encoding_values) == 2 + assert self.encoding_values[0] < self.encoding_values[1] + value_ranges = tuple([self.encoding_values for _ in shape]) + elif isinstance(self.encoding_values[0], (tuple, list)): + if len(self.encoding_values) != len(shape): + raise ValueError(f'The network size has {len(shape)} dimensions, while ' + f'the encoded values provided only has {len(self.encoding_values)}-D. ' + f'Error in {str(self)}.') + for v in self.encoding_values: + assert isinstance(v[0], (int, float)) + assert len(v) == 2 + value_ranges = tuple(self.encoding_values) + else: + raise ValueError(f'Unsupported encoding values: {self.encoding_values}') + else: + raise ValueError(f'Unsupported encoding values: {self.encoding_values}') + + # values + values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, shape)] + post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) + value_sizes = np.array([v[1] - v[0] for v in value_ranges]) + if value_sizes.ndim < post_values.ndim: + value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + voxel_ids = np.meshgrid(*[np.arange(s) for s in shape]) + if np.ndim(voxel_ids[0]) > 1: + voxel_ids = tuple(np.moveaxis(m, 0, 1).flatten() for m in voxel_ids) + + # connectivity matrix + if self.periodic_boundary: + conn_weights = _dog_decay_pd(voxel_ids, values, post_values, value_sizes, + self.max_w_p, self.sigma_p, + self.max_w_n, self.sigma_n) + else: + conn_weights = _dog_decay(voxel_ids, values, post_values, + self.max_w_p, self.sigma_p, + self.max_w_n, self.sigma_n) + if not self.include_self: + conn_weights = bm.asarray(conn_weights) + bm.fill_diagonal(conn_weights, 0.) + + # connectivity weights + conn_weights = bm.where(np.abs(conn_weights) < self.min_w, 0., conn_weights) + return bm.asarray(conn_weights, dtype=dtype) + + def __repr__(self): + name = self.__class__.__name__ + bank = ' ' * len(name) + return (f'{name}(sigmas={(self.sigma_p, self.sigma_n)}, ' + f'max_ws={(self.max_w_p, self.max_w_n)}, min_w={self.min_w}, \n' + f'{bank}periodic_boundary={self.periodic_boundary}, ' + f'include_self={self.include_self}, ' + f'normalize={self.normalize})') diff --git a/brainpy/_src/initialize/generic.py b/brainpy/_src/initialize/generic.py index f4dc9453d..63b343fdd 100644 --- a/brainpy/_src/initialize/generic.py +++ b/brainpy/_src/initialize/generic.py @@ -12,26 +12,26 @@ from .base import Initializer __all__ = [ - 'parameter', - 'variable', - 'variable_', - 'noise', - 'delay', + 'parameter', + 'variable', + 'variable_', + 'noise', + 'delay', ] def _check_none(x, allow_none: bool = False): - pass + pass def _is_scalar(x): - return isinstance(x, (float, int, bool, complex)) + return isinstance(x, (float, int, bool, complex)) def _check_var(x): - if isinstance(x, bm.Variable): - x.ready_to_trace = True - return x + if isinstance(x, bm.Variable): + x.ready_to_trace = True + return x def parameter( @@ -41,65 +41,65 @@ def parameter( allow_scalar: bool = True, sharding: Optional[Sharding] = None ): - """Initialize parameters. - - Parameters:: - - param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool - The initialization of the parameter. - - If it is None, the created parameter will be None. - - If it is a callable function :math:`f`, the ``f(size)`` will be returned. - - If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned. - - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``. - sizes: int, sequence of int - The shape of the parameter. - allow_none: bool - Whether allow the parameter is None. - allow_scalar: bool - Whether allow the parameter is a scalar value. - sharding: Sharding - The axes for automatic array sharding. - - Returns:: - - param: ArrayType, float, int, bool, None - The initialized parameter. - - See Also:: - - variable_, noise, delay - """ - if param is None: - if allow_none: - return None + """Initialize parameters. + + Parameters:: + + param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool + The initialization of the parameter. + - If it is None, the created parameter will be None. + - If it is a callable function :math:`f`, the ``f(size)`` will be returned. + - If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned. + - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``. + sizes: int, sequence of int + The shape of the parameter. + allow_none: bool + Whether allow the parameter is None. + allow_scalar: bool + Whether allow the parameter is a scalar value. + sharding: Sharding + The axes for automatic array sharding. + + Returns:: + + param: ArrayType, float, int, bool, None + The initialized parameter. + + See Also:: + + variable_, noise, delay + """ + if param is None: + if allow_none: + return None + else: + raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or ' + f'Callable function, but we got None. ') + sizes = to_size(sizes) + if allow_scalar and _is_scalar(param): + return param + + if callable(param): + v = bm.jit(param, + static_argnums=0, + out_shardings=bm.sharding.get_sharding(sharding))(sizes) + return _check_var(v) # TODO: checking the Variable need to be traced + + elif isinstance(param, (np.ndarray, jnp.ndarray)): + param = bm.asarray(param) + elif isinstance(param, bm.Variable): + param = param + elif isinstance(param, bm.BaseArray): + param = param else: - raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or ' - f'Callable function, but we got None. ') - sizes = to_size(sizes) - if allow_scalar and _is_scalar(param): - return param - - if callable(param): - v = bm.jit(param, - static_argnums=0, - out_shardings=bm.sharding.get_sharding(sharding))(sizes) - return _check_var(v) # TODO: checking the Variable need to be traced - - elif isinstance(param, (np.ndarray, jnp.ndarray)): - param = bm.asarray(param) - elif isinstance(param, bm.Variable): - param = param - elif isinstance(param, bm.BaseArray): - param = param - else: - raise ValueError(f'Unknown param type {type(param)}: {param}') - - if allow_scalar: - if param.shape == () or param.shape == (1,): - return param - if param.shape != sizes: - raise ValueError(f'The shape of the parameters should be {sizes}, but we got {param.shape}') - return bm.sharding.partition(param, sharding) + raise ValueError(f'Unknown param type {type(param)}: {param}') + + if allow_scalar: + if param.shape == () or param.shape == (1,): + return param + if param.shape != sizes: + raise ValueError(f'The shape of the parameters should be {sizes}, but we got {param.shape}') + return bm.sharding.partition(param, sharding) def variable_( @@ -110,19 +110,19 @@ def variable_( axis_names: Optional[Sequence[str]] = None, batch_axis_name: Optional[str] = None, ): - """Initialize a :math:`~.Variable` from a callable function or a data. + """Initialize a :math:`~.Variable` from a callable function or a data. - See Also:: + See Also:: - variable + variable - """ - return variable(init, - batch_or_mode, - sizes=sizes, - batch_axis=batch_axis, - axis_names=axis_names, - batch_axis_name=batch_axis_name) + """ + return variable(init, + batch_or_mode, + sizes=sizes, + batch_axis=batch_axis, + axis_names=axis_names, + batch_axis_name=batch_axis_name) def variable( @@ -133,83 +133,83 @@ def variable( axis_names: Optional[Sequence[str]] = None, batch_axis_name: Optional[str] = None, ): - """Initialize variables. - - Parameters:: - - init: callable, ArrayType - The data to be initialized as a ``Variable``. - batch_or_mode: int, bool, Mode, optional - The batch size, mode ``Mode``, boolean state. - This is used to specify the batch size of this variable. - If it is a boolean or an instance of ``Mode``, the batch size will be 1. - If it is None, the variable has no batch axis. - sizes: Shape - The shape of the variable. - batch_axis: int - The batch axis. - axis_names: sequence of str - The name for each axis. These names should match the given ``axes``. - batch_axis_name: str - The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given. - - Returns:: - - variable: bm.Variable - The target ``Variable`` instance. - - See Also:: - - variable_, parameter, noise, delay - - """ - - sizes = to_size(sizes) - if axis_names is not None: - axis_names = list(axis_names) - assert len(sizes) == len(axis_names) - if batch_or_mode is not None and not isinstance(batch_or_mode, bm.NonBatchingMode): - axis_names.insert(batch_axis, batch_axis_name) - - if callable(init): - if sizes is None: - raise ValueError('"varshape" cannot be None when data is a callable function.') - if isinstance(batch_or_mode, bm.NonBatchingMode): - data = bm.Variable(init(sizes), axis_names=axis_names) - elif isinstance(batch_or_mode, bm.BatchingMode): - new_shape = sizes[:batch_axis] + (batch_or_mode.batch_size,) + sizes[batch_axis:] - data = bm.Variable(init(new_shape), batch_axis=batch_axis, axis_names=axis_names) - elif batch_or_mode in (None, False): - data = bm.Variable(init(sizes), axis_names=axis_names) - elif isinstance(batch_or_mode, int): - new_shape = sizes[:batch_axis] + (int(batch_or_mode),) + sizes[batch_axis:] - data = bm.Variable(init(new_shape), batch_axis=batch_axis, axis_names=axis_names) - else: - raise ValueError(f'Unknown batch_size_or_mode: {batch_or_mode}') - - else: - if sizes is not None: - if bm.shape(init) != sizes: - raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {sizes}') - if isinstance(batch_or_mode, bm.NonBatchingMode): - data = bm.Variable(init, axis_names=axis_names) - elif isinstance(batch_or_mode, bm.BatchingMode): - data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), - batch_or_mode.batch_size, - axis=batch_axis), - batch_axis=batch_axis, - axis_names=axis_names) - elif batch_or_mode in (None, False): - data = bm.Variable(init, axis_names=axis_names) - elif isinstance(batch_or_mode, int): - data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), - int(batch_or_mode), - axis=batch_axis), - batch_axis=batch_axis, - axis_names=axis_names) + """Initialize variables. + + Parameters:: + + init: callable, ArrayType + The data to be initialized as a ``Variable``. + batch_or_mode: int, bool, Mode, optional + The batch size, mode ``Mode``, boolean state. + This is used to specify the batch size of this variable. + If it is a boolean or an instance of ``Mode``, the batch size will be 1. + If it is None, the variable has no batch axis. + sizes: Shape + The shape of the variable. + batch_axis: int + The batch axis. + axis_names: sequence of str + The name for each axis. These names should match the given ``axes``. + batch_axis_name: str + The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given. + + Returns:: + + variable: bm.Variable + The target ``Variable`` instance. + + See Also:: + + variable_, parameter, noise, delay + + """ + + sizes = to_size(sizes) + if axis_names is not None: + axis_names = list(axis_names) + assert len(sizes) == len(axis_names) + if batch_or_mode is not None and not isinstance(batch_or_mode, bm.NonBatchingMode): + axis_names.insert(batch_axis, batch_axis_name) + + if callable(init): + if sizes is None: + raise ValueError('"varshape" cannot be None when data is a callable function.') + if isinstance(batch_or_mode, bm.NonBatchingMode): + data = bm.Variable(init(sizes), axis_names=axis_names) + elif isinstance(batch_or_mode, bm.BatchingMode): + new_shape = sizes[:batch_axis] + (batch_or_mode.batch_size,) + sizes[batch_axis:] + data = bm.Variable(init(new_shape), batch_axis=batch_axis, axis_names=axis_names) + elif batch_or_mode in (None, False): + data = bm.Variable(init(sizes), axis_names=axis_names) + elif isinstance(batch_or_mode, int): + new_shape = sizes[:batch_axis] + (int(batch_or_mode),) + sizes[batch_axis:] + data = bm.Variable(init(new_shape), batch_axis=batch_axis, axis_names=axis_names) + else: + raise ValueError(f'Unknown batch_size_or_mode: {batch_or_mode}') + else: - raise ValueError('Unknown batch_size_or_mode.') - return bm.sharding.partition_by_axname(data, axis_names) + if sizes is not None: + if bm.shape(init) != sizes: + raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {sizes}') + if isinstance(batch_or_mode, bm.NonBatchingMode): + data = bm.Variable(init, axis_names=axis_names) + elif isinstance(batch_or_mode, bm.BatchingMode): + data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), + batch_or_mode.batch_size, + axis=batch_axis), + batch_axis=batch_axis, + axis_names=axis_names) + elif batch_or_mode in (None, False): + data = bm.Variable(init, axis_names=axis_names) + elif isinstance(batch_or_mode, int): + data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), + int(batch_or_mode), + axis=batch_axis), + batch_axis=batch_axis, + axis_names=axis_names) + else: + raise ValueError('Unknown batch_size_or_mode.') + return bm.sharding.partition_by_axname(data, axis_names) def noise( @@ -218,39 +218,39 @@ def noise( num_vars: int = 1, noise_idx: int = 0, ) -> Optional[Callable]: - """Initialize a noise function. + """Initialize a noise function. - Parameters:: + Parameters:: - noises: Any - size: Shape - The size of the noise. - num_vars: int - The number of variables. - noise_idx: int - The index of the current noise among all noise variables. + noises: Any + size: Shape + The size of the noise. + num_vars: int + The number of variables. + noise_idx: int + The index of the current noise among all noise variables. - Returns:: + Returns:: - noise_func: function, None - The noise function. + noise_func: function, None + The noise function. - See Also:: + See Also:: - variable_, parameter, delay + variable_, parameter, delay - """ - if callable(noises): - return noises - elif noises is None: - return None - else: - noises = parameter(noises, size, allow_none=False) - if num_vars > 1: - noises_ = [None] * num_vars - noises_[noise_idx] = noises - noises = tuple(noises_) - return lambda *args, **kwargs: noises + """ + if callable(noises): + return noises + elif noises is None: + return None + else: + noises = parameter(noises, size, allow_none=False) + if num_vars > 1: + noises_ = [None] * num_vars + noises_[noise_idx] = noises + noises = tuple(noises_) + return lambda *args, **kwargs: noises def delay( @@ -258,58 +258,58 @@ def delay( delay_target: Union[bm.ndarray, jnp.ndarray], delay_data: Union[bm.ndarray, jnp.ndarray] = None ): - """Initialize delay variable. - - Parameters:: - - delay_step: int, ndarray, ArrayType - The number of delay steps. It can an integer of an array of integers. - delay_target: ndarray, ArrayType - The target variable to delay. - delay_data: optional, ndarray, ArrayType - The initial delay data. - - Returns:: - - info: tuple - The triple of delay type, delay steps, and delay variable. - - See Also:: - - variable_, parameter, noise - """ - # check delay type - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, int): - delay_type = 'homo' - elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): - delay_type = 'heter' - delay_step = bm.asarray(delay_step) - elif callable(delay_step): - delay_step = parameter(delay_step, delay_target.shape, allow_none=False) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [bm.int32, bm.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if delay_target.shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}') - - # init delay data - if delay_type == 'homo': - delays = bm.LengthDelay(delay_target, delay_step, initial_delay_data=delay_data) - elif delay_type == 'heter': - if delay_step.size != delay_target.size: - raise ValueError('Heterogeneous delay must have a length ' - f'of the delay target {delay_target.shape}, ' - f'while we got {delay_step.shape}') - delays = bm.LengthDelay(delay_target, int(delay_step.max())) - else: - delays = None - - return delay_type, delay_step, delays + """Initialize delay variable. + + Parameters:: + + delay_step: int, ndarray, ArrayType + The number of delay steps. It can an integer of an array of integers. + delay_target: ndarray, ArrayType + The target variable to delay. + delay_data: optional, ndarray, ArrayType + The initial delay data. + + Returns:: + + info: tuple + The triple of delay type, delay steps, and delay variable. + + See Also:: + + variable_, parameter, noise + """ + # check delay type + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): + delay_type = 'heter' + delay_step = bm.asarray(delay_step) + elif callable(delay_step): + delay_step = parameter(delay_step, delay_target.shape, allow_none=False) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [bm.int32, bm.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if delay_target.shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}') + + # init delay data + if delay_type == 'homo': + delays = bm.LengthDelay(delay_target, delay_step, initial_delay_data=delay_data) + elif delay_type == 'heter': + if delay_step.size != delay_target.size: + raise ValueError('Heterogeneous delay must have a length ' + f'of the delay target {delay_target.shape}, ' + f'while we got {delay_step.shape}') + delays = bm.LengthDelay(delay_target, int(delay_step.max())) + else: + delays = None + + return delay_type, delay_step, delays diff --git a/brainpy/_src/initialize/others.py b/brainpy/_src/initialize/others.py index f856e8e06..bd19ae13f 100644 --- a/brainpy/_src/initialize/others.py +++ b/brainpy/_src/initialize/others.py @@ -1,26 +1,22 @@ - from typing import Callable + import brainpy.math as bm from .base import Initializer class Clip(Initializer): - def __init__(self, init: Callable, min=None, max=None): - self.min = min - self.max = max - self.init = init - - def __call__(self, shape, dtype=None): - x = self.init(shape, dtype) - if self.min is not None: - x = bm.maximum(self.min, x) - if self.max is not None: - x = bm.minimum(self.max, x) - return x - - def __repr__(self): - return f'{self.__class__.__name__}({self.init}, min={self.min}, max={self.max})' - - - - + def __init__(self, init: Callable, min=None, max=None): + self.min = min + self.max = max + self.init = init + + def __call__(self, shape, dtype=None): + x = self.init(shape, dtype) + if self.min is not None: + x = bm.maximum(self.min, x) + if self.max is not None: + x = bm.minimum(self.max, x) + return x + + def __repr__(self): + return f'{self.__class__.__name__}({self.init}, min={self.min}, max={self.max})' diff --git a/brainpy/_src/initialize/random_inits.py b/brainpy/_src/initialize/random_inits.py index 39686ca33..10716a7e2 100644 --- a/brainpy/_src/initialize/random_inits.py +++ b/brainpy/_src/initialize/random_inits.py @@ -5,466 +5,468 @@ import jax.numpy as jnp import numpy as np -from brainpy._src import math as bm from brainpy import tools +from brainpy._src import math as bm from .base import _InterLayerInitializer __all__ = [ - 'Normal', - 'TruncatedNormal', - 'Uniform', - 'VarianceScaling', - 'KaimingUniform', - 'KaimingNormal', - 'XavierUniform', - 'XavierNormal', - 'LecunUniform', - 'LecunNormal', - 'Orthogonal', - 'DeltaOrthogonal', + 'Normal', + 'TruncatedNormal', + 'Uniform', + 'VarianceScaling', + 'KaimingUniform', + 'KaimingNormal', + 'XavierUniform', + 'XavierNormal', + 'LecunUniform', + 'LecunNormal', + 'Orthogonal', + 'DeltaOrthogonal', ] def calculate_gain(nonlinearity, param=None): - r"""Return the recommended gain value for the given nonlinearity function. - The values are as follows: - - ================= ==================================================== - nonlinearity gain - ================= ==================================================== - Linear / Identity :math:`1` - Conv{1,2,3}D :math:`1` - Sigmoid :math:`1` - Tanh :math:`\frac{5}{3}` - ReLU :math:`\sqrt{2}` - Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` - SELU :math:`\frac{3}{4}` - ================= ==================================================== - - .. warning:: - In order to implement `Self-Normalizing Neural Networks`_ , - you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. - This gives the initial weights a variance of ``1 / N``, - which is necessary to induce a stable fixed point in the forward pass. - In contrast, the default gain for ``SELU`` sacrifices the normalisation - effect for more stable gradient flow in rectangular layers. - - Args: - nonlinearity: the non-linear function (`nn.functional` name) - param: optional parameter for the non-linear function - - .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html - """ - linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] - if nonlinearity in linear_fns or nonlinearity == 'sigmoid': - return 1 - elif nonlinearity == 'tanh': - return 5.0 / 3 - elif nonlinearity == 'relu': - return math.sqrt(2.0) - elif nonlinearity == 'leaky_relu': - if param is None: - negative_slope = 0.01 - elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): - # True/False are instances of int, hence check above - negative_slope = param + r"""Return the recommended gain value for the given nonlinearity function. + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + .. warning:: + In order to implement `Self-Normalizing Neural Networks`_ , + you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. + This gives the initial weights a variance of ``1 / N``, + which is necessary to induce a stable fixed point in the forward pass. + In contrast, the default gain for ``SELU`` sacrifices the normalisation + effect for more stable gradient flow in rectangular layers. + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + elif nonlinearity == 'tanh': + return 5.0 / 3 + elif nonlinearity == 'relu': + return math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + elif nonlinearity == 'selu': + return 3.0 / 4 else: - raise ValueError("negative_slope {} not a valid number".format(param)) - return math.sqrt(2.0 / (1 + negative_slope ** 2)) - elif nonlinearity == 'selu': - return 3.0 / 4 - else: - raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) def _format_shape(shape): - if isinstance(shape, int): - return (shape, ) - if len(shape) == 0: - raise ValueError('Please provide shape.') - if len(shape) == 1: - if isinstance(shape[0], (tuple, list)): - return shape[0] + if isinstance(shape, int): + return (shape,) + if len(shape) == 0: + raise ValueError('Please provide shape.') + if len(shape) == 1: + if isinstance(shape[0], (tuple, list)): + return shape[0] + else: + return shape else: - return shape - else: - return shape + return shape def _compute_fans(shape, in_axis=-2, out_axis=-1): - receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] - fan_in = shape[in_axis] * receptive_field_size - fan_out = shape[out_axis] * receptive_field_size - return fan_in, fan_out + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out class Normal(_InterLayerInitializer): - """Initialize weights with normal distribution. + """Initialize weights with normal distribution. - Parameters:: + Parameters:: - scale : float - The gain of the derivation of the normal distribution. + scale : float + The gain of the derivation of the normal distribution. - """ + """ - def __init__(self, mean=0., scale=1., seed=None): - super(Normal, self).__init__() - self.scale = scale - self.mean = mean - self.rng = bm.random.default_rng(seed, clone=False) + def __init__(self, mean=0., scale=1., seed=None): + super(Normal, self).__init__() + self.scale = scale + self.mean = mean + self.rng = bm.random.default_rng(seed, clone=False) - def __call__(self, shape, dtype=None): - shape = _format_shape(shape) - weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale) - return bm.asarray(weights, dtype=dtype) + def __call__(self, shape, dtype=None): + shape = _format_shape(shape) + weights = self.rng.normal(size=shape, loc=self.mean, scale=self.scale) + return bm.asarray(weights, dtype=dtype) - def __repr__(self): - return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})' + def __repr__(self): + return f'{self.__class__.__name__}(scale={self.scale}, rng={self.rng})' class TruncatedNormal(_InterLayerInitializer): - """Initialize weights with truncated normal distribution. - - Parameters:: - - loc : float, ndarray - Mean ("centre") of the distribution before truncating. Note that - the mean of the truncated distribution will not be exactly equal - to ``loc``. - scale : float - The standard deviation of the normal distribution before truncating. - lower : float, ndarray - A float or array of floats representing the lower bound for - truncation. Must be broadcast-compatible with ``upper``. - upper : float, ndarray - A float or array of floats representing the upper bound for - truncation. Must be broadcast-compatible with ``lower``. - - """ - - def __init__(self, loc=0., scale=1., lower=None, upper=None, seed=None): - super(TruncatedNormal, self).__init__() - assert scale > 0, '`scale` must be positive.' - self.scale = scale - self.loc = loc - self.lower = lower - self.upper = upper - self.rng = bm.random.default_rng(seed, clone=False) - - def __call__(self, shape, dtype=None): - shape = _format_shape(shape) - weights = self.rng.truncated_normal( - size=shape, - scale=self.scale, - lower=self.lower, - upper=self.upper, - loc=self.loc - ) - return bm.asarray(weights, dtype=dtype) - - def __repr__(self): - return f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, lower={self.lower}, upper={self.upper}, rng={self.rng})' - + """Initialize weights with truncated normal distribution. + + Parameters:: + + loc : float, ndarray + Mean ("centre") of the distribution before truncating. Note that + the mean of the truncated distribution will not be exactly equal + to ``loc``. + scale : float + The standard deviation of the normal distribution before truncating. + lower : float, ndarray + A float or array of floats representing the lower bound for + truncation. Must be broadcast-compatible with ``upper``. + upper : float, ndarray + A float or array of floats representing the upper bound for + truncation. Must be broadcast-compatible with ``lower``. + + """ + + def __init__(self, loc=0., scale=1., lower=None, upper=None, seed=None): + super(TruncatedNormal, self).__init__() + assert scale > 0, '`scale` must be positive.' + self.scale = scale + self.loc = loc + self.lower = lower + self.upper = upper + self.rng = bm.random.default_rng(seed, clone=False) + + def __call__(self, shape, dtype=None): + shape = _format_shape(shape) + weights = self.rng.truncated_normal( + size=shape, + scale=self.scale, + lower=self.lower, + upper=self.upper, + loc=self.loc + ) + return bm.asarray(weights, dtype=dtype) + + def __repr__(self): + return f'{self.__class__.__name__}(loc={self.loc}, scale={self.scale}, lower={self.lower}, upper={self.upper}, rng={self.rng})' + class Gamma(_InterLayerInitializer): - """Initialize weights with Gamma distribution. + """Initialize weights with Gamma distribution. - Parameters:: + Parameters:: - shape: float, Array - Shape parameter. - scale: float, Array - The gain of the derivation of the Gamma distribution. + shape: float, Array + Shape parameter. + scale: float, Array + The gain of the derivation of the Gamma distribution. - """ - def __init__(self, shape, scale=None, seed=None): - self.shape = shape - self.scale = scale - self.rng = bm.random.default_rng(seed, clone=False) + """ - def __call__(self, shape, dtype=None): - weights = self.rng.gamma(self.shape, scale=self.scale, size=shape) - return bm.asarray(weights, dtype=dtype) + def __init__(self, shape, scale=None, seed=None): + self.shape = shape + self.scale = scale + self.rng = bm.random.default_rng(seed, clone=False) - def __repr__(self): - return f'{self.__class__.__name__}(shape={self.shape}, scale={self.scale})' + def __call__(self, shape, dtype=None): + weights = self.rng.gamma(self.shape, scale=self.scale, size=shape) + return bm.asarray(weights, dtype=dtype) + + def __repr__(self): + return f'{self.__class__.__name__}(shape={self.shape}, scale={self.scale})' class Exponential(_InterLayerInitializer): - """Initialize weights with Gamma distribution. + """Initialize weights with Gamma distribution. + + Parameters:: - Parameters:: + scale: float, Array + The gain of the derivation of the Exponential distribution. - scale: float, Array - The gain of the derivation of the Exponential distribution. + """ - """ - def __init__(self, scale=None, seed=None): - self.scale = scale - self.rng = bm.random.default_rng(seed, clone=False) + def __init__(self, scale=None, seed=None): + self.scale = scale + self.rng = bm.random.default_rng(seed, clone=False) - def __call__(self, shape, dtype=None): - weights = self.rng.exponential(scale=self.scale, size=shape) - return bm.asarray(weights, dtype=dtype) + def __call__(self, shape, dtype=None): + weights = self.rng.exponential(scale=self.scale, size=shape) + return bm.asarray(weights, dtype=dtype) - def __repr__(self): - return f'{self.__class__.__name__}(scale={self.scale})' + def __repr__(self): + return f'{self.__class__.__name__}(scale={self.scale})' class Uniform(_InterLayerInitializer): - """Initialize weights with uniform distribution. + """Initialize weights with uniform distribution. - Parameters:: + Parameters:: - min_val : float - The lower limit of the uniform distribution. - max_val : float - The upper limit of the uniform distribution. - """ + min_val : float + The lower limit of the uniform distribution. + max_val : float + The upper limit of the uniform distribution. + """ - def __init__(self, min_val: float = 0., max_val: float = 1., seed=None): - super(Uniform, self).__init__() - self.min_val = min_val - self.max_val = max_val - self.rng = bm.random.default_rng(seed, clone=False) + def __init__(self, min_val: float = 0., max_val: float = 1., seed=None): + super(Uniform, self).__init__() + self.min_val = min_val + self.max_val = max_val + self.rng = bm.random.default_rng(seed, clone=False) - def __call__(self, shape, dtype=None): - shape = _format_shape(shape) - r = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape) - return bm.asarray(r, dtype=dtype) + def __call__(self, shape, dtype=None): + shape = _format_shape(shape) + r = self.rng.uniform(low=self.min_val, high=self.max_val, size=shape) + return bm.asarray(r, dtype=dtype) - def __repr__(self): - return (f'{self.__class__.__name__}(min_val={self.min_val}, ' - f'max_val={self.max_val}, rng={self.rng})') + def __repr__(self): + return (f'{self.__class__.__name__}(min_val={self.min_val}, ' + f'max_val={self.max_val}, rng={self.rng})') class VarianceScaling(_InterLayerInitializer): - def __init__( - self, - scale: float, - mode: str, - distribution: str, - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - assert mode in ['fan_in', 'fan_out', 'fan_avg'] - assert distribution in ['truncated_normal', 'normal', 'uniform'] - self.scale = scale - self.mode = mode - self.in_axis = in_axis - self.out_axis = out_axis - self.distribution = distribution - self.rng = bm.random.default_rng(seed, clone=False) - - def __call__(self, shape, dtype=None): - shape = _format_shape(shape) - fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis) - if self.mode == "fan_in": - denominator = fan_in - elif self.mode == "fan_out": - denominator = fan_out - elif self.mode == "fan_avg": - denominator = (fan_in + fan_out) / 2 - else: - raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode)) - variance = (self.scale / denominator).astype(dtype) - if self.distribution == "truncated_normal": - stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype) - res = self.rng.truncated_normal(-2, 2, shape).astype(dtype) * stddev - elif self.distribution == "normal": - res = self.rng.randn(*shape) * jnp.sqrt(variance).astype(dtype) - elif self.distribution == "uniform": - res = self.rng.uniform(low=-1, high=1, size=shape) * jnp.sqrt(3 * variance).astype(dtype) - else: - raise ValueError("invalid distribution for variance scaling initializer") - return bm.asarray(res, dtype=dtype) - - def __repr__(self): - name = self.__class__.__name__ - blank = ' ' * len(name) - return (f'{name}(scale={self.scale}, mode={self.mode}, in_axis={self.in_axis}, \n' - f'{blank}out_axis={self.out_axis}, distribution={self.distribution}, rng={self.rng})') + def __init__( + self, + scale: float, + mode: str, + distribution: str, + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + assert mode in ['fan_in', 'fan_out', 'fan_avg'] + assert distribution in ['truncated_normal', 'normal', 'uniform'] + self.scale = scale + self.mode = mode + self.in_axis = in_axis + self.out_axis = out_axis + self.distribution = distribution + self.rng = bm.random.default_rng(seed, clone=False) + + def __call__(self, shape, dtype=None): + shape = _format_shape(shape) + fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis) + if self.mode == "fan_in": + denominator = fan_in + elif self.mode == "fan_out": + denominator = fan_out + elif self.mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode)) + variance = (self.scale / denominator).astype(dtype) + if self.distribution == "truncated_normal": + stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype) + res = self.rng.truncated_normal(-2, 2, shape).astype(dtype) * stddev + elif self.distribution == "normal": + res = self.rng.randn(*shape) * jnp.sqrt(variance).astype(dtype) + elif self.distribution == "uniform": + res = self.rng.uniform(low=-1, high=1, size=shape) * jnp.sqrt(3 * variance).astype(dtype) + else: + raise ValueError("invalid distribution for variance scaling initializer") + return bm.asarray(res, dtype=dtype) + + def __repr__(self): + name = self.__class__.__name__ + blank = ' ' * len(name) + return (f'{name}(scale={self.scale}, mode={self.mode}, in_axis={self.in_axis}, \n' + f'{blank}out_axis={self.out_axis}, distribution={self.distribution}, rng={self.rng})') class KaimingUniform(VarianceScaling): - def __init__( - self, - scale: float = 2.0, - mode: str = "fan_in", - distribution: str = "uniform", - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - super().__init__(scale, - mode, - distribution, - in_axis=in_axis, - out_axis=out_axis, - seed=seed) + def __init__( + self, + scale: float = 2.0, + mode: str = "fan_in", + distribution: str = "uniform", + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + super().__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, + seed=seed) class KaimingNormal(VarianceScaling): - def __init__( - self, - scale: float = 2.0, - mode: str = "fan_in", - distribution: str = "truncated_normal", - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - super().__init__(scale, - mode, - distribution, - in_axis=in_axis, - out_axis=out_axis, - seed=seed) + def __init__( + self, + scale: float = 2.0, + mode: str = "fan_in", + distribution: str = "truncated_normal", + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + super().__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, + seed=seed) class XavierUniform(VarianceScaling): - def __init__( - self, - scale: float = 1.0, - mode: str = "fan_avg", - distribution: str = "uniform", - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - super().__init__(scale, - mode, - distribution, - in_axis=in_axis, - out_axis=out_axis, - seed=seed) + def __init__( + self, + scale: float = 1.0, + mode: str = "fan_avg", + distribution: str = "uniform", + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + super().__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, + seed=seed) class XavierNormal(VarianceScaling): - def __init__( - self, - scale: float = 1.0, - mode: str = "fan_avg", - distribution: str = "truncated_normal", - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - super().__init__(scale, - mode, - distribution, - in_axis=in_axis, - out_axis=out_axis, - seed=seed) + def __init__( + self, + scale: float = 1.0, + mode: str = "fan_avg", + distribution: str = "truncated_normal", + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + super().__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, + seed=seed) class LecunUniform(VarianceScaling): - def __init__( - self, - scale: float = 1.0, - mode: str = "fan_in", - distribution: str = "uniform", - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - super().__init__(scale, - mode, - distribution, - in_axis=in_axis, - out_axis=out_axis, - seed=seed) + def __init__( + self, + scale: float = 1.0, + mode: str = "fan_in", + distribution: str = "uniform", + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + super().__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, + seed=seed) class LecunNormal(VarianceScaling): - def __init__( - self, - scale: float = 1.0, - mode: str = "fan_in", - distribution: str = "truncated_normal", - in_axis: int = -2, - out_axis: int = -1, - seed: int = None - ): - super().__init__(scale, - mode, - distribution, - in_axis=in_axis, - out_axis=out_axis, - seed=seed) + def __init__( + self, + scale: float = 1.0, + mode: str = "fan_in", + distribution: str = "truncated_normal", + in_axis: int = -2, + out_axis: int = -1, + seed: int = None + ): + super().__init__(scale, + mode, + distribution, + in_axis=in_axis, + out_axis=out_axis, + seed=seed) class Orthogonal(_InterLayerInitializer): - """ - Construct an initializer for uniformly distributed orthogonal matrices. - - If the shape is not square, the matrix will have orthonormal rows or columns - depending on which side is smaller. - """ - - def __init__( - self, - scale: float = 1., - axis: int = -1, - seed: int = None - ): - super().__init__() - self.scale = scale - self.axis = axis - self.rng = bm.random.default_rng(seed, clone=False) - - def __call__(self, shape, dtype=None): - shape = _format_shape(shape) - n_rows = shape[self.axis] - n_cols = np.prod(shape) // n_rows - matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows) - norm_dst = self.rng.normal(size=matrix_shape) - q_mat, r_mat = jnp.linalg.qr(bm.as_jax(norm_dst)) - # Enforce Q is uniformly distributed - q_mat *= jnp.sign(jnp.diag(r_mat)) - if n_rows < n_cols: - q_mat = q_mat.T - q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) - q_mat = jnp.moveaxis(q_mat, 0, self.axis) - return self.scale * bm.asarray(q_mat, dtype=dtype) - - def __repr__(self): - return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, rng={self.rng})' + """ + Construct an initializer for uniformly distributed orthogonal matrices. + + If the shape is not square, the matrix will have orthonormal rows or columns + depending on which side is smaller. + """ + + def __init__( + self, + scale: float = 1., + axis: int = -1, + seed: int = None + ): + super().__init__() + self.scale = scale + self.axis = axis + self.rng = bm.random.default_rng(seed, clone=False) + + def __call__(self, shape, dtype=None): + shape = _format_shape(shape) + n_rows = shape[self.axis] + n_cols = np.prod(shape) // n_rows + matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows) + norm_dst = self.rng.normal(size=matrix_shape) + q_mat, r_mat = jnp.linalg.qr(bm.as_jax(norm_dst)) + # Enforce Q is uniformly distributed + q_mat *= jnp.sign(jnp.diag(r_mat)) + if n_rows < n_cols: + q_mat = q_mat.T + q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) + q_mat = jnp.moveaxis(q_mat, 0, self.axis) + return self.scale * bm.asarray(q_mat, dtype=dtype) + + def __repr__(self): + return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis}, rng={self.rng})' class DeltaOrthogonal(_InterLayerInitializer): - """ - Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. - - The shape must be 3D, 4D or 5D. - """ - - def __init__(self, scale=1.0, axis=-1, ): - super(DeltaOrthogonal, self).__init__() - self.scale = scale - self.axis = axis - - def __call__(self, shape, dtype=None): - shape = [tools.size2num(d) for d in shape] - if len(shape) not in [3, 4, 5]: - raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.") - if shape[-1] < shape[-2]: - raise ValueError("`fan_in` must be less or equal than `fan_out`. ") - ortho_init = Orthogonal(scale=self.scale, axis=self.axis) - ortho_matrix = ortho_init(shape[-2:], dtype=dtype) - W = bm.zeros(shape, dtype=dtype) - if len(shape) == 3: - k = shape[0] - W[(k - 1) // 2, ...] = ortho_matrix - elif len(shape) == 4: - k1, k2 = shape[:2] - W[(k1 - 1) // 2, (k2 - 1) // 2, ...] = ortho_matrix - else: - k1, k2, k3 = shape[:3] - W[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...] = ortho_matrix - return W - - def __repr__(self): - return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis})' + """ + Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. + + The shape must be 3D, 4D or 5D. + """ + + def __init__(self, scale=1.0, axis=-1, ): + super(DeltaOrthogonal, self).__init__() + self.scale = scale + self.axis = axis + + def __call__(self, shape, dtype=None): + shape = [tools.size2num(d) for d in shape] + if len(shape) not in [3, 4, 5]: + raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.") + if shape[-1] < shape[-2]: + raise ValueError("`fan_in` must be less or equal than `fan_out`. ") + ortho_init = Orthogonal(scale=self.scale, axis=self.axis) + ortho_matrix = ortho_init(shape[-2:], dtype=dtype) + W = bm.zeros(shape, dtype=dtype) + if len(shape) == 3: + k = shape[0] + W[(k - 1) // 2, ...] = ortho_matrix + elif len(shape) == 4: + k1, k2 = shape[:2] + W[(k1 - 1) // 2, (k2 - 1) // 2, ...] = ortho_matrix + else: + k1, k2, k3 = shape[:3] + W[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...] = ortho_matrix + return W + + def __repr__(self): + return f'{self.__class__.__name__}(scale={self.scale}, axis={self.axis})' diff --git a/brainpy/_src/initialize/regular_inits.py b/brainpy/_src/initialize/regular_inits.py index 6246b42a3..e542c7686 100644 --- a/brainpy/_src/initialize/regular_inits.py +++ b/brainpy/_src/initialize/regular_inits.py @@ -4,93 +4,93 @@ from .base import _InterLayerInitializer __all__ = [ - 'ZeroInit', - 'Constant', - 'OneInit', - 'Identity', + 'ZeroInit', + 'Constant', + 'OneInit', + 'Identity', ] class ZeroInit(_InterLayerInitializer): - """Zero initializer. + """Zero initializer. - Initialize the weights with zeros. - """ + Initialize the weights with zeros. + """ - def __call__(self, shape, dtype=None): - shape = [tools.size2num(d) for d in shape] - return bm.zeros(shape, dtype=dtype) + def __call__(self, shape, dtype=None): + shape = [tools.size2num(d) for d in shape] + return bm.zeros(shape, dtype=dtype) - def __repr__(self): - return self.__class__.__name__ + def __repr__(self): + return self.__class__.__name__ class Constant(_InterLayerInitializer): - """Constant initializer. + """Constant initializer. - Initialize the weights with the given values. + Initialize the weights with the given values. - Parameters:: + Parameters:: - value : float, int, bm.ndarray - The value to specify. - """ + value : float, int, bm.ndarray + The value to specify. + """ - def __init__(self, value=1.): - super(Constant, self).__init__() - self.value = value + def __init__(self, value=1.): + super(Constant, self).__init__() + self.value = value - def __call__(self, shape, dtype=None): - shape = [tools.size2num(d) for d in shape] - return bm.ones(shape, dtype=dtype) * self.value + def __call__(self, shape, dtype=None): + shape = [tools.size2num(d) for d in shape] + return bm.ones(shape, dtype=dtype) * self.value - def __repr__(self): - return f'{self.__class__.__name__}(value={self.value})' + def __repr__(self): + return f'{self.__class__.__name__}(value={self.value})' class OneInit(Constant): - """One initializer. - """ - pass + """One initializer. + """ + pass class Identity(_InterLayerInitializer): - """Returns the identity matrix. + """Returns the identity matrix. - This initializer was proposed in (Le, et al., 2015) [1]_. + This initializer was proposed in (Le, et al., 2015) [1]_. - Parameters:: + Parameters:: - value : float - The optional scaling factor. + value : float + The optional scaling factor. - Returns:: + Returns:: - shape: tuple of int - The weight shape/size. + shape: tuple of int + The weight shape/size. - References:: + References:: - .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to - initialize recurrent networks of rectified linear units." arXiv preprint - arXiv:1504.00941 (2015). - """ + .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to + initialize recurrent networks of rectified linear units." arXiv preprint + arXiv:1504.00941 (2015). + """ - def __init__(self, value=1.): - super(Identity, self).__init__() - self.value = value + def __init__(self, value=1.): + super(Identity, self).__init__() + self.value = value - def __call__(self, shape, dtype=None): - if isinstance(shape, int): - shape = (shape,) - elif isinstance(shape, (tuple, list)): - if len(shape) > 2: - raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.') - else: - raise ValueError(f'Only support shape of int, or tuple/list of int ' - f'in {self.__class__.__name__}, but we got {shape}.') - shape = [tools.size2num(d) for d in shape] - return bm.eye(*shape, dtype=dtype) * self.value + def __call__(self, shape, dtype=None): + if isinstance(shape, int): + shape = (shape,) + elif isinstance(shape, (tuple, list)): + if len(shape) > 2: + raise ValueError(f'Only support initialize 2D weights for {self.__class__.__name__}.') + else: + raise ValueError(f'Only support shape of int, or tuple/list of int ' + f'in {self.__class__.__name__}, but we got {shape}.') + shape = [tools.size2num(d) for d in shape] + return bm.eye(*shape, dtype=dtype) * self.value - def __repr__(self): - return f'{self.__class__.__name__}(value={self.value})' + def __repr__(self): + return f'{self.__class__.__name__}(value={self.value})' diff --git a/brainpy/_src/initialize/tests/test_decay_inits.py b/brainpy/_src/initialize/tests/test_decay_inits.py index 22e1fa023..298523573 100644 --- a/brainpy/_src/initialize/tests/test_decay_inits.py +++ b/brainpy/_src/initialize/tests/test_decay_inits.py @@ -13,79 +13,79 @@ # visualization def mat_visualize(matrix, cmap=None): - if cmap is None: - cmap = plt.colormaps.get_cmap('coolwarm') - plt.colormaps.get_cmap('coolwarm') - im = plt.matshow(matrix, cmap=cmap) - plt.colorbar(mappable=im, shrink=0.8, aspect=15) - plt.show() + if cmap is None: + cmap = plt.colormaps.get_cmap('coolwarm') + plt.colormaps.get_cmap('coolwarm') + im = plt.matshow(matrix, cmap=cmap) + plt.colorbar(mappable=im, shrink=0.8, aspect=15) + plt.show() def _size2len(size): - if isinstance(size, int): - return size - elif isinstance(size, (tuple, list)): - length = 1 - for e in size: - length *= e - return length - else: - raise ValueError(f'Must be a list/tuple of int, but got {size}') + if isinstance(size, int): + return size + elif isinstance(size, (tuple, list)): + length = 1 + for e in size: + length *= e + return length + else: + raise ValueError(f'Must be a list/tuple of int, but got {size}') class TestGaussianDecayInit(unittest.TestCase): - def test_gaussian_decay_init1(self): - init = bp.init.GaussianDecay(sigma=4, max_w=1.) - for size in [10, (10, 20), (10, 20, 30)]: - weights = init(size) - shape = _size2len(size) - assert weights.shape == (shape, shape) - assert isinstance(weights, bp.math.ndarray) - - def test_gaussian_decay_init2(self): - init = bp.init.GaussianDecay(sigma=4, max_w=1., min_w=0.1, periodic_boundary=True, - encoding_values=((-bm.pi, bm.pi), (10, 20), (0, 2 * bm.pi)), - include_self=False, normalize=True) - size = (10, 20, 30) - weights = init(size) - shape = _size2len(size) - assert weights.shape == (shape, shape) - assert isinstance(weights, bp.math.ndarray) + def test_gaussian_decay_init1(self): + init = bp.init.GaussianDecay(sigma=4, max_w=1.) + for size in [10, (10, 20), (10, 20, 30)]: + weights = init(size) + shape = _size2len(size) + assert weights.shape == (shape, shape) + assert isinstance(weights, bp.math.ndarray) + + def test_gaussian_decay_init2(self): + init = bp.init.GaussianDecay(sigma=4, max_w=1., min_w=0.1, periodic_boundary=True, + encoding_values=((-bm.pi, bm.pi), (10, 20), (0, 2 * bm.pi)), + include_self=False, normalize=True) + size = (10, 20, 30) + weights = init(size) + shape = _size2len(size) + assert weights.shape == (shape, shape) + assert isinstance(weights, bp.math.ndarray) class TestDOGDecayInit(unittest.TestCase): - def test_dog_decay_init1(self): - init = bp.init.DOGDecay(sigmas=(1., 2.5), max_ws=(1.0, 0.7)) - for size in [10, (10, 20), (10, 20, 30)]: - weights = init(size) - shape = _size2len(size) - assert weights.shape == (shape, shape) - assert isinstance(weights, bp.math.ndarray) - - def test_dog_decay_init2(self): - init = bp.init.DOGDecay(sigmas=(1., 2.5), - max_ws=(1.0, 0.7), min_w=0.1, - periodic_boundary=True, - encoding_values=((-bm.pi, bm.pi), (10, 20), (0, 2 * bm.pi)), - include_self=False, - normalize=True) - size = (10, 20, 30) - weights = init(size) - shape = _size2len(size) - assert weights.shape == (shape, shape) - assert isinstance(weights, bp.math.ndarray) - - def test_dog_decay3(self): - size = (10, 12) - dog_init = bp.init.DOGDecay(sigmas=(1., 3.), - max_ws=(10., 5.), - min_w=0.1, - include_self=True) - weights = dog_init(size) - print('shape of weights: {}'.format(weights.shape)) - # out: shape of weights: (120, 120) - self.assertTrue(weights.shape == (np.prod(size), np.prod(size))) - - # visualize neuron(3, 4) - mat_visualize(weights[:, 3 * 12 + 4].reshape((10, 12)), cmap=matplotlib.colormaps['Reds']) - plt.close() + def test_dog_decay_init1(self): + init = bp.init.DOGDecay(sigmas=(1., 2.5), max_ws=(1.0, 0.7)) + for size in [10, (10, 20), (10, 20, 30)]: + weights = init(size) + shape = _size2len(size) + assert weights.shape == (shape, shape) + assert isinstance(weights, bp.math.ndarray) + + def test_dog_decay_init2(self): + init = bp.init.DOGDecay(sigmas=(1., 2.5), + max_ws=(1.0, 0.7), min_w=0.1, + periodic_boundary=True, + encoding_values=((-bm.pi, bm.pi), (10, 20), (0, 2 * bm.pi)), + include_self=False, + normalize=True) + size = (10, 20, 30) + weights = init(size) + shape = _size2len(size) + assert weights.shape == (shape, shape) + assert isinstance(weights, bp.math.ndarray) + + def test_dog_decay3(self): + size = (10, 12) + dog_init = bp.init.DOGDecay(sigmas=(1., 3.), + max_ws=(10., 5.), + min_w=0.1, + include_self=True) + weights = dog_init(size) + print('shape of weights: {}'.format(weights.shape)) + # out: shape of weights: (120, 120) + self.assertTrue(weights.shape == (np.prod(size), np.prod(size))) + + # visualize neuron(3, 4) + mat_visualize(weights[:, 3 * 12 + 4].reshape((10, 12)), cmap=matplotlib.colormaps['Reds']) + plt.close() diff --git a/brainpy/_src/initialize/tests/test_random_inits.py b/brainpy/_src/initialize/tests/test_random_inits.py index b9f21b62a..4750044e7 100644 --- a/brainpy/_src/initialize/tests/test_random_inits.py +++ b/brainpy/_src/initialize/tests/test_random_inits.py @@ -6,172 +6,172 @@ class TestNormalInit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() - - def test_normal_init1(self): - init = bp.init.Normal() - for size in [(100,), (10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size - - def test_normal_init2(self): - init = bp.init.Normal(scale=0.5) - for size in [(100,), (10, 20)]: - weights = init(size) - assert weights.shape == size - - def test_normal_init3(self): - init1 = bp.init.Normal(scale=0.5, seed=10) - init2 = bp.init.Normal(scale=0.5, seed=10) - size = (10,) - weights1 = init1(size) - weights2 = init2(size) - assert weights1.shape == size - assert (weights1 == weights2).all() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() + + def test_normal_init1(self): + init = bp.init.Normal() + for size in [(100,), (10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size + + def test_normal_init2(self): + init = bp.init.Normal(scale=0.5) + for size in [(100,), (10, 20)]: + weights = init(size) + assert weights.shape == size + + def test_normal_init3(self): + init1 = bp.init.Normal(scale=0.5, seed=10) + init2 = bp.init.Normal(scale=0.5, seed=10) + size = (10,) + weights1 = init1(size) + weights2 = init2(size) + assert weights1.shape == size + assert (weights1 == weights2).all() class TestUniformInit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_uniform_init1(self): - init = bp.init.Normal() - for size in [(100,), (10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_uniform_init1(self): + init = bp.init.Normal() + for size in [(100,), (10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size - def test_uniform_init2(self): - init = bp.init.Uniform(min_val=10, max_val=20) - for size in [(100,), (10, 20)]: - weights = init(size) - assert weights.shape == size + def test_uniform_init2(self): + init = bp.init.Uniform(min_val=10, max_val=20) + for size in [(100,), (10, 20)]: + weights = init(size) + assert weights.shape == size class TestVarianceScaling(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() - - def test_var_scaling1(self): - init = bp.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal') - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size - - def test_var_scaling2(self): - init = bp.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal') - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size - - def test_var_scaling3(self): - init = bp.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1, - distribution='uniform') - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() + + def test_var_scaling1(self): + init = bp.init.VarianceScaling(scale=1., mode='fan_in', distribution='truncated_normal') + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size + + def test_var_scaling2(self): + init = bp.init.VarianceScaling(scale=2, mode='fan_out', distribution='normal') + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size + + def test_var_scaling3(self): + init = bp.init.VarianceScaling(scale=2 / 4, mode='fan_avg', in_axis=0, out_axis=1, + distribution='uniform') + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestKaimingUniformUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_kaiming_uniform_init(self): - init = bp.init.KaimingUniform() - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_kaiming_uniform_init(self): + init = bp.init.KaimingUniform() + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestKaimingNormalUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_kaiming_normal_init(self): - init = bp.init.KaimingNormal() - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_kaiming_normal_init(self): + init = bp.init.KaimingNormal() + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestXavierUniformUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_xavier_uniform_init(self): - init = bp.init.XavierUniform() - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_xavier_uniform_init(self): + init = bp.init.XavierUniform() + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestXavierNormalUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_xavier_normal_init(self): - init = bp.init.XavierNormal() - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_xavier_normal_init(self): + init = bp.init.XavierNormal() + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestLecunUniformUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_lecun_uniform_init(self): - init = bp.init.LecunUniform() - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_lecun_uniform_init(self): + init = bp.init.LecunUniform() + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestLecunNormalUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_lecun_normal_init(self): - init = bp.init.LecunNormal() - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_lecun_normal_init(self): + init = bp.init.LecunNormal() + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestOrthogonalUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() - def test_orthogonal_init1(self): - init = bp.init.Orthogonal() - for size in [(20, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_orthogonal_init1(self): + init = bp.init.Orthogonal() + for size in [(20, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size - def test_orthogonal_init2(self): - init = bp.init.Orthogonal(scale=2., axis=0) - for size in [(10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size + def test_orthogonal_init2(self): + init = bp.init.Orthogonal(scale=2., axis=0) + for size in [(10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size class TestDeltaOrthogonalUnit(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bp.math.random.seed() - - def test_delta_orthogonal_init1(self): - init = bp.init.DeltaOrthogonal() - for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]: - weights = init(size) - assert weights.shape == size + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bp.math.random.seed() + + def test_delta_orthogonal_init1(self): + init = bp.init.DeltaOrthogonal() + for size in [(20, 20, 20), (10, 20, 30, 40), (50, 40, 30, 20, 20)]: + weights = init(size) + assert weights.shape == size diff --git a/brainpy/_src/initialize/tests/test_regular_inits.py b/brainpy/_src/initialize/tests/test_regular_inits.py index 8316fc661..b434e04c4 100644 --- a/brainpy/_src/initialize/tests/test_regular_inits.py +++ b/brainpy/_src/initialize/tests/test_regular_inits.py @@ -5,32 +5,32 @@ class TestZeroInit(unittest.TestCase): - def test_zero_init(self): - init = bp.init.ZeroInit() - for size in [(100,), (10, 20), (10, 20, 30)]: - weights = init(size) - assert weights.shape == size - assert isinstance(weights, bp.math.ndarray) + def test_zero_init(self): + init = bp.init.ZeroInit() + for size in [(100,), (10, 20), (10, 20, 30)]: + weights = init(size) + assert weights.shape == size + assert isinstance(weights, bp.math.ndarray) class TestOneInit(unittest.TestCase): - def test_one_init(self): - for size in [(100,), (10, 20), (10, 20, 30)]: - for value in [0., 1., -1.]: - init = bp.init.OneInit(value=value) - weights = init(size) - assert weights.shape == size - assert (weights == value).all() + def test_one_init(self): + for size in [(100,), (10, 20), (10, 20, 30)]: + for value in [0., 1., -1.]: + init = bp.init.OneInit(value=value) + weights = init(size) + assert weights.shape == size + assert (weights == value).all() class TestIdentityInit(unittest.TestCase): - def test_identity_init(self): - for size in [(100,), (10, 20)]: - for value in [0., 1., -1.]: - init = bp.init.Identity(value=value) - weights = init(size) - if len(size) == 1: - assert weights.shape == (size[0], size[0]) - else: - assert weights.shape == size - assert isinstance(weights, bp.math.ndarray) + def test_identity_init(self): + for size in [(100,), (10, 20)]: + for value in [0., 1., -1.]: + init = bp.init.Identity(value=value) + weights = init(size) + if len(size) == 1: + assert weights.shape == (size[0], size[0]) + else: + assert weights.shape == size + assert isinstance(weights, bp.math.ndarray) diff --git a/brainpy/_src/inputs/__init__.py b/brainpy/_src/inputs/__init__.py index 792e10e82..50a6ce2f3 100644 --- a/brainpy/_src/inputs/__init__.py +++ b/brainpy/_src/inputs/__init__.py @@ -7,4 +7,3 @@ """ from .currents import * - diff --git a/brainpy/_src/inputs/currents.py b/brainpy/_src/inputs/currents.py index 9ef98a4e0..d336c8b8c 100644 --- a/brainpy/_src/inputs/currents.py +++ b/brainpy/_src/inputs/currents.py @@ -10,406 +10,406 @@ from brainpy.check import is_float, is_integer __all__ = [ - 'section_input', - 'constant_input', 'constant_current', - 'spike_input', 'spike_current', - 'ramp_input', 'ramp_current', - 'wiener_process', - 'ou_process', - 'sinusoidal_input', - 'square_input', + 'section_input', + 'constant_input', 'constant_current', + 'spike_input', 'spike_current', + 'ramp_input', 'ramp_current', + 'wiener_process', + 'ou_process', + 'sinusoidal_input', + 'square_input', ] def section_input(values, durations, dt=None, return_length=False): - """Format an input current with different sections. - - For example: - - If you want to get an input where the size is 0 bwteen 0-100 ms, - and the size is 1. between 100-200 ms. - - >>> section_input(values=[0, 1], - >>> durations=[100, 100]) - - Parameters:: - - values : list, np.ndarray - The current values for each period duration. - durations : list, np.ndarray - The duration for each period. - dt : float - Default is None. - return_length : bool - Return the final duration length. - - Returns:: - - current_and_duration - """ - if len(durations) != len(values): - raise ValueError(f'"values" and "durations" must be the same length, while ' - f'we got {len(values)} != {len(durations)}.') - - dt = bm.get_dt() if dt is None else dt - - # get input current shape, and duration - I_duration = sum(durations) - I_shape = () - for val in values: - shape = jnp.shape(val) - if len(shape) > len(I_shape): - I_shape = shape - - # get the current - start = 0 - I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape) - for c_size, duration in zip(values, durations): - length = int(duration / dt) - I_current[start: start + length] = c_size - start += length - - if return_length: - return I_current, I_duration - else: - return I_current + """Format an input current with different sections. + For example: -def constant_input(I_and_duration, dt=None): - """Format constant input in durations. + If you want to get an input where the size is 0 bwteen 0-100 ms, + and the size is 1. between 100-200 ms. + + >>> section_input(values=[0, 1], + >>> durations=[100, 100]) - For example: + Parameters:: - If you want to get an input where the size is 0 bwteen 0-100 ms, - and the size is 1. between 100-200 ms. + values : list, np.ndarray + The current values for each period duration. + durations : list, np.ndarray + The duration for each period. + dt : float + Default is None. + return_length : bool + Return the final duration length. - >>> import brainpy.math as bm - >>> constant_input([(0, 100), (1, 100)]) - >>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)]) + Returns:: - Parameters:: + current_and_duration + """ + if len(durations) != len(values): + raise ValueError(f'"values" and "durations" must be the same length, while ' + f'we got {len(values)} != {len(durations)}.') - I_and_duration : list - This parameter receives the current size and the current - duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`. - dt : float - Default is None. + dt = bm.get_dt() if dt is None else dt - Returns:: + # get input current shape, and duration + I_duration = sum(durations) + I_shape = () + for val in values: + shape = jnp.shape(val) + if len(shape) > len(I_shape): + I_shape = shape - current_and_duration : tuple - (The formatted current, total duration) - """ - dt = bm.get_dt() if dt is None else dt + # get the current + start = 0 + I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape) + for c_size, duration in zip(values, durations): + length = int(duration / dt) + I_current[start: start + length] = c_size + start += length - # get input current dimension, shape, and duration - I_duration = 0. - I_shape = () - for I in I_and_duration: - I_duration += I[1] - shape = jnp.shape(I[0]) - if len(shape) > len(I_shape): - I_shape = shape + if return_length: + return I_current, I_duration + else: + return I_current - # get the current - start = 0 - I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape) - for c_size, duration in I_and_duration: - length = int(duration / dt) - I_current[start: start + length] = c_size - start += length - return I_current, I_duration + +def constant_input(I_and_duration, dt=None): + """Format constant input in durations. + + For example: + + If you want to get an input where the size is 0 bwteen 0-100 ms, + and the size is 1. between 100-200 ms. + + >>> import brainpy.math as bm + >>> constant_input([(0, 100), (1, 100)]) + >>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)]) + + Parameters:: + + I_and_duration : list + This parameter receives the current size and the current + duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`. + dt : float + Default is None. + + Returns:: + + current_and_duration : tuple + (The formatted current, total duration) + """ + dt = bm.get_dt() if dt is None else dt + + # get input current dimension, shape, and duration + I_duration = 0. + I_shape = () + for I in I_and_duration: + I_duration += I[1] + shape = jnp.shape(I[0]) + if len(shape) > len(I_shape): + I_shape = shape + + # get the current + start = 0 + I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape) + for c_size, duration in I_and_duration: + length = int(duration / dt) + I_current[start: start + length] = c_size + start += length + return I_current, I_duration def constant_current(*args, **kwargs): - """Format constant input in durations. + """Format constant input in durations. - .. deprecated:: 2.1.13 - Use ``constant_current()`` instead. - """ - warnings.warn('Please use "brainpy.inputs.constant_input()" instead. ' - '"brainpy.inputs.constant_current()" is deprecated since version 2.1.13.', - DeprecationWarning) - return constant_input(*args, **kwargs) + .. deprecated:: 2.1.13 + Use ``constant_current()`` instead. + """ + warnings.warn('Please use "brainpy.inputs.constant_input()" instead. ' + '"brainpy.inputs.constant_current()" is deprecated since version 2.1.13.', + DeprecationWarning) + return constant_input(*args, **kwargs) def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): - """Format current input like a series of short-time spikes. - - For example: - - If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, - and each spike lasts 1 ms and the spike current is 0.5, then you can use the - following funtions: - - >>> spike_input(sp_times=[10, 20, 30, 200, 300], - >>> sp_lens=1., # can be a list to specify the spike length at each point - >>> sp_sizes=0.5, # can be a list to specify the current size at each point - >>> duration=400.) - - Parameters:: - - sp_times : list, tuple - The spike time-points. Must be an iterable object. - sp_lens : int, float, list, tuple - The length of each point-current, mimicking the spike durations. - sp_sizes : int, float, list, tuple - The current sizes. - duration : int, float - The total current duration. - dt : float - The default is None. - - Returns:: - - current : bm.ndarray - The formatted input current. - """ - dt = bm.get_dt() if dt is None else dt - assert isinstance(sp_times, (list, tuple)) - if isinstance(sp_lens, (float, int)): - sp_lens = [sp_lens] * len(sp_times) - if isinstance(sp_sizes, (float, int)): - sp_sizes = [sp_sizes] * len(sp_times) - - current = bm.zeros(int(np.ceil(duration / dt))) - for time, dur, size in zip(sp_times, sp_lens, sp_sizes): - pp = int(time / dt) - p_len = int(dur / dt) - current[pp: pp + p_len] = size - return current + """Format current input like a series of short-time spikes. + + For example: + + If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, + and each spike lasts 1 ms and the spike current is 0.5, then you can use the + following funtions: + + >>> spike_input(sp_times=[10, 20, 30, 200, 300], + >>> sp_lens=1., # can be a list to specify the spike length at each point + >>> sp_sizes=0.5, # can be a list to specify the current size at each point + >>> duration=400.) + + Parameters:: + + sp_times : list, tuple + The spike time-points. Must be an iterable object. + sp_lens : int, float, list, tuple + The length of each point-current, mimicking the spike durations. + sp_sizes : int, float, list, tuple + The current sizes. + duration : int, float + The total current duration. + dt : float + The default is None. + + Returns:: + + current : bm.ndarray + The formatted input current. + """ + dt = bm.get_dt() if dt is None else dt + assert isinstance(sp_times, (list, tuple)) + if isinstance(sp_lens, (float, int)): + sp_lens = [sp_lens] * len(sp_times) + if isinstance(sp_sizes, (float, int)): + sp_sizes = [sp_sizes] * len(sp_times) + + current = bm.zeros(int(np.ceil(duration / dt))) + for time, dur, size in zip(sp_times, sp_lens, sp_sizes): + pp = int(time / dt) + p_len = int(dur / dt) + current[pp: pp + p_len] = size + return current def spike_current(*args, **kwargs): - """Format current input like a series of short-time spikes. + """Format current input like a series of short-time spikes. - .. deprecated:: 2.1.13 - Use ``spike_current()`` instead. - """ - warnings.warn('Please use "brainpy.inputs.spike_input()" instead. ' - '"brainpy.inputs.spike_current()" is deprecated since version 2.1.13.', - DeprecationWarning) - return constant_input(*args, **kwargs) + .. deprecated:: 2.1.13 + Use ``spike_current()`` instead. + """ + warnings.warn('Please use "brainpy.inputs.spike_input()" instead. ' + '"brainpy.inputs.spike_current()" is deprecated since version 2.1.13.', + DeprecationWarning) + return constant_input(*args, **kwargs) def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): - """Get the gradually changed input current. - - Parameters:: - - c_start : float - The minimum (or maximum) current size. - c_end : float - The maximum (or minimum) current size. - duration : int, float - The total duration. - t_start : float - The ramped current start time-point. - t_end : float - The ramped current end time-point. Default is the None. - dt : float, int, optional - The numerical precision. - - Returns:: - - current : bm.ndarray - The formatted current - """ - dt = bm.get_dt() if dt is None else dt - t_end = duration if t_end is None else t_end - - current = bm.zeros(int(np.ceil(duration / dt))) - p1 = int(np.ceil(t_start / dt)) - p2 = int(np.ceil(t_end / dt)) - cc = jnp.array(jnp.linspace(c_start, c_end, p2 - p1)) - current[p1: p2] = cc - return current + """Get the gradually changed input current. + + Parameters:: + + c_start : float + The minimum (or maximum) current size. + c_end : float + The maximum (or minimum) current size. + duration : int, float + The total duration. + t_start : float + The ramped current start time-point. + t_end : float + The ramped current end time-point. Default is the None. + dt : float, int, optional + The numerical precision. + + Returns:: + + current : bm.ndarray + The formatted current + """ + dt = bm.get_dt() if dt is None else dt + t_end = duration if t_end is None else t_end + + current = bm.zeros(int(np.ceil(duration / dt))) + p1 = int(np.ceil(t_start / dt)) + p2 = int(np.ceil(t_end / dt)) + cc = jnp.array(jnp.linspace(c_start, c_end, p2 - p1)) + current[p1: p2] = cc + return current def ramp_current(*args, **kwargs): - """Get the gradually changed input current. + """Get the gradually changed input current. - .. deprecated:: 2.1.13 - Use ``ramp_input()`` instead. - """ - warnings.warn('Please use "brainpy.inputs.ramp_input()" instead. ' - '"brainpy.inputs.ramp_current()" is deprecated since version 2.1.13.', - DeprecationWarning) - return constant_input(*args, **kwargs) + .. deprecated:: 2.1.13 + Use ``ramp_input()`` instead. + """ + warnings.warn('Please use "brainpy.inputs.ramp_input()" instead. ' + '"brainpy.inputs.ramp_current()" is deprecated since version 2.1.13.', + DeprecationWarning) + return constant_input(*args, **kwargs) def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None): - """Stimulus sampled from a Wiener process, i.e. - drawn from standard normal distribution N(0, sqrt(dt)). - - Parameters:: - - duration: float - The input duration. - dt: float - The numerical precision. - n: int - The variable number. - t_start: float - The start time. - t_end: float - The end time. - seed: int - The noise seed. - """ - dt = bm.get_dt() if dt is None else dt - is_float(dt, 'dt', allow_none=False, min_bound=0.) - is_integer(n, 'n', allow_none=False, min_bound=0) - rng = bm.random.default_rng(seed, clone=False) - t_end = duration if t_end is None else t_end - i_start = int(t_start / dt) - i_end = int(t_end / dt) - noises = rng.standard_normal((i_end - i_start, n)) * jnp.sqrt(dt) - currents = bm.zeros((int(duration / dt), n)) - currents[i_start: i_end] = noises - return currents + """Stimulus sampled from a Wiener process, i.e. + drawn from standard normal distribution N(0, sqrt(dt)). + + Parameters:: + + duration: float + The input duration. + dt: float + The numerical precision. + n: int + The variable number. + t_start: float + The start time. + t_end: float + The end time. + seed: int + The noise seed. + """ + dt = bm.get_dt() if dt is None else dt + is_float(dt, 'dt', allow_none=False, min_bound=0.) + is_integer(n, 'n', allow_none=False, min_bound=0) + rng = bm.random.default_rng(seed, clone=False) + t_end = duration if t_end is None else t_end + i_start = int(t_start / dt) + i_end = int(t_end / dt) + noises = rng.standard_normal((i_end - i_start, n)) * jnp.sqrt(dt) + currents = bm.zeros((int(duration / dt), n)) + currents[i_start: i_end] = noises + return currents def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None): - r"""Ornstein–Uhlenbeck input. - - .. math:: - - dX = (mu - X)/\tau * dt + \sigma*dW - - Parameters:: - - mean: float - Drift of the OU process. - sigma: float - Standard deviation of the Wiener process, i.e. strength of the noise. - tau: float - Timescale of the OU process, in ms. - duration: float - The input duration. - dt: float - The numerical precision. - n: int - The variable number. - t_start: float - The start time. - t_end: float - The end time. - seed: optional, int - The random seed. - """ - dt = bm.get_dt() if dt is None else dt - dt_sqrt = jnp.sqrt(dt) - is_float(dt, 'dt', allow_none=False, min_bound=0.) - is_integer(n, 'n', allow_none=False, min_bound=0) - rng = bm.random.default_rng(seed, clone=False) - x = bm.Variable(jnp.ones(n) * mean) - - def _f(t): - x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.randn(n) - return x.value - - noises = bm.for_loop(_f, jnp.arange(t_start, t_end, dt)) - - t_end = duration if t_end is None else t_end - i_start = int(t_start / dt) - i_end = int(t_end / dt) - currents = bm.zeros((int(duration / dt), n)) - currents[i_start: i_end] = noises - return currents + r"""Ornstein–Uhlenbeck input. + + .. math:: + + dX = (mu - X)/\tau * dt + \sigma*dW + + Parameters:: + + mean: float + Drift of the OU process. + sigma: float + Standard deviation of the Wiener process, i.e. strength of the noise. + tau: float + Timescale of the OU process, in ms. + duration: float + The input duration. + dt: float + The numerical precision. + n: int + The variable number. + t_start: float + The start time. + t_end: float + The end time. + seed: optional, int + The random seed. + """ + dt = bm.get_dt() if dt is None else dt + dt_sqrt = jnp.sqrt(dt) + is_float(dt, 'dt', allow_none=False, min_bound=0.) + is_integer(n, 'n', allow_none=False, min_bound=0) + rng = bm.random.default_rng(seed, clone=False) + x = bm.Variable(jnp.ones(n) * mean) + + def _f(t): + x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.randn(n) + return x.value + + noises = bm.for_loop(_f, jnp.arange(t_start, t_end, dt)) + + t_end = duration if t_end is None else t_end + i_start = int(t_start / dt) + i_end = int(t_end / dt) + currents = bm.zeros((int(duration / dt), n)) + currents[i_start: i_end] = noises + return currents def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=None, bias=False): - """Sinusoidal input. - - Parameters:: - - amplitude: float - Amplitude of the sinusoid. - frequency: float - Frequency of the sinus oscillation, in Hz - duration: float - The input duration. - t_start: float - The start time. - t_end: float - The end time. - dt: float - The numerical precision. - bias: bool - Whether the sinusoid oscillates around 0 (False), or - has a positive DC bias, thus non-negative (True). - """ - dt = bm.get_dt() if dt is None else dt - is_float(dt, 'dt', allow_none=False, min_bound=0.) - if t_end is None: - t_end = duration - times = jnp.arange(0, t_end - t_start, dt) - start_i = int(t_start / dt) - end_i = int(t_end / dt) - sin_inputs = amplitude * jnp.sin(2 * jnp.pi * times * (frequency / 1000.0)) - if bias: sin_inputs += amplitude - currents = bm.zeros(int(duration / dt)) - currents[start_i:end_i] = sin_inputs - return currents + """Sinusoidal input. + + Parameters:: + + amplitude: float + Amplitude of the sinusoid. + frequency: float + Frequency of the sinus oscillation, in Hz + duration: float + The input duration. + t_start: float + The start time. + t_end: float + The end time. + dt: float + The numerical precision. + bias: bool + Whether the sinusoid oscillates around 0 (False), or + has a positive DC bias, thus non-negative (True). + """ + dt = bm.get_dt() if dt is None else dt + is_float(dt, 'dt', allow_none=False, min_bound=0.) + if t_end is None: + t_end = duration + times = jnp.arange(0, t_end - t_start, dt) + start_i = int(t_start / dt) + end_i = int(t_end / dt) + sin_inputs = amplitude * jnp.sin(2 * jnp.pi * times * (frequency / 1000.0)) + if bias: sin_inputs += amplitude + currents = bm.zeros(int(duration / dt)) + currents[start_i:end_i] = sin_inputs + return currents def _square(t, duty=0.5): - t, w = np.asarray(t), np.asarray(duty) - w = np.asarray(w + (t - t)) - t = np.asarray(t + (w - w)) - if t.dtype.char in 'fFdD': - ytype = t.dtype.char - else: - ytype = 'd' + t, w = np.asarray(t), np.asarray(duty) + w = np.asarray(w + (t - t)) + t = np.asarray(t + (w - w)) + if t.dtype.char in 'fFdD': + ytype = t.dtype.char + else: + ytype = 'd' - y = np.zeros(t.shape, ytype) + y = np.zeros(t.shape, ytype) - # width must be between 0 and 1 inclusive - mask1 = (w > 1) | (w < 0) - np.place(y, mask1, np.nan) + # width must be between 0 and 1 inclusive + mask1 = (w > 1) | (w < 0) + np.place(y, mask1, np.nan) - # on the interval 0 to duty*2*pi function is 1 - tmod = np.mod(t, 2 * np.pi) - mask2 = (1 - mask1) & (tmod < w * 2 * np.pi) - np.place(y, mask2, 1) + # on the interval 0 to duty*2*pi function is 1 + tmod = np.mod(t, 2 * np.pi) + mask2 = (1 - mask1) & (tmod < w * 2 * np.pi) + np.place(y, mask2, 1) - # on the interval duty*2*pi to 2*pi function is - # (pi*(w+1)-tmod) / (pi*(1-w)) - mask3 = (1 - mask1) & (1 - mask2) - np.place(y, mask3, -1) - return y + # on the interval duty*2*pi to 2*pi function is + # (pi*(w+1)-tmod) / (pi*(1-w)) + mask3 = (1 - mask1) & (1 - mask2) + np.place(y, mask3, -1) + return y def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0., t_end=None): - """Oscillatory square input. - - Parameters:: - - amplitude: float - Amplitude of the square oscillation. - frequency: float - Frequency of the square oscillation, in Hz. - duration: float - The input duration. - t_start: float - The start time. - t_end: float - The end time. - dt: float - The numerical precision. - bias: bool - Whether the sinusoid oscillates around 0 (False), or - has a positive DC bias, thus non-negative (True). - """ - dt = bm.get_dt() if dt is None else dt - is_float(dt, 'dt', allow_none=False, min_bound=0.) - if t_end is None: t_end = duration - times = np.arange(0, t_end - t_start, dt) - sin_inputs = amplitude * _square(2 * np.pi * times * (frequency / 1000.0)) - if bias: sin_inputs += amplitude - currents = bm.zeros(int(duration / dt)) - start_i = int(t_start / dt) - end_i = int(t_end / dt) - currents[start_i:end_i] = bm.asarray(sin_inputs) - return currents + """Oscillatory square input. + + Parameters:: + + amplitude: float + Amplitude of the square oscillation. + frequency: float + Frequency of the square oscillation, in Hz. + duration: float + The input duration. + t_start: float + The start time. + t_end: float + The end time. + dt: float + The numerical precision. + bias: bool + Whether the sinusoid oscillates around 0 (False), or + has a positive DC bias, thus non-negative (True). + """ + dt = bm.get_dt() if dt is None else dt + is_float(dt, 'dt', allow_none=False, min_bound=0.) + if t_end is None: t_end = duration + times = np.arange(0, t_end - t_start, dt) + sin_inputs = amplitude * _square(2 * np.pi * times * (frequency / 1000.0)) + if bias: sin_inputs += amplitude + currents = bm.zeros(int(duration / dt)) + start_i = int(t_start / dt) + end_i = int(t_end / dt) + currents[start_i:end_i] = bm.asarray(sin_inputs) + return currents diff --git a/brainpy/_src/inputs/tests/test_currents.py b/brainpy/_src/inputs/tests/test_currents.py index a9173c093..197c233f2 100644 --- a/brainpy/_src/inputs/tests/test_currents.py +++ b/brainpy/_src/inputs/tests/test_currents.py @@ -3,7 +3,6 @@ from unittest import TestCase - import numpy as np import brainpy as bp @@ -13,82 +12,82 @@ def show(current, duration, title=''): - global plt - if plt is None: - import matplotlib.pyplot as plt - ts = np.arange(0, duration, bp.math.get_dt()) - plt.plot(ts, current) - plt.title(title) - plt.xlabel('Time [ms]') - plt.ylabel('Current Value') - plt.show(block=block) + global plt + if plt is None: + import matplotlib.pyplot as plt + ts = np.arange(0, duration, bp.math.get_dt()) + plt.plot(ts, current) + plt.title(title) + plt.xlabel('Time [ms]') + plt.ylabel('Current Value') + plt.show(block=block) class TestCurrents(TestCase): - def test_section_input(self): - current1, duration = bp.inputs.section_input(values=[0, 1., 0.], - durations=[100, 300, 100], - return_length=True, - dt=0.1) - show(current1, duration, 'values=[0, 1, 0], durations=[100, 300, 100]') - - def test_constant_input(self): - current2, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)]) - show(current2, duration, '[(0, 100), (1, 300), (0, 100)]') - - def test_spike_input(self): - current3 = bp.inputs.spike_input( - sp_times=[10, 20, 30, 200, 300], - sp_lens=1., # can be a list to specify the spike length at each point - sp_sizes=0.5, # can be a list to specify the spike current size at each point - duration=400.) - - show(current3, 400, 'Spike Input Example') - - def test_ramp_input(self): - duration = 500 - current4 = bp.inputs.ramp_input(0, 1, duration) - - show(current4, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, ' - r'$t_{start}$=0, $t_{end}$=None' % (duration)) - - def test_ramp_input2(self): - duration, t_start, t_end = 500, 100, 400 - current5 = bp.inputs.ramp_input(0, 1, duration, t_start, t_end) - - show(current5, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, ' - r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end)) - - def test_wiener_process(self): - duration = 200 - current6 = bp.inputs.wiener_process(duration, n=2, t_start=10., t_end=180.) - show(current6, duration, 'Wiener Process') - - def test_ou_process(self): - duration = 200 - current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.) - show(current7, duration, 'Ornstein-Uhlenbeck Process') - - def test_sinusoidal_input(self): - duration = 2000 - current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0, duration=duration, t_start=100., ) - show(current8, duration, 'Sinusoidal Input') - - def test_square_input(self): - duration = 2000 - current9 = bp.inputs.square_input(amplitude=1., frequency=2.0, - duration=duration, t_start=100) - show(current9, duration, 'Square Input') - - def test_general1(self): - I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1) - I2 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.01) - self.assertTrue(I1.shape[0] == 600) - self.assertTrue(I2.shape[0] == 6000) - - def test_general2(self): - bp.math.random.seed(123) - current = bp.inputs.section_input(values=[0, bp.math.ones(10), - bp.math.random.random((3, 10))], - durations=[100, 300, 100]) - self.assertTrue(current.shape == (5000, 3, 10)) + def test_section_input(self): + current1, duration = bp.inputs.section_input(values=[0, 1., 0.], + durations=[100, 300, 100], + return_length=True, + dt=0.1) + show(current1, duration, 'values=[0, 1, 0], durations=[100, 300, 100]') + + def test_constant_input(self): + current2, duration = bp.inputs.constant_input([(0, 100), (1, 300), (0, 100)]) + show(current2, duration, '[(0, 100), (1, 300), (0, 100)]') + + def test_spike_input(self): + current3 = bp.inputs.spike_input( + sp_times=[10, 20, 30, 200, 300], + sp_lens=1., # can be a list to specify the spike length at each point + sp_sizes=0.5, # can be a list to specify the spike current size at each point + duration=400.) + + show(current3, 400, 'Spike Input Example') + + def test_ramp_input(self): + duration = 500 + current4 = bp.inputs.ramp_input(0, 1, duration) + + show(current4, duration, r'$c_{start}$=0, $c_{end}$=%d, duration, ' + r'$t_{start}$=0, $t_{end}$=None' % (duration)) + + def test_ramp_input2(self): + duration, t_start, t_end = 500, 100, 400 + current5 = bp.inputs.ramp_input(0, 1, duration, t_start, t_end) + + show(current5, duration, r'$c_{start}$=0, $c_{end}$=1, duration=%d, ' + r'$t_{start}$=%d, $t_{end}$=%d' % (duration, t_start, t_end)) + + def test_wiener_process(self): + duration = 200 + current6 = bp.inputs.wiener_process(duration, n=2, t_start=10., t_end=180.) + show(current6, duration, 'Wiener Process') + + def test_ou_process(self): + duration = 200 + current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.) + show(current7, duration, 'Ornstein-Uhlenbeck Process') + + def test_sinusoidal_input(self): + duration = 2000 + current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0, duration=duration, t_start=100., ) + show(current8, duration, 'Sinusoidal Input') + + def test_square_input(self): + duration = 2000 + current9 = bp.inputs.square_input(amplitude=1., frequency=2.0, + duration=duration, t_start=100) + show(current9, duration, 'Square Input') + + def test_general1(self): + I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1) + I2 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.01) + self.assertTrue(I1.shape[0] == 600) + self.assertTrue(I2.shape[0] == 6000) + + def test_general2(self): + bp.math.random.seed(123) + current = bp.inputs.section_input(values=[0, bp.math.ones(10), + bp.math.random.random((3, 10))], + durations=[100, 300, 100]) + self.assertTrue(current.shape == (5000, 3, 10)) diff --git a/brainpy/_src/integrators/__init__.py b/brainpy/_src/integrators/__init__.py index a8ecacef8..4df1514fe 100644 --- a/brainpy/_src/integrators/__init__.py +++ b/brainpy/_src/integrators/__init__.py @@ -12,35 +12,31 @@ Details please see the following. """ +# FDE tools +from . import fde +# ODE tools +from . import ode +# PDE tools +from . import pde +# SDE tools +from . import sde # basic tools from .base import * from .constants import * +from .fde.base import FDEIntegrator +from .fde.generic import (fdeint, + get_default_fdeint, + set_default_fdeint, + register_fde_integrator) from .joint_eq import * -from .runner import * - -# ODE tools -from . import ode from .ode.base import ODEIntegrator from .ode.generic import (odeint, get_default_odeint, set_default_odeint, register_ode_integrator) - -# SDE tools -from . import sde +from .runner import * from .sde.base import SDEIntegrator from .sde.generic import (sdeint, get_default_sdeint, set_default_sdeint, register_sde_integrator) - -# FDE tools -from . import fde -from .fde.base import FDEIntegrator -from .fde.generic import (fdeint, - get_default_fdeint, - set_default_fdeint, - register_fde_integrator) - -# PDE tools -from . import pde diff --git a/brainpy/_src/integrators/_jaxpr_to_source_code.py b/brainpy/_src/integrators/_jaxpr_to_source_code.py index 6c8c0c9e8..a00bbe9cb 100644 --- a/brainpy/_src/integrators/_jaxpr_to_source_code.py +++ b/brainpy/_src/integrators/_jaxpr_to_source_code.py @@ -25,130 +25,131 @@ import jax.numpy as jnp import numpy as np from jax._src.sharding_impls import UNSPECIFIED + if jax.__version__ >= '0.5.0': - from jax.extend.core import Primitive, Literal, Var, Jaxpr + from jax.extend.core import Literal, Var, Jaxpr else: - from jax.core import Primitive, Literal, Var, Jaxpr + from jax.core import Primitive, Literal, Var, Jaxpr __all__ = [ - 'fn_to_python_code', - 'jaxpr_to_python_code', + 'fn_to_python_code', + 'jaxpr_to_python_code', ] class IdentitySet(MutableSet): - """Set that compares objects by identity. + """Set that compares objects by identity. - This is a set that compares objects by identity instead of equality. It is - useful for storing objects that are not hashable or that should be compared - by identity. + This is a set that compares objects by identity instead of equality. It is + useful for storing objects that are not hashable or that should be compared + by identity. - This is a mutable set, but it does not support the ``__hash__`` method and - therefore cannot be used as a dictionary key or as an element of another set. - """ + This is a mutable set, but it does not support the ``__hash__`` method and + therefore cannot be used as a dictionary key or as an element of another set. + """ - def __init__(self, iterable=None): - self._data = {} - if iterable is not None: - self.update(iterable) + def __init__(self, iterable=None): + self._data = {} + if iterable is not None: + self.update(iterable) - def __contains__(self, value): - return id(value) in self._data + def __contains__(self, value): + return id(value) in self._data - def __iter__(self): - return iter(self._data.values()) + def __iter__(self): + return iter(self._data.values()) - def __len__(self): - return len(self._data) + def __len__(self): + return len(self._data) - def add(self, value): - self._data[id(value)] = value + def add(self, value): + self._data[id(value)] = value - def discard(self, value): - self._data.pop(id(value), None) + def discard(self, value): + self._data.pop(id(value), None) - def __repr__(self): - return f"IdentitySet({list(repr(x) for x in self._data.values())})" + def __repr__(self): + return f"IdentitySet({list(repr(x) for x in self._data.values())})" - def __str__(self): - return f"IdentitySet({list(str(x) for x in self._data.values())})" + def __str__(self): + return f"IdentitySet({list(str(x) for x in self._data.values())})" class IdentityMap(MutableMapping): - """Map that compares keys by identity. + """Map that compares keys by identity. - This is a map that compares keys by identity instead of equality. It is - useful for storing objects that are not hashable or that should be compared - by identity. + This is a map that compares keys by identity instead of equality. It is + useful for storing objects that are not hashable or that should be compared + by identity. - This is a mutable mapping, but it does not support the ``__hash__`` method - and therefore cannot be used as a dictionary key or as an element of another - set. - """ + This is a mutable mapping, but it does not support the ``__hash__`` method + and therefore cannot be used as a dictionary key or as an element of another + set. + """ - def __init__(self, iterable=None): - self._data = {} - if iterable is not None: - self.update(iterable) + def __init__(self, iterable=None): + self._data = {} + if iterable is not None: + self.update(iterable) - def __contains__(self, key): - return id(key) in self._data + def __contains__(self, key): + return id(key) in self._data - def __getitem__(self, key): - return self._data[id(key)] + def __getitem__(self, key): + return self._data[id(key)] - def __setitem__(self, key, value): - self._data[id(key)] = value + def __setitem__(self, key, value): + self._data[id(key)] = value - def __delitem__(self, key): - del self._data[id(key)] + def __delitem__(self, key): + del self._data[id(key)] - def __iter__(self): - return iter(self._data.values()) + def __iter__(self): + return iter(self._data.values()) - def __len__(self): - return len(self._data) + def __len__(self): + return len(self._data) - def __repr__(self): - return f"IdentityMap({list(repr(x) for x in self._data.values())})" + def __repr__(self): + return f"IdentityMap({list(repr(x) for x in self._data.values())})" - def __str__(self): - return f"IdentityMap({list(str(x) for x in self._data.values())})" + def __str__(self): + return f"IdentityMap({list(str(x) for x in self._data.values())})" @dataclass class SourcerorState: - """State for the auto-minimizer. Basically just in charge of naming variables.""" - _var_names: IdentityMap[Var, str] = field(default_factory=IdentityMap) - _skolem_count: int = 0 - - def name(self, var, ctx=ast.Load()) -> ast.Name: - return ast.Name(id=self.str_name(var), ctx=ctx) - - def str_name(self, var: Var): - # Names things in a way vaguely compatible with - # JAX's naming scheme, which is 'a'-'z' followed - # by 'aa'-'az' etc. - if var in self._var_names: - return self._var_names[var] - else: - cur_count = len(self._var_names) - name = "" - while cur_count >= 26: - name += chr(ord('a') + cur_count % 26) - cur_count //= 26 + """State for the auto-minimizer. Basically just in charge of naming variables.""" + _var_names: IdentityMap[Var, str] = field(default_factory=IdentityMap) + _skolem_count: int = 0 + + def name(self, var, ctx=ast.Load()) -> ast.Name: + return ast.Name(id=self.str_name(var), ctx=ctx) + + def str_name(self, var: Var): + # Names things in a way vaguely compatible with + # JAX's naming scheme, which is 'a'-'z' followed + # by 'aa'-'az' etc. + if var in self._var_names: + return self._var_names[var] + else: + cur_count = len(self._var_names) + name = "" + while cur_count >= 26: + name += chr(ord('a') + cur_count % 26) + cur_count //= 26 - name += chr(ord('a') + cur_count) + name += chr(ord('a') + cur_count) - name = name[::-1] + name = name[::-1] - self._var_names[var] = name + self._var_names[var] = name - return name + return name - def skolem(self, prefix: str): - self._skolem_count += 1 - return f"{prefix}_{self._skolem_count}" + def skolem(self, prefix: str): + self._skolem_count += 1 + return f"{prefix}_{self._skolem_count}" prefix_imports = set() @@ -156,154 +157,154 @@ def skolem(self, prefix: str): @contextmanager def catch_imports(): - try: - prefix_imports.clear() - yield - finally: - prefix_imports.clear() + try: + prefix_imports.clear() + yield + finally: + prefix_imports.clear() def fn_to_python_code(fn, *args, **kwargs): - """ - Given a function which is defined by jax primitives and the function arguments, - return the Python code that would be generated by JAX for that function. - - :param fn: The function to generate code for - :param args: The positional arguments to the function - :param kwargs: The keyword arguments to the function - :return: The Python code that would be generated by JAX for that function - """ - closed_jaxpr = jax.make_jaxpr(fn)(*args, **kwargs) - jaxpr = constant_fold_jaxpr(closed_jaxpr.jaxpr) - state = SourcerorState() - try: - name = fn.__name__ - except AttributeError: - name = "unknown" - with catch_imports(): - node = jaxpr_to_py_ast(state, jaxpr, fn_name=name) - node = _maybe_wrap_fn_for_leaves(node, fn, len(args) + len(kwargs)) - ast.fix_missing_locations(node) - source = ast.unparse(node) - if len(prefix_imports): - source = "\n".join(prefix_imports) + "\n\n" + source - return source + """ + Given a function which is defined by jax primitives and the function arguments, + return the Python code that would be generated by JAX for that function. + + :param fn: The function to generate code for + :param args: The positional arguments to the function + :param kwargs: The keyword arguments to the function + :return: The Python code that would be generated by JAX for that function + """ + closed_jaxpr = jax.make_jaxpr(fn)(*args, **kwargs) + jaxpr = constant_fold_jaxpr(closed_jaxpr.jaxpr) + state = SourcerorState() + try: + name = fn.__name__ + except AttributeError: + name = "unknown" + with catch_imports(): + node = jaxpr_to_py_ast(state, jaxpr, fn_name=name) + node = _maybe_wrap_fn_for_leaves(node, fn, len(args) + len(kwargs)) + ast.fix_missing_locations(node) + source = ast.unparse(node) + if len(prefix_imports): + source = "\n".join(prefix_imports) + "\n\n" + source + return source def jaxpr_to_python_code(jaxpr: Jaxpr, fn_name: str = "generated_function"): - """ - Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. - - :param jaxpr: The jaxpr to generate code. - :param fn_name: The name of the function to generate code. - :return: The Python code that would be generated by JAX for that jaxpr - """ - jaxpr = constant_fold_jaxpr(jaxpr) - state = SourcerorState() - with catch_imports(): - node = jaxpr_to_py_ast(state, jaxpr, fn_name=fn_name) - ast.fix_missing_locations(node) - source = ast.unparse(node) - if len(prefix_imports): - source = "\n".join(prefix_imports) + "\n\n" + source - return source + """ + Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. + + :param jaxpr: The jaxpr to generate code. + :param fn_name: The name of the function to generate code. + :return: The Python code that would be generated by JAX for that jaxpr + """ + jaxpr = constant_fold_jaxpr(jaxpr) + state = SourcerorState() + with catch_imports(): + node = jaxpr_to_py_ast(state, jaxpr, fn_name=fn_name) + ast.fix_missing_locations(node) + source = ast.unparse(node) + if len(prefix_imports): + source = "\n".join(prefix_imports) + "\n\n" + source + return source def register_prim_handler(prim_name, handler): - """ - Register a handler for a primitive for automin - :param prim_name: - :param handler: - :return: - """ - if prim_name in prim_to_python: - warnings.warn(f"Overwriting handler for primitive {prim_name}") - prim_to_python[prim_name] = handler + """ + Register a handler for a primitive for automin + :param prim_name: + :param handler: + :return: + """ + if prim_name in prim_to_python: + warnings.warn(f"Overwriting handler for primitive {prim_name}") + prim_to_python[prim_name] = handler def register_prim_as(prim_name): - """ - Decorator to register a handler for a primitive. + """ + Decorator to register a handler for a primitive. - :param prim_name: - :return: - """ + :param prim_name: + :return: + """ - def decorator(fn): - register_prim_handler(prim_name, fn) - return fn + def decorator(fn): + register_prim_handler(prim_name, fn) + return fn - return decorator + return decorator def _assign_stmt(call_expr: Callable): - """ - Create a handler for a primitive that is a simple assignment. - :param call_expr: - :return: - """ - - def binop_fn(state, eqn): - invars = [_astify_atom(state, v) for v in eqn.invars] - outvars = _astify_outvars(state, eqn.outvars) - return ast.Assign( - outvars, - call_expr( - *invars, - **{k: _astify_value(v) for k, v in eqn.params.items()} - ) - ) + """ + Create a handler for a primitive that is a simple assignment. + :param call_expr: + :return: + """ + + def binop_fn(state, eqn): + invars = [_astify_atom(state, v) for v in eqn.invars] + outvars = _astify_outvars(state, eqn.outvars) + return ast.Assign( + outvars, + call_expr( + *invars, + **{k: _astify_value(v) for k, v in eqn.params.items()} + ) + ) - return binop_fn + return binop_fn def _binop_fn(op: ast.operator): - return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y)) + return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y)) def _cmpop_fn(op: ast.cmpop): - return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y])) + return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y])) def normal_fn(fn_name): - """ - Create a handler for a normal function call. - :param fn_name: - :return: - """ - return _assign_stmt( - lambda *args, **kwargs: ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=list(args), - keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()] + """ + Create a handler for a normal function call. + :param fn_name: + :return: + """ + return _assign_stmt( + lambda *args, **kwargs: ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=list(args), + keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()] + ) ) - ) def _reduce_fn(fn_name: str): - def reduce_fn_inner(state: SourcerorState, eqn): - invars = [_astify_atom(state, v) for v in eqn.invars] - outvars = _astify_outvars(state, eqn.outvars) - if eqn.params: - params = eqn.params.copy() - params['axis'] = tuple(params['axes']) - del params['axes'] - call_op = ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=invars, - keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()] - ) - else: - call_op = ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=invars, - keywords=[] - ) + def reduce_fn_inner(state: SourcerorState, eqn): + invars = [_astify_atom(state, v) for v in eqn.invars] + outvars = _astify_outvars(state, eqn.outvars) + if eqn.params: + params = eqn.params.copy() + params['axis'] = tuple(params['axes']) + del params['axes'] + call_op = ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=invars, + keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()] + ) + else: + call_op = ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=invars, + keywords=[] + ) - return ast.Assign(outvars, call_op) + return ast.Assign(outvars, call_op) - return reduce_fn_inner + return reduce_fn_inner prim_to_python = dict() @@ -329,807 +330,808 @@ def reduce_fn_inner(state: SourcerorState, eqn): def _maybe_wrap_fn_for_leaves(node, f, num_args): - if len(node.args.args) == num_args: - return node - - wrapped_node = ast.FunctionDef( - name=f.__name__, - args=ast.arguments( - args=[], - vararg=ast.arg(arg="args", annotation=None), - kwarg=ast.arg(arg="kwargs", annotation=None), - kwonlyargs=[], kw_defaults=[], defaults=[], - posonlyargs=[] - ), - body=[ - node, - ast.Return( - ast.Call( - func=ast.Name(id=node.name, ctx=ast.Load()), - args=[ - ast.Starred( - ast.Call( - func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()), - attr="tree_leaves", - ctx=ast.Load()), - args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()), - ast.Name(id="kwargs", ctx=ast.Load())], - ctx=ast.Load())], - keywords=[] - ) - ) - ], - keywords=[] - ) - ), - ], - decorator_list=[] - ) + if len(node.args.args) == num_args: + return node + + wrapped_node = ast.FunctionDef( + name=f.__name__, + args=ast.arguments( + args=[], + vararg=ast.arg(arg="args", annotation=None), + kwarg=ast.arg(arg="kwargs", annotation=None), + kwonlyargs=[], kw_defaults=[], defaults=[], + posonlyargs=[] + ), + body=[ + node, + ast.Return( + ast.Call( + func=ast.Name(id=node.name, ctx=ast.Load()), + args=[ + ast.Starred( + ast.Call( + func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()), + attr="tree_leaves", + ctx=ast.Load()), + args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()), + ast.Name(id="kwargs", ctx=ast.Load())], + ctx=ast.Load())], + keywords=[] + ) + ) + ], + keywords=[] + ) + ), + ], + decorator_list=[] + ) - return wrapped_node + return wrapped_node def jaxpr_to_py_ast(state: SourcerorState, jaxpr: Jaxpr, fn_name: str = "function"): - # Generate argument declarations - ast_args = [ast.arg(arg=state.str_name(var), annotation=None) - for var in jaxpr.invars] - ast_args = ast.arguments(args=ast_args, - vararg=None, - kwonlyargs=[], - kw_defaults=[], - kwarg=None, - defaults=[], - posonlyargs=[]) - - stmts = [] - - # Generate body of the function - for eqn in jaxpr.eqns: - prim = str(eqn.primitive) - if prim in prim_to_python: - eqn_stmts = prim_to_python[prim](state, eqn) - else: - eqn_stmts = normal_fn(prim)(state, eqn) + # Generate argument declarations + ast_args = [ast.arg(arg=state.str_name(var), annotation=None) + for var in jaxpr.invars] + ast_args = ast.arguments(args=ast_args, + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + posonlyargs=[]) + + stmts = [] + + # Generate body of the function + for eqn in jaxpr.eqns: + prim = str(eqn.primitive) + if prim in prim_to_python: + eqn_stmts = prim_to_python[prim](state, eqn) + else: + eqn_stmts = normal_fn(prim)(state, eqn) - if isinstance(eqn_stmts, list): - stmts.extend(eqn_stmts) - else: - stmts.append(eqn_stmts) + if isinstance(eqn_stmts, list): + stmts.extend(eqn_stmts) + else: + stmts.append(eqn_stmts) - # Generate return statement - if len(jaxpr.outvars) == 1: - returns = state.name(jaxpr.outvars[0]) - else: - returns = ast.Tuple(elts=[state.name(var) for var in jaxpr.outvars], ctx=ast.Load()) - stmts.append(ast.Return(value=returns)) + # Generate return statement + if len(jaxpr.outvars) == 1: + returns = state.name(jaxpr.outvars[0]) + else: + returns = ast.Tuple(elts=[state.name(var) for var in jaxpr.outvars], ctx=ast.Load()) + stmts.append(ast.Return(value=returns)) - return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) + return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) def constant_fold_jaxpr(jaxpr: Jaxpr): - """ - Given a jaxpr, return a new jaxpr with all constant folding done. - """ - return partial_eval_jaxpr(jaxpr, {}) + """ + Given a jaxpr, return a new jaxpr with all constant folding done. + """ + return partial_eval_jaxpr(jaxpr, {}) def partial_eval_jaxpr(jaxpr, env): - env = env.copy() - new_eqns = [] + env = env.copy() + new_eqns = [] - def read(var): - if isinstance(var, Literal): - return var.val - else: - return env.get(var, None) - - def read_or_self(var): - out = read(var) - if out is None: - return var - elif isinstance(out, Var): - return out - elif isinstance(out, Literal): - return Literal(out.val, var.aval) - else: - assert not isinstance(out, Jaxpr) - return Literal(out, var.aval) - - for eqn in jaxpr.eqns: - vals = [read(var) for var in eqn.invars] - if eqn.primitive.name in constant_fold_blacklist: - new_eqns.append(eqn) - elif all(val is not None for val in vals): - # go ahead and eval it - out = _eval_eqn(eqn, vals) - - # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values - if isinstance(out, Jaxpr): - # we need to inline this - new_eqns.extend(out.eqns) - out = out.outvars - elif not isinstance(out, tuple) and not isinstance(out, list): - out = (out,) - - for var, val in zip(eqn.outvars, out): - assert not isinstance(val, Jaxpr) - if isinstance(val, Literal): - env[var] = val.val + def read(var): + if isinstance(var, Literal): + return var.val else: - env[var] = val - else: - new_eqns.append(eqn) + return env.get(var, None) + + def read_or_self(var): + out = read(var) + if out is None: + return var + elif isinstance(out, Var): + return out + elif isinstance(out, Literal): + return Literal(out.val, var.aval) + else: + assert not isinstance(out, Jaxpr) + return Literal(out, var.aval) + + for eqn in jaxpr.eqns: + vals = [read(var) for var in eqn.invars] + if eqn.primitive.name in constant_fold_blacklist: + new_eqns.append(eqn) + elif all(val is not None for val in vals): + # go ahead and eval it + out = _eval_eqn(eqn, vals) + + # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values + if isinstance(out, Jaxpr): + # we need to inline this + new_eqns.extend(out.eqns) + out = out.outvars + elif not isinstance(out, tuple) and not isinstance(out, list): + out = (out,) + + for var, val in zip(eqn.outvars, out): + assert not isinstance(val, Jaxpr) + if isinstance(val, Literal): + env[var] = val.val + else: + env[var] = val + else: + new_eqns.append(eqn) - # now that we've evaled everything, inline all the constants - out_eqns = [] - for eqn in new_eqns: - eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars)) - out_eqns.append(eqn) + # now that we've evaled everything, inline all the constants + out_eqns = [] + for eqn in new_eqns: + eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars)) + out_eqns.append(eqn) - invars_still_used = IdentitySet() - for eqn in out_eqns: - for var in eqn.invars: - invars_still_used.add(var) + invars_still_used = IdentitySet() + for eqn in out_eqns: + for var in eqn.invars: + invars_still_used.add(var) - invars = tuple(var for var in jaxpr.invars if var in invars_still_used) + invars = tuple(var for var in jaxpr.invars if var in invars_still_used) - # sub in any constants for outvars - outvars = tuple(read_or_self(var) for var in jaxpr.outvars) + # sub in any constants for outvars + outvars = tuple(read_or_self(var) for var in jaxpr.outvars) - return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars) + return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars) def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jnp.ndarray]: - if eqn.primitive.name == "closed_call": - assert eqn.primitive.call_primitive == True - assert eqn.primitive.map_primitive == False - - out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, - {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) - elif eqn.primitive.name == "scan": - out = eqn.primitive.bind(*vals, **eqn.params) - else: - out = eqn.primitive.bind(*vals, **eqn.params) - return out + if eqn.primitive.name == "closed_call": + assert eqn.primitive.call_primitive == True + assert eqn.primitive.map_primitive == False + + out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, + {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) + elif eqn.primitive.name == "scan": + out = eqn.primitive.bind(*vals, **eqn.params) + else: + out = eqn.primitive.bind(*vals, **eqn.params) + return out @register_prim_as('dot_general') def _astify_dot_general(state, eqn): - x, y = eqn.invars - d = eqn.params['dimension_numbers'] - precision = eqn.params['precision'] - preferred_element_type = eqn.params['preferred_element_type'] - - has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type - - # recognize simple matmul case - if d == (((1,), (0,)), ((), ())) and precision == None: - invars = [_astify_atom(state, x), _astify_atom(state, y)] + x, y = eqn.invars + d = eqn.params['dimension_numbers'] + precision = eqn.params['precision'] + preferred_element_type = eqn.params['preferred_element_type'] + + has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type + + # recognize simple matmul case + if d == (((1,), (0,)), ((), ())) and precision == None: + invars = [_astify_atom(state, x), _astify_atom(state, y)] + outvars = _astify_outvars(state, eqn.outvars) + out = ast.Assign(targets=outvars, value=ast.Call( + func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='matmul', ctx=ast.Load()), + args=invars, + keywords=[])) + if not has_dtype: + out = ast.Assign(targets=outvars, + value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()), + args=[_astify_value(preferred_element_type)], keywords=[])) + + return out + + # TODO: convert to einsum? + + invars = [_astify_atom(state, x), + _astify_atom(state, y), + _astify_value(d), + _astify_value(precision), + _astify_value(preferred_element_type)] outvars = _astify_outvars(state, eqn.outvars) - out = ast.Assign(targets=outvars, value=ast.Call( - func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='matmul', ctx=ast.Load()), args=invars, - keywords=[])) - if not has_dtype: - out = ast.Assign(targets=outvars, - value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()), - args=[_astify_value(preferred_element_type)], keywords=[])) - - return out - - # TODO: convert to einsum? - - invars = [_astify_atom(state, x), - _astify_atom(state, y), - _astify_value(d), - _astify_value(precision), - _astify_value(preferred_element_type)] - outvars = _astify_outvars(state, eqn.outvars) - return ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()), - args=invars, - keywords=[] + return ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()), + args=invars, + keywords=[] + ) ) - ) @register_prim_as('dynamic_slice') def _sourcify_dynamic_slice(state, eqn): - sliced = eqn.invars[0] - invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) - outvars = _astify_outvars(state, eqn.outvars) - params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] - return ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='jax.lax', ctx=ast.Load()), - attr='dynamic_slice', - ctx=ast.Load() - ), - args=[_astify_atom(state, sliced), invars], - keywords=params + sliced = eqn.invars[0] + invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) + outvars = _astify_outvars(state, eqn.outvars) + params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] + return ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='jax.lax', ctx=ast.Load()), + attr='dynamic_slice', + ctx=ast.Load() + ), + args=[_astify_atom(state, sliced), invars], + keywords=params + ) ) - ) @register_prim_as('slice') def _sourcify_slice(state, eqn): - sliced = eqn.invars[0] - # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) - outvars = _astify_outvars(state, eqn.outvars) - start_indices = eqn.params['start_indices'] - limit_indices = eqn.params['limit_indices'] - strides = eqn.params['strides'] - if strides is None: - strides = (None,) * len(start_indices) - indices = [_astify_value(slice(s, e, stride)) - for s, e, stride in zip(start_indices, limit_indices, strides)] - # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] - return ast.Assign( - targets=outvars, - value=ast.Subscript( - value=_astify_atom(state, sliced), - slice=ast.Tuple(elts=indices, ctx=ast.Load()), - ctx=ast.Load() + sliced = eqn.invars[0] + # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) + outvars = _astify_outvars(state, eqn.outvars) + start_indices = eqn.params['start_indices'] + limit_indices = eqn.params['limit_indices'] + strides = eqn.params['strides'] + if strides is None: + strides = (None,) * len(start_indices) + indices = [_astify_value(slice(s, e, stride)) + for s, e, stride in zip(start_indices, limit_indices, strides)] + # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] + return ast.Assign( + targets=outvars, + value=ast.Subscript( + value=_astify_atom(state, sliced), + slice=ast.Tuple(elts=indices, ctx=ast.Load()), + ctx=ast.Load() + ) ) - ) @register_prim_as('dynamic_update_slice') def _sourcify_dynamic_update_slice(state, eqn): - sliced = eqn.invars[0] - # the first two arguments are the sliced array and the update array - # the remaining are start indices and should be packaged into a tuple - target = _astify_atom(state, eqn.invars[0]) - update = _astify_atom(state, eqn.invars[1]) - start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]]) - outvars = _astify_outvars(state, eqn.outvars) - - return ast.Assign(targets=outvars, value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='jax.lax', ctx=ast.Load()), - attr='dynamic_update_slice', - ctx=ast.Load() - ), - args=[target, update, start_indices], - keywords=[] - )) + sliced = eqn.invars[0] + # the first two arguments are the sliced array and the update array + # the remaining are start indices and should be packaged into a tuple + target = _astify_atom(state, eqn.invars[0]) + update = _astify_atom(state, eqn.invars[1]) + start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]]) + outvars = _astify_outvars(state, eqn.outvars) + + return ast.Assign(targets=outvars, value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='jax.lax', ctx=ast.Load()), + attr='dynamic_update_slice', + ctx=ast.Load() + ), + args=[target, update, start_indices], + keywords=[] + )) @register_prim_as('convert_element_type') def _astify_convert_element_type(state, eqn): - # now we use ast - outvars = _astify_outvars(state, eqn.outvars) - assert len(eqn.invars) == 1 - invar = _astify_atom(state, eqn.invars[0]) - dtype = _astify_value(eqn.params['new_dtype']) - return ast.Assign(targets=outvars, value=ast.Call( - func=ast.Attribute( - value=invar, - attr='astype', - ctx=ast.Load() - ), - args=[dtype], - keywords=[] - )) + # now we use ast + outvars = _astify_outvars(state, eqn.outvars) + assert len(eqn.invars) == 1 + invar = _astify_atom(state, eqn.invars[0]) + dtype = _astify_value(eqn.params['new_dtype']) + return ast.Assign(targets=outvars, value=ast.Call( + func=ast.Attribute( + value=invar, + attr='astype', + ctx=ast.Load() + ), + args=[dtype], + keywords=[] + )) def is_array(arr): - return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray)) + return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray)) def _astify_array(value): - assert is_array(value) - if isinstance(value, np.int64): - return ast.Constant(value=int(value)) + assert is_array(value) + if isinstance(value, np.int64): + return ast.Constant(value=int(value)) + + if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64): + return ast.Constant(value=value.item()) + + if value.ndim == 0: + dtype_value = _astify_value(value.dtype) + return ast.Call( + dtype_value, + args=[ast.Constant(value=value.item())], + keywords=[], + ) - if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64): - return ast.Constant(value=value.item()) + values = value.tolist() + + def rec_astify_list(values): + if isinstance(values, list): + return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load()) + else: + return ast.Constant(value=values) - if value.ndim == 0: - dtype_value = _astify_value(value.dtype) return ast.Call( - dtype_value, - args=[ast.Constant(value=value.item())], - keywords=[], + func=ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='array', + ctx=ast.Load() + ), + args=[rec_astify_list(values)], + keywords=[ast.keyword(arg='dtype', + value=_astify_value(value.dtype))] ) - values = value.tolist() - - def rec_astify_list(values): - if isinstance(values, list): - return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load()) - else: - return ast.Constant(value=values) - - return ast.Call( - func=ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='array', - ctx=ast.Load() - ), - args=[rec_astify_list(values)], - keywords=[ast.keyword(arg='dtype', - value=_astify_value(value.dtype))] - ) - def _astify_atom(state: SourcerorState, var: Union[Literal, Var]): - if isinstance(var, Literal): - return _astify_value(var.val) - elif isinstance(var, Var): - return state.name(var) - else: - raise NotImplementedError() + if isinstance(var, Literal): + return _astify_value(var.val) + elif isinstance(var, Var): + return state.name(var) + else: + raise NotImplementedError() def _astify_value(value): - assert not isinstance(value, (Literal, Var)) - - if is_array(value): - return _astify_array(value) - elif isinstance(value, (int, bool, float, str, type(None))): - return ast.Constant(value=value) - elif isinstance(value, (tuple, list)): - return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load()) - elif isinstance(value, jnp.dtype): - # return ast.Call(func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[]) - if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'): - # return ast.Constant(value=getattr(jnp, value.name)) - return ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr=value.name, - ctx=ast.Load() - ) - elif value.name == 'bool': - return ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='bool_', - ctx=ast.Load() - ) - else: - return ast.Call( - func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='dtype', - ctx=ast.Load()), - args=[ast.Constant(value=str(value))], - keywords=[] - ) - elif value is UNSPECIFIED: - prefix_imports.add('from jax._src.sharding_impls import UNSPECIFIED') - return ast.Name(id='UNSPECIFIED', ctx=ast.Load()) - elif isinstance(value, enum.Enum): - return ast.Attribute( - value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()), - attr=value.name, - ctx=ast.Load() - ) + assert not isinstance(value, (Literal, Var)) + + if is_array(value): + return _astify_array(value) + elif isinstance(value, (int, bool, float, str, type(None))): + return ast.Constant(value=value) + elif isinstance(value, (tuple, list)): + return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load()) + elif isinstance(value, jnp.dtype): + # return ast.Call(func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[]) + if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'): + # return ast.Constant(value=getattr(jnp, value.name)) + return ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr=value.name, + ctx=ast.Load() + ) + elif value.name == 'bool': + return ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='bool_', + ctx=ast.Load() + ) + else: + return ast.Call( + func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='dtype', + ctx=ast.Load()), + args=[ast.Constant(value=str(value))], + keywords=[] + ) + elif value is UNSPECIFIED: + prefix_imports.add('from jax._src.sharding_impls import UNSPECIFIED') + return ast.Name(id='UNSPECIFIED', ctx=ast.Load()) + elif isinstance(value, enum.Enum): + return ast.Attribute( + value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()), + attr=value.name, + ctx=ast.Load() + ) - else: - warnings.warn(f"Unknown value type {type(value)}") - return ast.parse(repr(value)).body[0] + else: + warnings.warn(f"Unknown value type {type(value)}") + return ast.parse(repr(value)).body[0] def _astify_outvars(state, outvars): - out = [state.name(v, ctx=ast.Store()) for v in outvars] - if len(out) == 1: - return out - else: - return [ast.Tuple(elts=out, ctx=ast.Store())] + out = [state.name(v, ctx=ast.Store()) for v in outvars] + if len(out) == 1: + return out + else: + return [ast.Tuple(elts=out, ctx=ast.Store())] def maybe_tuple_vars(vars): - if len(vars) == 1: - return vars[0] - else: - return ast.Tuple(elts=vars, ctx=ast.Load()) + if len(vars) == 1: + return vars[0] + else: + return ast.Tuple(elts=vars, ctx=ast.Load()) def maybe_untuple_vars(var, is_tuple): - if is_tuple: - return ast.Starred(value=var, ctx=ast.Load()) - else: - return var + if is_tuple: + return ast.Starred(value=var, ctx=ast.Load()) + else: + return var @register_prim_as('scan') def _astify_scan(state, eqn): - assert eqn.primitive.name == 'scan' + assert eqn.primitive.name == 'scan' - # the args to scan are [constants, carry, xs] - # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda) - num_consts = eqn.params['num_consts'] - num_carry = eqn.params['num_carry'] + # the args to scan are [constants, carry, xs] + # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda) + num_consts = eqn.params['num_consts'] + num_carry = eqn.params['num_carry'] - # TODO: bring back map - # if num_carry == 0: - # this is a map - # return _astify_map(eqn) + # TODO: bring back map + # if num_carry == 0: + # this is a map + # return _astify_map(eqn) - constant_args = eqn.invars[:num_consts] - carries = eqn.invars[num_consts:num_consts + num_carry] - xs = eqn.invars[num_consts + num_carry:] + constant_args = eqn.invars[:num_consts] + carries = eqn.invars[num_consts:num_consts + num_carry] + xs = eqn.invars[num_consts + num_carry:] - jaxpr = eqn.params['jaxpr'] + jaxpr = eqn.params['jaxpr'] - if num_consts != 0: - # we want to construct an environment where we partial eval the function using the constants as the env - env = dict(zip(jaxpr.jaxpr.invars, constant_args)) - jaxpr = partial_eval_jaxpr(jaxpr.jaxpr, env) - else: - jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) - - fn_name = state.skolem('fn') - fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) - - length = _astify_value(eqn.params['length']) - unroll = _astify_value(eqn.params['unroll']) - reverse = _astify_value(eqn.params['reverse']) + if num_consts != 0: + # we want to construct an environment where we partial eval the function using the constants as the env + env = dict(zip(jaxpr.jaxpr.invars, constant_args)) + jaxpr = partial_eval_jaxpr(jaxpr.jaxpr, env) + else: + jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) + + fn_name = state.skolem('fn') + fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) + + length = _astify_value(eqn.params['length']) + unroll = _astify_value(eqn.params['unroll']) + reverse = _astify_value(eqn.params['reverse']) + + stmts = [] + + if num_carry != 1 or len(jaxpr.invars) != 2: + # what we want is something like: + # fn_name = lambda carry, xs: fn_name(*carry, *xs) + # jax.lax.scan(fn_name, (carries...), (xs...)) + + modified_signature = ast.arguments( + args=[ast.arg(arg='carry'), ast.arg(arg='x')], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[], + posonlyargs=[] + ) - stmts = [] + initial_assign = ast.Assign( + targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args], + ctx=ast.Store())], + value=ast.Tuple( + elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1), + maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)] + ) + ) - if num_carry != 1 or len(jaxpr.invars) != 2: - # what we want is something like: - # fn_name = lambda carry, xs: fn_name(*carry, *xs) - # jax.lax.scan(fn_name, (carries...), (xs...)) + fn_return = fn_ast.body[-1] + assert isinstance(fn_return, ast.Return) - modified_signature = ast.arguments( - args=[ast.arg(arg='carry'), ast.arg(arg='x')], - vararg=None, - kwonlyargs=[], - kw_defaults=[], - kwarg=None, - defaults=[], - posonlyargs=[] - ) + fn_return_value = fn_return.value - initial_assign = ast.Assign( - targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args], - ctx=ast.Store())], - value=ast.Tuple( - elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1), - maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)] - ) - ) + if isinstance(fn_return_value, ast.Tuple): + fn_return_value = fn_return_value.elts + ret_carries = maybe_tuple_vars(fn_return_value[:num_carry]) + ret_ys = maybe_tuple_vars(fn_return_value[num_carry:]) + elif num_carry == 0: + ret_carries = _astify_value(()) + ret_ys = fn_return_value + else: + ret_carries = fn_return_value + ret_ys = _astify_value(()) - fn_return = fn_ast.body[-1] - assert isinstance(fn_return, ast.Return) + scan_return = ast.Return( + value=ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load()) + ) - fn_return_value = fn_return.value + new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return] - if isinstance(fn_return_value, ast.Tuple): - fn_return_value = fn_return_value.elts - ret_carries = maybe_tuple_vars(fn_return_value[:num_carry]) - ret_ys = maybe_tuple_vars(fn_return_value[num_carry:]) - elif num_carry == 0: - ret_carries = _astify_value(()) - ret_ys = fn_return_value - else: - ret_carries = fn_return_value - ret_ys = _astify_value(()) + fn_ast = ast.FunctionDef( + name=fn_name, + args=modified_signature, + body=new_body, + decorator_list=[] + ) - scan_return = ast.Return( - value=ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load()) - ) + stmts.append(fn_ast) + + scan_call = ast.Assign( + # targets=_astify_outvars(eqn.outvars), + targets=[ + ast.Tuple( + elts=[ast.Name(id='final_carry', ctx=ast.Store()), + ast.Name(id='ys', ctx=ast.Store())], + ctx=ast.Store() + ) + ], + value=ast.Call( + func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), + args=[ast.Name(id=fn_name, ctx=ast.Load()), + maybe_tuple_vars([_astify_atom(state, v) for v in carries]), + maybe_tuple_vars([_astify_atom(state, v) for v in xs])], + keywords=[ast.keyword(arg='length', value=length), + ast.keyword(arg='unroll', value=unroll), + ast.keyword(arg='reverse', value=reverse)] + ) + ) + stmts.append(scan_call) - new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return] + if num_carry > 0: + assign_carry = ast.Assign( + targets=_astify_outvars(state, eqn.outvars[:num_carry]), + value=ast.Name(id='final_carry', ctx=ast.Load()) + ) - fn_ast = ast.FunctionDef( - name=fn_name, - args=modified_signature, - body=new_body, - decorator_list=[] - ) + stmts.append(assign_carry) - stmts.append(fn_ast) + if num_carry < len(eqn.outvars): + assign_ys = ast.Assign( + targets=_astify_outvars(state, eqn.outvars[num_carry:]), + value=ast.Name(id='ys', ctx=ast.Load()) + ) - scan_call = ast.Assign( - # targets=_astify_outvars(eqn.outvars), - targets=[ - ast.Tuple( - elts=[ast.Name(id='final_carry', ctx=ast.Store()), - ast.Name(id='ys', ctx=ast.Store())], - ctx=ast.Store() + stmts.append(assign_ys) + else: + stmts.append(fn_ast) + + scan_call = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), + args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars], + keywords=[ast.keyword(arg='length', value=length), + ast.keyword(arg='unroll', value=unroll), + ast.keyword(arg='reverse', value=reverse)] + ) ) - ], - value=ast.Call( - func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), - args=[ast.Name(id=fn_name, ctx=ast.Load()), - maybe_tuple_vars([_astify_atom(state, v) for v in carries]), - maybe_tuple_vars([_astify_atom(state, v) for v in xs])], - keywords=[ast.keyword(arg='length', value=length), - ast.keyword(arg='unroll', value=unroll), - ast.keyword(arg='reverse', value=reverse)] - ) - ) - stmts.append(scan_call) - - if num_carry > 0: - assign_carry = ast.Assign( - targets=_astify_outvars(state, eqn.outvars[:num_carry]), - value=ast.Name(id='final_carry', ctx=ast.Load()) - ) - - stmts.append(assign_carry) - - if num_carry < len(eqn.outvars): - assign_ys = ast.Assign( - targets=_astify_outvars(state, eqn.outvars[num_carry:]), - value=ast.Name(id='ys', ctx=ast.Load()) - ) - - stmts.append(assign_ys) - else: - stmts.append(fn_ast) - - scan_call = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), - args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars], - keywords=[ast.keyword(arg='length', value=length), - ast.keyword(arg='unroll', value=unroll), - ast.keyword(arg='reverse', value=reverse)] - ) - ) - stmts.append(scan_call) + stmts.append(scan_call) - return stmts + return stmts def _astify_map(state, eqn): - assert eqn.primitive.name == 'scan' - assert eqn.params['num_carry'] == 0 - - jaxpr = eqn.params['jaxpr'] - jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) - - fn_name = state.skolem('fn') - fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) - - # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg - # so we need to use a lambda to redirect the call - lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0] - - assign = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=ast.Name(id='jax.lax.map', ctx=ast.Load()), - args=[lam, - ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars], - ctx=ast.Load())], - keywords=[] + assert eqn.primitive.name == 'scan' + assert eqn.params['num_carry'] == 0 + + jaxpr = eqn.params['jaxpr'] + jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) + + fn_name = state.skolem('fn') + fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) + + # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg + # so we need to use a lambda to redirect the call + lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0] + + assign = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=ast.Name(id='jax.lax.map', ctx=ast.Load()), + args=[lam, + ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars], + ctx=ast.Load())], + keywords=[] + ) ) - ) - return [fn_ast, assign] + return [fn_ast, assign] @register_prim_as('closed_call') def _astify_closed_call(state, eqn): - # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, - # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) - raw_jaxpr = eqn.params['call_jaxpr'].jaxpr - literal_args = {k: v.val - for k, v in zip(raw_jaxpr.invars, eqn.invars) - if isinstance(v, Literal)} - call_japr = partial_eval_jaxpr(raw_jaxpr, literal_args) - fn_name = state.skolem('fn') - - fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) - - invars = [_astify_atom(state, v) - for v in eqn.invars - if not isinstance(v, Literal)] - outvars = _astify_outvars(state, eqn.outvars) - - assign = ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=invars, - keywords=[] + # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, + # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) + raw_jaxpr = eqn.params['call_jaxpr'].jaxpr + literal_args = {k: v.val + for k, v in zip(raw_jaxpr.invars, eqn.invars) + if isinstance(v, Literal)} + call_japr = partial_eval_jaxpr(raw_jaxpr, literal_args) + fn_name = state.skolem('fn') + + fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) + + invars = [_astify_atom(state, v) + for v in eqn.invars + if not isinstance(v, Literal)] + outvars = _astify_outvars(state, eqn.outvars) + + assign = ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Name(id=fn_name, ctx=ast.Load()), + args=invars, + keywords=[] + ) ) - ) - return [fn_ast, assign] + return [fn_ast, assign] @register_prim_as('pjit') def _astify_pjit(state, eqn): - # this one's a real pain. - # pjit's params are : - # jaxpr - # donated_invars: - # in_shardings, out_shardings - # resource env - # name (yay) - # keep_unused, inline (which we won't use) - - jaxpr = eqn.params['jaxpr'] - donated_invars = eqn.params['donated_invars'] - in_shardings = eqn.params['in_shardings'] - out_shardings = eqn.params['out_shardings'] - resource_env = eqn.params['resource_env'] - name = eqn.params['name'] - - can_ignore_donated = not any(donated_invars) - - # preprocess the function - jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) - fn_name = state.skolem(name) - fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) - - in_shardings = _astify_value(in_shardings) - out_shardings = _astify_value(out_shardings) - - keywords = [ - ast.keyword(arg='in_shardings', value=in_shardings), - ast.keyword(arg='out_shardings', value=out_shardings), - ] - - if not can_ignore_donated: - donated_invars = _astify_value(donated_invars) - keywords.append(ast.keyword(arg='donated_invars', value=donated_invars)) - - jitted_fn = ast.Call( - func=ast.Attribute( - ast.Name(id='jax', ctx=ast.Load()), - attr='jit' - ), - args=[ast.Name(id=fn_name, ctx=ast.Load())], - keywords=keywords - ) - - assign = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=jitted_fn, - args=[_astify_atom(state, v) for v in eqn.invars], - keywords=[] + # this one's a real pain. + # pjit's params are : + # jaxpr + # donated_invars: + # in_shardings, out_shardings + # resource env + # name (yay) + # keep_unused, inline (which we won't use) + + jaxpr = eqn.params['jaxpr'] + donated_invars = eqn.params['donated_invars'] + in_shardings = eqn.params['in_shardings'] + out_shardings = eqn.params['out_shardings'] + resource_env = eqn.params['resource_env'] + name = eqn.params['name'] + + can_ignore_donated = not any(donated_invars) + + # preprocess the function + jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) + fn_name = state.skolem(name) + fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) + + in_shardings = _astify_value(in_shardings) + out_shardings = _astify_value(out_shardings) + + keywords = [ + ast.keyword(arg='in_shardings', value=in_shardings), + ast.keyword(arg='out_shardings', value=out_shardings), + ] + + if not can_ignore_donated: + donated_invars = _astify_value(donated_invars) + keywords.append(ast.keyword(arg='donated_invars', value=donated_invars)) + + jitted_fn = ast.Call( + func=ast.Attribute( + ast.Name(id='jax', ctx=ast.Load()), + attr='jit' + ), + args=[ast.Name(id=fn_name, ctx=ast.Load())], + keywords=keywords ) - ) - return [fn_ast, assign] + assign = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=jitted_fn, + args=[_astify_atom(state, v) for v in eqn.invars], + keywords=[] + ) + ) + + return [fn_ast, assign] @register_prim_as('remat2') def _astify_remat(state: SourcerorState, eqn): - # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, - # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) - call_japr = constant_fold_jaxpr(eqn.params['jaxpr']) - fn_name = state.skolem('fn') - - fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) - - invars = [_astify_atom(state, v) for v in eqn.invars] - outvars = _astify_outvars(state, eqn.outvars) - - lam = ast.Assign( - targets=[ast.Name(id=f"ckpt_{fn_name}", ctx=ast.Store())], - # value=ast.parse(f"jax.checkpoint({fn_name})").body[0] - value=ast.Call( - func=ast.Name(id='jax.checkpoint', ctx=ast.Load()), - args=[ast.Name(id=fn_name, ctx=ast.Load())], - keywords=[]) - ) - - assign = ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Name(id=f"ckpt_{fn_name}"), - args=invars, - keywords=[] - )) + # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, + # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) + call_japr = constant_fold_jaxpr(eqn.params['jaxpr']) + fn_name = state.skolem('fn') + + fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) + + invars = [_astify_atom(state, v) for v in eqn.invars] + outvars = _astify_outvars(state, eqn.outvars) + + lam = ast.Assign( + targets=[ast.Name(id=f"ckpt_{fn_name}", ctx=ast.Store())], + # value=ast.parse(f"jax.checkpoint({fn_name})").body[0] + value=ast.Call( + func=ast.Name(id='jax.checkpoint', ctx=ast.Load()), + args=[ast.Name(id=fn_name, ctx=ast.Load())], + keywords=[]) + ) + + assign = ast.Assign( + targets=outvars, + value=ast.Call( + func=ast.Name(id=f"ckpt_{fn_name}"), + args=invars, + keywords=[] + )) - return [fn_ast, lam, assign] + return [fn_ast, lam, assign] @register_prim_as('reshape') def _astify_reshape(state, eqn): - # the lax reshape is a bit different, because it can combine a transpose and reshape into one. - # np.reshape(np.transpose(operand, dimensions), new_sizes) - dimensions = eqn.params['dimensions'] - new_sizes = eqn.params['new_sizes'] - - source = _astify_atom(state, eqn.invars[0]) - - if dimensions is not None: - source = ast.Call( - func=ast.Name(id='jax.numpy.transpose', ctx=ast.Load()), - args=[source, _astify_value(dimensions)], - keywords=[] - ) + # the lax reshape is a bit different, because it can combine a transpose and reshape into one. + # np.reshape(np.transpose(operand, dimensions), new_sizes) + dimensions = eqn.params['dimensions'] + new_sizes = eqn.params['new_sizes'] + + source = _astify_atom(state, eqn.invars[0]) + + if dimensions is not None: + source = ast.Call( + func=ast.Name(id='jax.numpy.transpose', ctx=ast.Load()), + args=[source, _astify_value(dimensions)], + keywords=[] + ) - assign = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=ast.Name(id='jax.numpy.reshape', ctx=ast.Load()), - args=[source, _astify_value(new_sizes)], - keywords=[] - )) + assign = ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=ast.Call( + func=ast.Name(id='jax.numpy.reshape', ctx=ast.Load()), + args=[source, _astify_value(new_sizes)], + keywords=[] + )) - return [assign] + return [assign] @register_prim_as('add_any') def _astify_add_any(state, eqn): - # add_any is a weird undocumented jax primitive. best guess is it adds? - return _binop_fn(ast.Add())(state, eqn) + # add_any is a weird undocumented jax primitive. best guess is it adds? + return _binop_fn(ast.Add())(state, eqn) @register_prim_as('broadcast_in_dim') def _astify_broadcast_in_dim(state, eqn): - # broadcast_in_dim is how zeros, ones, full, etc are implemented, - # so we prefer to use those where possible - assert len(eqn.invars) == 1 - value = eqn.invars[0] - shape = eqn.params['shape'] - broadcast_dimensions = eqn.params['broadcast_dimensions'] - - if not isinstance(value, Literal) or broadcast_dimensions != (): - return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) - - if not isinstance(value.val, np.ndarray) or value.val.ndim != 0: - return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) - else: - constant_value = value.val.item() - if constant_value == 0: - call = ast.Call( - ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='zeros', - ctx=ast.Load() - ), - args=[_astify_value(shape), - _astify_value(value.val.dtype)], - keywords=[] - ) - elif constant_value == 1: - call = ast.Call( - ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='ones', - ctx=ast.Load() - ), - args=[_astify_value(shape), - _astify_value(value.val.dtype)], - keywords=[] - ) + # broadcast_in_dim is how zeros, ones, full, etc are implemented, + # so we prefer to use those where possible + assert len(eqn.invars) == 1 + value = eqn.invars[0] + shape = eqn.params['shape'] + broadcast_dimensions = eqn.params['broadcast_dimensions'] + + if not isinstance(value, Literal) or broadcast_dimensions != (): + return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) + + if not isinstance(value.val, np.ndarray) or value.val.ndim != 0: + return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) else: - call = ast.Call( - ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='full', - ctx=ast.Load() - ), - args=[_astify_value(shape), - _astify_value(constant_value), - _astify_value(value.val.dtype)], - keywords=[] - ) + constant_value = value.val.item() + if constant_value == 0: + call = ast.Call( + ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='zeros', + ctx=ast.Load() + ), + args=[_astify_value(shape), + _astify_value(value.val.dtype)], + keywords=[] + ) + elif constant_value == 1: + call = ast.Call( + ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='ones', + ctx=ast.Load() + ), + args=[_astify_value(shape), + _astify_value(value.val.dtype)], + keywords=[] + ) + else: + call = ast.Call( + ast.Attribute( + value=ast.Name(id='jax.numpy', ctx=ast.Load()), + attr='full', + ctx=ast.Load() + ), + args=[_astify_value(shape), + _astify_value(constant_value), + _astify_value(value.val.dtype)], + keywords=[] + ) - return [ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=call - )] + return [ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=call + )] @register_prim_as('random_wrap') def _astify_random_wrap(state, eqn): - # we treat this as a noop - return ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=_astify_atom(state, eqn.invars[0]) - ) + # we treat this as a noop + return ast.Assign( + targets=_astify_outvars(state, eqn.outvars), + value=_astify_atom(state, eqn.invars[0]) + ) constant_fold_blacklist = { - 'broadcast_in_dim', - 'broadcast', + 'broadcast_in_dim', + 'broadcast', } diff --git a/brainpy/_src/integrators/base.py b/brainpy/_src/integrators/base.py index 7853123bc..b5145dc73 100644 --- a/brainpy/_src/integrators/base.py +++ b/brainpy/_src/integrators/base.py @@ -1,178 +1,177 @@ # -*- coding: utf-8 -*- +from contextlib import contextmanager from typing import Dict, Sequence, Union, Callable import jax -from brainpy._src.math.object_transform.base import BrainPyObject from brainpy._src.math import TimeDelay, LengthDelay +from brainpy._src.math.object_transform.base import BrainPyObject from brainpy.check import is_float, is_dict_data from brainpy.errors import DiffEqError -from .constants import DT - from ._jaxpr_to_source_code import jaxpr_to_python_code -from contextlib import contextmanager +from .constants import DT __all__ = [ - 'Integrator', + 'Integrator', ] class AbstractIntegrator(BrainPyObject): - """Basic Integrator Class.""" + """Basic Integrator Class.""" - # func_name - # derivative - # code_scope - # + # func_name + # derivative + # code_scope + # - def __call__(self, *args, **kwargs): - raise NotImplementedError + def __call__(self, *args, **kwargs): + raise NotImplementedError class Integrator(AbstractIntegrator): - """Basic Integrator Class.""" - - def __init__( - self, - variables: Sequence[str], - parameters: Sequence[str], - arguments: Sequence[str], - dt: float, - name: str = None, - state_delays: Dict[str, Union[TimeDelay, LengthDelay]] = None, - ): - super(Integrator, self).__init__(name=name) - - self._dt = dt - is_float(dt, 'dt', allow_none=False, allow_int=True) - self._variables = list(variables) # variables - self._parameters = list(parameters) # parameters - self._arguments = list(arguments) + [f'{DT}={self._dt}', ] # arguments - self._integral = None # integral function - self.arg_names = self._variables + self._parameters + [DT] - - # state delays - self._state_delays = dict() - if state_delays is not None: - is_dict_data(state_delays, key_type=str, val_type=(TimeDelay, LengthDelay)) - for key, delay in state_delays.items(): - if key not in self.variables: - raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') - self._state_delays[key] = delay - self.register_implicit_nodes(self._state_delays) - - # math expression - self._math_expr = None - - @property - def dt(self): - """The numerical integration precision.""" - return self._dt - - @dt.setter - def dt(self, value): - raise ValueError('Cannot set "dt" by users.') - - @property - def variables(self): - """The variables defined in the differential equation.""" - return self._variables - - @variables.setter - def variables(self, values): - raise ValueError('Cannot set "variables" by users.') - - @property - def parameters(self): - """The parameters defined in the differential equation.""" - return self._parameters - - @parameters.setter - def parameters(self, values): - raise ValueError('Cannot set "parameters" by users.') - - @property - def arguments(self): - """All arguments when calling the numer integrator of the differential equation.""" - return self._arguments - - @arguments.setter - def arguments(self, values): - raise ValueError('Cannot set "arguments" by users.') - - @property - def integral(self): - """The integral function.""" - return self._integral - - @integral.setter - def integral(self, f): - self.set_integral(f) - - def set_integral(self, f): - """Set the integral function.""" - if not callable(f): - raise ValueError(f'integral function must be a callable function, ' - f'but we got {type(f)}: {f}') - self._integral = f - - @property - def state_delays(self): - """State delays.""" - return self._state_delays - - @state_delays.setter - def state_delays(self, value): - raise ValueError('Cannot set "state_delays" by users.') - - def _call_integral(self, *args, **kwargs): - if _during_compile: - jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) - outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) - _, tree = jax.tree.flatten(out_shapes) - new_vars = tree.unflatten(outs) - self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr) - - else: - new_vars = self.integral(**kwargs) - return new_vars - - def __call__(self, *args, **kwargs): - assert self.integral is not None, 'Please build the integrator first.' - - # check arguments - for i, arg in enumerate(args): - kwargs[self.arg_names[i]] = arg - - # integral - new_vars = self._call_integral(**kwargs) - - # post-process - if len(self.variables) == 1: - dict_vars = {self.variables[0]: new_vars} - else: - dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} - - # update state delay variables - dt = kwargs.pop(DT, self.dt) - for key, delay in self.state_delays.items(): - if isinstance(delay, LengthDelay): - delay.update(dict_vars[key]) - elif isinstance(delay, TimeDelay): - delay.update(dict_vars[key]) - else: - raise ValueError('Unknown delay variable. We only supports ' - 'brainpy.math.LengthDelay, brainpy.math.TimeDelay. ' - f'While we got {delay}') - - return new_vars - - def to_math_expr(self): - if self._math_expr is None: - raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.') - return self._math_expr + """Basic Integrator Class.""" + + def __init__( + self, + variables: Sequence[str], + parameters: Sequence[str], + arguments: Sequence[str], + dt: float, + name: str = None, + state_delays: Dict[str, Union[TimeDelay, LengthDelay]] = None, + ): + super(Integrator, self).__init__(name=name) + + self._dt = dt + is_float(dt, 'dt', allow_none=False, allow_int=True) + self._variables = list(variables) # variables + self._parameters = list(parameters) # parameters + self._arguments = list(arguments) + [f'{DT}={self._dt}', ] # arguments + self._integral = None # integral function + self.arg_names = self._variables + self._parameters + [DT] + + # state delays + self._state_delays = dict() + if state_delays is not None: + is_dict_data(state_delays, key_type=str, val_type=(TimeDelay, LengthDelay)) + for key, delay in state_delays.items(): + if key not in self.variables: + raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') + self._state_delays[key] = delay + self.register_implicit_nodes(self._state_delays) + + # math expression + self._math_expr = None + + @property + def dt(self): + """The numerical integration precision.""" + return self._dt + + @dt.setter + def dt(self, value): + raise ValueError('Cannot set "dt" by users.') + + @property + def variables(self): + """The variables defined in the differential equation.""" + return self._variables + + @variables.setter + def variables(self, values): + raise ValueError('Cannot set "variables" by users.') + + @property + def parameters(self): + """The parameters defined in the differential equation.""" + return self._parameters + + @parameters.setter + def parameters(self, values): + raise ValueError('Cannot set "parameters" by users.') + + @property + def arguments(self): + """All arguments when calling the numer integrator of the differential equation.""" + return self._arguments + + @arguments.setter + def arguments(self, values): + raise ValueError('Cannot set "arguments" by users.') + + @property + def integral(self): + """The integral function.""" + return self._integral + + @integral.setter + def integral(self, f): + self.set_integral(f) + + def set_integral(self, f): + """Set the integral function.""" + if not callable(f): + raise ValueError(f'integral function must be a callable function, ' + f'but we got {type(f)}: {f}') + self._integral = f + + @property + def state_delays(self): + """State delays.""" + return self._state_delays + + @state_delays.setter + def state_delays(self, value): + raise ValueError('Cannot set "state_delays" by users.') + + def _call_integral(self, *args, **kwargs): + if _during_compile: + jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) + outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) + _, tree = jax.tree.flatten(out_shapes) + new_vars = tree.unflatten(outs) + self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr) + + else: + new_vars = self.integral(**kwargs) + return new_vars + + def __call__(self, *args, **kwargs): + assert self.integral is not None, 'Please build the integrator first.' + + # check arguments + for i, arg in enumerate(args): + kwargs[self.arg_names[i]] = arg + + # integral + new_vars = self._call_integral(**kwargs) + + # post-process + if len(self.variables) == 1: + dict_vars = {self.variables[0]: new_vars} + else: + dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} + + # update state delay variables + dt = kwargs.pop(DT, self.dt) + for key, delay in self.state_delays.items(): + if isinstance(delay, LengthDelay): + delay.update(dict_vars[key]) + elif isinstance(delay, TimeDelay): + delay.update(dict_vars[key]) + else: + raise ValueError('Unknown delay variable. We only supports ' + 'brainpy.math.LengthDelay, brainpy.math.TimeDelay. ' + f'While we got {delay}') + + return new_vars + + def to_math_expr(self): + if self._math_expr is None: + raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.') + return self._math_expr _during_compile = False @@ -180,19 +179,17 @@ def to_math_expr(self): @contextmanager def _during_compile_context(): - global _during_compile - try: - _during_compile = True - yield - finally: - _during_compile = False + global _during_compile + try: + _during_compile = True + yield + finally: + _during_compile = False def compile_integrators(f: Callable, *args, **kwargs): - """ - Compile integrators in the given function. - """ - with _during_compile_context(): - return f(*args, **kwargs) - - + """ + Compile integrators in the given function. + """ + with _during_compile_context(): + return f(*args, **kwargs) diff --git a/brainpy/_src/integrators/constants.py b/brainpy/_src/integrators/constants.py index d1daf5d3a..b20c18680 100644 --- a/brainpy/_src/integrators/constants.py +++ b/brainpy/_src/integrators/constants.py @@ -4,41 +4,37 @@ from brainpy._src.math.object_transform import naming __all__ = [ - 'DT', - 'F', - - 'SUPPORTED_VAR_TYPE', - 'SCALAR_VAR', - 'POP_VAR', - 'SYSTEM_VAR', - - 'SUPPORTED_WIENER_TYPE', - 'SCALAR_WIENER', - 'VECTOR_WIENER', - - 'SUPPORTED_INTG_TYPE', - 'ITO_SDE', - 'STRA_SDE', - - 'DE_INT', - 'ODE_INT', - 'SDE_INT', - 'DDE_INT', - 'FDE_INT', - 'PDE_INT', - - 'unique_name', - 'multi_vars_msg', + 'DT', + 'F', + + 'SUPPORTED_VAR_TYPE', + 'SCALAR_VAR', + 'POP_VAR', + 'SYSTEM_VAR', + + 'SUPPORTED_WIENER_TYPE', + 'SCALAR_WIENER', + 'VECTOR_WIENER', + + 'SUPPORTED_INTG_TYPE', + 'ITO_SDE', + 'STRA_SDE', + + 'DE_INT', + 'ODE_INT', + 'SDE_INT', + 'DDE_INT', + 'FDE_INT', + 'PDE_INT', + + 'unique_name', + 'multi_vars_msg', ] - DT = 'dt' F = 'f' G = 'g' - - - # Ito SDE_INT # --- # The SDE_INT integral proposed by Ito in 1940s. @@ -50,8 +46,8 @@ STRA_SDE = 'Stratonovich' SUPPORTED_INTG_TYPE = [ - ITO_SDE, - STRA_SDE + ITO_SDE, + STRA_SDE ] # ------------------------------------------------------ @@ -67,8 +63,8 @@ VECTOR_WIENER = 'vector' SUPPORTED_WIENER_TYPE = [ - SCALAR_WIENER, - VECTOR_WIENER + SCALAR_WIENER, + VECTOR_WIENER ] # ------------------------------------------------------ @@ -112,9 +108,9 @@ SYSTEM_VAR = 'system' SUPPORTED_VAR_TYPE = [ - SCALAR_VAR, - POP_VAR, - SYSTEM_VAR, + SCALAR_VAR, + POP_VAR, + SYSTEM_VAR, ] # ------------------------------------------------------ @@ -131,18 +127,18 @@ def unique_name(type): - if type == 'ode': - return naming.get_unique_name(ODE_INT) - elif type == 'sde': - return naming.get_unique_name(SDE_INT) - elif type == 'dde': - return naming.get_unique_name(DDE_INT) - elif type == 'fde': - return naming.get_unique_name(FDE_INT) - elif type == 'pde': - return naming.get_unique_name(PDE_INT) - else: - raise ValueError(f'Unknown differential equation type: {type}') + if type == 'ode': + return naming.get_unique_name(ODE_INT) + elif type == 'sde': + return naming.get_unique_name(SDE_INT) + elif type == 'dde': + return naming.get_unique_name(DDE_INT) + elif type == 'fde': + return naming.get_unique_name(FDE_INT) + elif type == 'pde': + return naming.get_unique_name(PDE_INT) + else: + raise ValueError(f'Unknown differential equation type: {type}') # Errors diff --git a/brainpy/_src/integrators/fde/Caputo.py b/brainpy/_src/integrators/fde/Caputo.py index fdff337ab..1a4decd58 100644 --- a/brainpy/_src/integrators/fde/Caputo.py +++ b/brainpy/_src/integrators/fde/Caputo.py @@ -20,400 +20,400 @@ from .generic import register_fde_integrator __all__ = [ - 'CaputoEuler', - 'CaputoL1Schema', + 'CaputoEuler', + 'CaputoL1Schema', ] class CaputoEuler(FDEIntegrator): - r"""One-step Euler method for Caputo fractional differential equations. - - Given a fractional initial value problem, - - .. math:: - - D_{*}^{\alpha} y(t)=f(t, y(t)), \quad y^{(k)}(0)=y_{0}^{(k)}, \quad k=0,1, \ldots,\lceil\alpha\rceil-1 - - where the :math:`y_0^{(k)}` ay be arbitrary real numbers and where :math:`\alpha>0`. - :math:`D_{*}^{\alpha}` denotes the differential operator in the sense of Caputo, defined - by - - .. math:: - - D_{*}^{\alpha} z(t)=J^{n-\alpha} D^{n} z(t) - - where :math:`n:=\lceil\alpha\rceil` is the smallest integer :math:`\geqslant \alpha`, - Here :math:`D^n` is the usual differential operator of (integer) order :math:`n`, - and for :math:`\mu > 0`, :math:`J^{\mu}` is the Riemann–Liouville integral operator - of order :math:`\mu`, defined by - - .. math:: - - J^{\mu} z(t)=\frac{1}{\Gamma(\mu)} \int_{0}^{t}(t-u)^{\mu-1} z(u) \mathrm{d} u - - The one-step Euler method for fractional differential equation is defined as - - .. math:: - - y_{k+1} = y_0 + \frac{1}{\Gamma(\alpha)} \sum_{j=0}^{k} b_{j, k+1} f\left(t_{j}, y_{j}\right). - - where - - .. math:: - - b_{j, k+1}=\frac{h^{\alpha}}{\alpha}\left((k+1-j)^{\alpha}-(k-j)^{\alpha}\right). - - - Examples:: - - >>> import brainpy as bp - >>> - >>> a, b, c = 10, 28, 8 / 3 - >>> def lorenz(x, y, z, t): - >>> dx = a * (y - x) - >>> dy = x * (b - z) - y - >>> dz = x * y - c * z - >>> return dx, dy, dz - >>> - >>> duration = 30. - >>> dt = 0.005 - >>> inits = [1., 0., 1.] - >>> f = bp.fde.CaputoEuler(lorenz, alpha=0.97, num_memory=int(duration / dt), inits=inits) - >>> runner = bp.integrators.IntegratorRunner(f, monitors=list('xyz'), dt=dt, inits=inits) - >>> runner.run(duration) - >>> - >>> import matplotlib.pyplot as plt - >>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) - >>> plt.show() - - - Parameters:: - - f : callable - The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence - The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. - num_memory: int - The total time step of the simulation. - inits: sequence - A sequence of the initial values for variables. - dt: float, int - The numerical precision. - name: str - The integrator name. - - References:: - - .. [1] Li, Changpin, and Fanhai Zeng. "The finite difference methods for fractional - ordinary differential equations." Numerical Functional Analysis and - Optimization 34.2 (2013): 149-179. - .. [2] Diethelm, Kai, Neville J. Ford, and Alan D. Freed. "Detailed error analysis - for a fractional Adams method." Numerical algorithms 36.1 (2004): 31-52. - """ - - def __init__( - self, - f: Callable, - alpha: Union[float, Sequence[float], ArrayType], - num_memory: int, - inits: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], - dt: float = None, - name: str = None, - state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, - ): - super(CaputoEuler, self).__init__(f=f, - alpha=alpha, - dt=dt, - name=name, - num_memory=num_memory, - state_delays=state_delays) - - # fractional order - if not bm.all(bm.logical_and(self.alpha < 1, self.alpha > 0)): - raise UnsupportedError(f'Only support the fractional order in (0, 1), ' - f'but we got {self.alpha}.') - - # initial values - self.inits = check_inits(inits, self.variables) - - # coefficients - rgamma_alpha = bm.asarray(rgamma(bm.as_numpy(self.alpha))) - ranges = bm.asarray([bm.arange(num_memory + 1) for _ in self.variables]).T - coef = rgamma_alpha * bm.diff(bm.power(ranges, self.alpha), axis=0) - self.coef = bm.flip(coef, axis=0) - - # variable states - self.f_states = {v: bm.Variable(bm.zeros((num_memory,) + self.inits[v].shape)) - for v in self.variables} - self.register_implicit_vars(self.f_states) - self.idx = bm.Variable(bm.asarray([1])) - - self.set_integral(self._integral_func) - - def _check_step(self, args): - dt, t = args - raise ValueError(f'The maximum number of step is {self.num_memory}, ' - f'however, the current time {t} require a time ' - f'step number {t / dt}.') - - def _integral_func(self, *args, **kwargs): - # format arguments - all_args = format_args(args, kwargs, self.arg_names) - t = all_args['t'] - dt = all_args.pop(DT, self.dt) - if check.is_checking(): - check.jit_error(self.num_memory * dt < t, self._check_step, (dt, t)) - - # derivative values - devs = self.f(**all_args) - if len(self.variables) == 1: - if not isinstance(devs, (bm.ndarray, jax.Array)): - raise ValueError('Derivative values must be a tensor when there ' - 'is only one variable in the equation.') - devs = {self.variables[0]: devs} - else: - if not isinstance(devs, (tuple, list)): - raise ValueError('Derivative values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - devs = {var: devs[i] for i, var in enumerate(self.variables)} - - # function states - for key in self.variables: - self.f_states[key][self.idx[0]] = devs[key] - - # integral results - integrals = [] - idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory - for i, key in enumerate(self.variables): - integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key] - integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i])) - self.idx.value = (self.idx + 1) % self.num_memory - - # return integrals - if len(self.variables) == 1: - return integrals[0] - else: - return integrals + r"""One-step Euler method for Caputo fractional differential equations. + + Given a fractional initial value problem, + + .. math:: + + D_{*}^{\alpha} y(t)=f(t, y(t)), \quad y^{(k)}(0)=y_{0}^{(k)}, \quad k=0,1, \ldots,\lceil\alpha\rceil-1 + + where the :math:`y_0^{(k)}` ay be arbitrary real numbers and where :math:`\alpha>0`. + :math:`D_{*}^{\alpha}` denotes the differential operator in the sense of Caputo, defined + by + + .. math:: + + D_{*}^{\alpha} z(t)=J^{n-\alpha} D^{n} z(t) + + where :math:`n:=\lceil\alpha\rceil` is the smallest integer :math:`\geqslant \alpha`, + Here :math:`D^n` is the usual differential operator of (integer) order :math:`n`, + and for :math:`\mu > 0`, :math:`J^{\mu}` is the Riemann–Liouville integral operator + of order :math:`\mu`, defined by + + .. math:: + + J^{\mu} z(t)=\frac{1}{\Gamma(\mu)} \int_{0}^{t}(t-u)^{\mu-1} z(u) \mathrm{d} u + + The one-step Euler method for fractional differential equation is defined as + + .. math:: + + y_{k+1} = y_0 + \frac{1}{\Gamma(\alpha)} \sum_{j=0}^{k} b_{j, k+1} f\left(t_{j}, y_{j}\right). + + where + + .. math:: + + b_{j, k+1}=\frac{h^{\alpha}}{\alpha}\left((k+1-j)^{\alpha}-(k-j)^{\alpha}\right). + + + Examples:: + + >>> import brainpy as bp + >>> + >>> a, b, c = 10, 28, 8 / 3 + >>> def lorenz(x, y, z, t): + >>> dx = a * (y - x) + >>> dy = x * (b - z) - y + >>> dz = x * y - c * z + >>> return dx, dy, dz + >>> + >>> duration = 30. + >>> dt = 0.005 + >>> inits = [1., 0., 1.] + >>> f = bp.fde.CaputoEuler(lorenz, alpha=0.97, num_memory=int(duration / dt), inits=inits) + >>> runner = bp.integrators.IntegratorRunner(f, monitors=list('xyz'), dt=dt, inits=inits) + >>> runner.run(duration) + >>> + >>> import matplotlib.pyplot as plt + >>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) + >>> plt.show() + + + Parameters:: + + f : callable + The derivative function. + alpha: int, float, jnp.ndarray, bm.ndarray, sequence + The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. + num_memory: int + The total time step of the simulation. + inits: sequence + A sequence of the initial values for variables. + dt: float, int + The numerical precision. + name: str + The integrator name. + + References:: + + .. [1] Li, Changpin, and Fanhai Zeng. "The finite difference methods for fractional + ordinary differential equations." Numerical Functional Analysis and + Optimization 34.2 (2013): 149-179. + .. [2] Diethelm, Kai, Neville J. Ford, and Alan D. Freed. "Detailed error analysis + for a fractional Adams method." Numerical algorithms 36.1 (2004): 31-52. + """ + + def __init__( + self, + f: Callable, + alpha: Union[float, Sequence[float], ArrayType], + num_memory: int, + inits: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], + dt: float = None, + name: str = None, + state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, + ): + super(CaputoEuler, self).__init__(f=f, + alpha=alpha, + dt=dt, + name=name, + num_memory=num_memory, + state_delays=state_delays) + + # fractional order + if not bm.all(bm.logical_and(self.alpha < 1, self.alpha > 0)): + raise UnsupportedError(f'Only support the fractional order in (0, 1), ' + f'but we got {self.alpha}.') + + # initial values + self.inits = check_inits(inits, self.variables) + + # coefficients + rgamma_alpha = bm.asarray(rgamma(bm.as_numpy(self.alpha))) + ranges = bm.asarray([bm.arange(num_memory + 1) for _ in self.variables]).T + coef = rgamma_alpha * bm.diff(bm.power(ranges, self.alpha), axis=0) + self.coef = bm.flip(coef, axis=0) + + # variable states + self.f_states = {v: bm.Variable(bm.zeros((num_memory,) + self.inits[v].shape)) + for v in self.variables} + self.register_implicit_vars(self.f_states) + self.idx = bm.Variable(bm.asarray([1])) + + self.set_integral(self._integral_func) + + def _check_step(self, args): + dt, t = args + raise ValueError(f'The maximum number of step is {self.num_memory}, ' + f'however, the current time {t} require a time ' + f'step number {t / dt}.') + + def _integral_func(self, *args, **kwargs): + # format arguments + all_args = format_args(args, kwargs, self.arg_names) + t = all_args['t'] + dt = all_args.pop(DT, self.dt) + if check.is_checking(): + check.jit_error(self.num_memory * dt < t, self._check_step, (dt, t)) + + # derivative values + devs = self.f(**all_args) + if len(self.variables) == 1: + if not isinstance(devs, (bm.ndarray, jax.Array)): + raise ValueError('Derivative values must be a tensor when there ' + 'is only one variable in the equation.') + devs = {self.variables[0]: devs} + else: + if not isinstance(devs, (tuple, list)): + raise ValueError('Derivative values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + devs = {var: devs[i] for i, var in enumerate(self.variables)} + + # function states + for key in self.variables: + self.f_states[key][self.idx[0]] = devs[key] + + # integral results + integrals = [] + idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory + for i, key in enumerate(self.variables): + integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key] + integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i])) + self.idx.value = (self.idx + 1) % self.num_memory + + # return integrals + if len(self.variables) == 1: + return integrals[0] + else: + return integrals register_fde_integrator(name='euler', integrator=CaputoEuler) class CaputoL1Schema(FDEIntegrator): - r"""The L1 scheme method for the numerical approximation of the Caputo - fractional-order derivative equations [3]_. + r"""The L1 scheme method for the numerical approximation of the Caputo + fractional-order derivative equations [3]_. + + For the fractional order :math:`0<\alpha<1`, let the fractional derivative of variable + :math:`x(t)` be + + .. math:: + + \frac{d^{\alpha} x}{d t^{\alpha}}=F(x, t) + + The Caputo definition of the fractional derivative for variable :math:`x` is + + .. math:: + + \frac{d^{\alpha} x}{d t^{\alpha}}=\frac{1}{\Gamma(1-\alpha)} \int_{0}^{t} \frac{x^{\prime}(u)}{(t-u)^{\alpha}} d u + + where :math:`\Gamma` is the Gamma function. + + The fractional-order derivative is capable of integrating the activity of the + function over all past activities weighted by a function that follows a power-law. + Using one of the numerical methods, the L1 scheme method [3]_, the numerical + approximation of the fractional-order derivative of :math:`x` is + + .. math:: + + \frac{d^{\alpha} \chi}{d t^{\alpha}} \approx \frac{(d t)^{-\alpha}}{\Gamma(2-\alpha)}\left[\sum_{k=0}^{N-1}\left[x\left(t_{k+1}\right)- + \mathrm{x}\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right] + + Therefore, the numerical solution of original system is given by + + .. math:: + + x\left(t_{N}\right) \approx d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right)- + \left[\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right] - For the fractional order :math:`0<\alpha<1`, let the fractional derivative of variable - :math:`x(t)` be + Hence, the solution of the fractional-order derivative can be described as the + difference between the *Markov term* and the *memory trace*. The *Markov term* + weighted by the gamma function is + + .. math:: + + \text { Markov term }=d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right) - .. math:: - - \frac{d^{\alpha} x}{d t^{\alpha}}=F(x, t) - - The Caputo definition of the fractional derivative for variable :math:`x` is - - .. math:: - - \frac{d^{\alpha} x}{d t^{\alpha}}=\frac{1}{\Gamma(1-\alpha)} \int_{0}^{t} \frac{x^{\prime}(u)}{(t-u)^{\alpha}} d u - - where :math:`\Gamma` is the Gamma function. - - The fractional-order derivative is capable of integrating the activity of the - function over all past activities weighted by a function that follows a power-law. - Using one of the numerical methods, the L1 scheme method [3]_, the numerical - approximation of the fractional-order derivative of :math:`x` is - - .. math:: - - \frac{d^{\alpha} \chi}{d t^{\alpha}} \approx \frac{(d t)^{-\alpha}}{\Gamma(2-\alpha)}\left[\sum_{k=0}^{N-1}\left[x\left(t_{k+1}\right)- - \mathrm{x}\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right] - - Therefore, the numerical solution of original system is given by - - .. math:: - - x\left(t_{N}\right) \approx d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right)- - \left[\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right] - - Hence, the solution of the fractional-order derivative can be described as the - difference between the *Markov term* and the *memory trace*. The *Markov term* - weighted by the gamma function is - - .. math:: - - \text { Markov term }=d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right) - - The memory trace (:math:`x`-memory trace since it is related to variable :math:`x`) is - - .. math:: - - \text { Memory trace }=\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-(k+1))^{1-\alpha}\right] - - The memory trace integrates all the past activity and captures the long-term - history of the system. For :math:`\alpha=1`, the memory trace is 0 for any - time :math:`t`. When the fractional order :math:`\alpha` is decreased from 1, - the memory trace non-linearly increases from 0, and its dynamics strongly - depends on time. Thus, the fractional order dynamics strongly deviates - from the first order dynamics. - - - Examples:: - - >>> import brainpy as bp - >>> - >>> a, b, c = 10, 28, 8 / 3 - >>> def lorenz(x, y, z, t): - >>> dx = a * (y - x) - >>> dy = x * (b - z) - y - >>> dz = x * y - c * z - >>> return dx, dy, dz - >>> - >>> duration = 30. - >>> dt = 0.005 - >>> inits = [1., 0., 1.] - >>> f = bp.fde.CaputoL1Schema(lorenz, alpha=0.99, num_memory=int(duration / dt), inits=inits) - >>> runner = bp.IntegratorRunner(f, monitors=list('xz'), dt=dt, inits=inits) - >>> runner.run(duration) - >>> - >>> import matplotlib.pyplot as plt - >>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) - >>> plt.show() - - - Parameters:: - - f : callable - The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence - The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. - num_memory: int - The total time step of the simulation. - inits: sequence - A sequence of the initial values for variables. - dt: float, int - The numerical precision. - name: str - The integrator name. - - References:: - - .. [3] Oldham, K., & Spanier, J. (1974). The fractional calculus theory - and applications of differentiation and integration to arbitrary - order. Elsevier. - """ - - def __init__( - self, - f: Callable, - alpha: Union[float, Sequence[float], ArrayType], - num_memory: int, - inits: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], - state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, - dt: float = None, - name: str = None, - ): - super(CaputoL1Schema, self).__init__(f=f, - alpha=alpha, - dt=dt, - name=name, - num_memory=num_memory, - state_delays=state_delays) - - # fractional order - if not bm.all(bm.logical_and(self.alpha <= 1, self.alpha > 0)): - raise UnsupportedError(f'Only support the fractional order in (0, 1), ' - f'but we got {self.alpha}.') - self.gamma_alpha = bm.asarray(gamma(bm.as_numpy(2 - self.alpha))) - - # initial values - inits = check_inits(inits, self.variables) - self.inits = bm.VarDict({v: bm.Variable(inits[v]) for v in self.variables}) - - # coefficients - ranges = bm.asarray([bm.arange(1, num_memory + 2) for _ in self.variables]).T - coef = bm.diff(bm.power(ranges, 1 - self.alpha), axis=0) - self.coef = bm.flip(coef, axis=0) - - # used to save the difference of two adjacent states - self.diff_states = bm.VarDict({v + "_diff": bm.Variable(bm.zeros((num_memory,) + self.inits[v].shape, - dtype=self.inits[v].dtype)) - for v in self.variables}) - self.idx = bm.Variable(bm.asarray([self.num_memory - 1])) - - # integral function - self.set_integral(self._integral_func) - - def reset(self, inits): - """Reset function.""" - self.idx.value = bm.asarray([self.num_memory - 1]) - inits = check_inits(inits, self.variables) - for key, value in inits.items(): - self.inits[key] = value - for key, val in inits.items(): - self.diff_states[key + "_diff"] = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) - - def hists(self, var=None, numpy=True): - """Get the recorded history values.""" - if var is None: - hists_ = {k: bm.vstack([self.inits[k], self.diff_states[k + '_diff']]) - for k in self.variables} - hists_ = {k: bm.cumsum(v, axis=0) for k, v in hists_.items()} - if numpy: - hists_ = {k: v.numpy() for k, v in hists_} - return hists_ - else: - assert var in self.variables, (f'"{var}" is not defined in equation ' - f'variables: {self.variables}') - hists_ = bm.vstack([self.inits[var], self.diff_states[var + '_diff']]) - hists_ = bm.cumsum(hists_, axis=0) - if numpy: - hists_ = hists_.numpy() - return hists_ - - def _check_step(self, args): - dt, t = args - raise ValueError(f'The maximum number of step is {self.num_memory}, ' - f'however, the current time {t} require a time ' - f'step number {t / dt}.') - - def _integral_func(self, *args, **kwargs): - # format arguments - all_args = format_args(args, kwargs, self.arg_names) - t = all_args['t'] - dt = all_args.pop(DT, self.dt) - if check.is_checking(): - check.jit_error(self.num_memory * dt < t, self._check_step, (dt, t)) - - # derivative values - devs = self.f(**all_args) - if len(self.variables) == 1: - if not isinstance(devs, (bm.BaseArray, jax.Array)): - raise ValueError('Derivative values must be a tensor when there ' - 'is only one variable in the equation.') - devs = {self.variables[0]: devs} - else: - if not isinstance(devs, (tuple, list)): - raise ValueError('Derivative values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - devs = {var: devs[i] for i, var in enumerate(self.variables)} - - # integral results - integrals = [] - idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory - for i, key in enumerate(self.variables): - self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key] - self.inits[key].value = all_args[key] - markov_term = dt ** self.alpha[i] * self.gamma_alpha[i] * devs[key] + all_args[key] - memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff'] - integral = markov_term - memory_trace - integrals.append(integral) - self.idx.value = (self.idx + 1) % self.num_memory - - # return integrals - if len(self.variables) == 1: - return integrals[0] - else: - return integrals + The memory trace (:math:`x`-memory trace since it is related to variable :math:`x`) is + + .. math:: + + \text { Memory trace }=\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-(k+1))^{1-\alpha}\right] + + The memory trace integrates all the past activity and captures the long-term + history of the system. For :math:`\alpha=1`, the memory trace is 0 for any + time :math:`t`. When the fractional order :math:`\alpha` is decreased from 1, + the memory trace non-linearly increases from 0, and its dynamics strongly + depends on time. Thus, the fractional order dynamics strongly deviates + from the first order dynamics. + + + Examples:: + + >>> import brainpy as bp + >>> + >>> a, b, c = 10, 28, 8 / 3 + >>> def lorenz(x, y, z, t): + >>> dx = a * (y - x) + >>> dy = x * (b - z) - y + >>> dz = x * y - c * z + >>> return dx, dy, dz + >>> + >>> duration = 30. + >>> dt = 0.005 + >>> inits = [1., 0., 1.] + >>> f = bp.fde.CaputoL1Schema(lorenz, alpha=0.99, num_memory=int(duration / dt), inits=inits) + >>> runner = bp.IntegratorRunner(f, monitors=list('xz'), dt=dt, inits=inits) + >>> runner.run(duration) + >>> + >>> import matplotlib.pyplot as plt + >>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) + >>> plt.show() + + + Parameters:: + + f : callable + The derivative function. + alpha: int, float, jnp.ndarray, bm.ndarray, sequence + The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. + num_memory: int + The total time step of the simulation. + inits: sequence + A sequence of the initial values for variables. + dt: float, int + The numerical precision. + name: str + The integrator name. + + References:: + + .. [3] Oldham, K., & Spanier, J. (1974). The fractional calculus theory + and applications of differentiation and integration to arbitrary + order. Elsevier. + """ + + def __init__( + self, + f: Callable, + alpha: Union[float, Sequence[float], ArrayType], + num_memory: int, + inits: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], + state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, + dt: float = None, + name: str = None, + ): + super(CaputoL1Schema, self).__init__(f=f, + alpha=alpha, + dt=dt, + name=name, + num_memory=num_memory, + state_delays=state_delays) + + # fractional order + if not bm.all(bm.logical_and(self.alpha <= 1, self.alpha > 0)): + raise UnsupportedError(f'Only support the fractional order in (0, 1), ' + f'but we got {self.alpha}.') + self.gamma_alpha = bm.asarray(gamma(bm.as_numpy(2 - self.alpha))) + + # initial values + inits = check_inits(inits, self.variables) + self.inits = bm.VarDict({v: bm.Variable(inits[v]) for v in self.variables}) + + # coefficients + ranges = bm.asarray([bm.arange(1, num_memory + 2) for _ in self.variables]).T + coef = bm.diff(bm.power(ranges, 1 - self.alpha), axis=0) + self.coef = bm.flip(coef, axis=0) + + # used to save the difference of two adjacent states + self.diff_states = bm.VarDict({v + "_diff": bm.Variable(bm.zeros((num_memory,) + self.inits[v].shape, + dtype=self.inits[v].dtype)) + for v in self.variables}) + self.idx = bm.Variable(bm.asarray([self.num_memory - 1])) + + # integral function + self.set_integral(self._integral_func) + + def reset(self, inits): + """Reset function.""" + self.idx.value = bm.asarray([self.num_memory - 1]) + inits = check_inits(inits, self.variables) + for key, value in inits.items(): + self.inits[key] = value + for key, val in inits.items(): + self.diff_states[key + "_diff"] = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) + + def hists(self, var=None, numpy=True): + """Get the recorded history values.""" + if var is None: + hists_ = {k: bm.vstack([self.inits[k], self.diff_states[k + '_diff']]) + for k in self.variables} + hists_ = {k: bm.cumsum(v, axis=0) for k, v in hists_.items()} + if numpy: + hists_ = {k: v.numpy() for k, v in hists_} + return hists_ + else: + assert var in self.variables, (f'"{var}" is not defined in equation ' + f'variables: {self.variables}') + hists_ = bm.vstack([self.inits[var], self.diff_states[var + '_diff']]) + hists_ = bm.cumsum(hists_, axis=0) + if numpy: + hists_ = hists_.numpy() + return hists_ + + def _check_step(self, args): + dt, t = args + raise ValueError(f'The maximum number of step is {self.num_memory}, ' + f'however, the current time {t} require a time ' + f'step number {t / dt}.') + + def _integral_func(self, *args, **kwargs): + # format arguments + all_args = format_args(args, kwargs, self.arg_names) + t = all_args['t'] + dt = all_args.pop(DT, self.dt) + if check.is_checking(): + check.jit_error(self.num_memory * dt < t, self._check_step, (dt, t)) + + # derivative values + devs = self.f(**all_args) + if len(self.variables) == 1: + if not isinstance(devs, (bm.BaseArray, jax.Array)): + raise ValueError('Derivative values must be a tensor when there ' + 'is only one variable in the equation.') + devs = {self.variables[0]: devs} + else: + if not isinstance(devs, (tuple, list)): + raise ValueError('Derivative values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + devs = {var: devs[i] for i, var in enumerate(self.variables)} + + # integral results + integrals = [] + idx = ((self.num_memory - 1 - self.idx) + bm.arange(self.num_memory)) % self.num_memory + for i, key in enumerate(self.variables): + self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key] + self.inits[key].value = all_args[key] + markov_term = dt ** self.alpha[i] * self.gamma_alpha[i] * devs[key] + all_args[key] + memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff'] + integral = markov_term - memory_trace + integrals.append(integral) + self.idx.value = (self.idx + 1) % self.num_memory + + # return integrals + if len(self.variables) == 1: + return integrals[0] + else: + return integrals register_fde_integrator(name='l1', integrator=CaputoL1Schema) diff --git a/brainpy/_src/integrators/fde/GL.py b/brainpy/_src/integrators/fde/GL.py index 3c5ab0607..948e3946c 100644 --- a/brainpy/_src/integrators/fde/GL.py +++ b/brainpy/_src/integrators/fde/GL.py @@ -16,201 +16,201 @@ from .generic import register_fde_integrator __all__ = [ - 'GLShortMemory' + 'GLShortMemory' ] class GLShortMemory(FDEIntegrator): - r"""Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_. - - According to the explicit numerical approximation of Grünwald-Letnikov, the - fractional-order derivative :math:`q` for a discrete function :math:`f(t_K)` - can be described as follows: - - .. math:: - - {{}_{k-\frac{L_{m}}{h}}D_{t_{k}}^{q}}f(t_{k})\approx h^{-q} - \sum\limits_{j=0}^{k}C_{j}^{q}f(t_{k-j}) - - where :math:`L_{m}` is the memory lenght, :math:`h` is the integration step size, - and :math:`C_{j}^{q}` are the binomial coefficients which are calculated recursively with - - .. math:: - - C_{0}^{q}=1,\ C_{j}^{q}=\left(1- \frac{1+q}{j}\right)C_{j-1}^{q},\ j=1,2, \ldots k. - - Then, the numerical solution for a fractional-order differential equation (FODE) expressed - in the form - - .. math:: - - D_{t_{k}}^{q}x(t_{k})=f(x(t_{k})) - - can be obtained by - - .. math:: - - x(t_{k})=f(x(t_{k-1}))h^{q}- \sum\limits_{j=1}^{k}C_{j}^{q}x(t_{k-j}). - - for :math:`0 < q < 1`. The above expression requires infinity memory length - for numerical solution since the summation term depends on the discritized - time :math:`t_k`. This implies relatively high simulation times. - - To reduce the computational time, the upper bound of summation needs to be modified by - :math:`k=v`, where - - .. math:: - - v=\begin{cases} k, & k\leq M,\\ L_{m}, & k > M. \end{cases} - - This is known as the short-memory principle, where :math:`M` - is the memory window with a width defined by :math:`M=\frac{L_{m}}{h}`. - As was reported in [2]_, the accuracy increases by increaing the width of memory window. - - Examples:: - - >>> import brainpy as bp - >>> - >>> a, b, c = 10, 28, 8 / 3 - >>> def lorenz(x, y, z, t): - >>> dx = a * (y - x) - >>> dy = x * (b - z) - y - >>> dz = x * y - c * z - >>> return dx, dy, dz - >>> - >>> integral = bp.fde.GLShortMemory(lorenz, - >>> alpha=0.96, - >>> num_memory=500, - >>> inits=[1., 0., 1.]) - >>> runner = bp.IntegratorRunner(integral, - >>> monitors=list('xyz'), - >>> inits=[1., 0., 1.], - >>> dt=0.005) - >>> runner.run(100.) - >>> - >>> import matplotlib.pyplot as plt - >>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) - >>> plt.show() - - - Parameters:: - - f : callable - The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence - The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. - num_memory: int - The length of the short memory. - - .. versionchanged:: 2.1.11 - - inits: sequence - A sequence of the initial values for variables. - dt: float, int - The numerical precision. - name: str - The integrator name. - - References:: - - .. [1] Clemente-López, D., et al. "Efficient computation of the - Grünwald-Letnikov method for arm-based implementations of - fractional-order chaotic systems." 2019 8th International - Conference on Modern Circuits and Systems Technologies (MOCAST). IEEE, 2019. - .. [2] M. F. Tolba, A. M. AbdelAty, N. S. Soliman, L. A. Said, A. H. - Madian, A. T. Azar, et al., "FPGA implementation of two fractional - order chaotic systems", International Journal of Electronics and - Communications, vol. 78, pp. 162-172, 2017. - """ - - def __init__( - self, - f: Callable, - alpha: Any, - inits: Any, - num_memory: int, - dt: float = None, - name: str = None, - state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, - ): - super(GLShortMemory, self).__init__(f=f, - alpha=alpha, - dt=dt, - name=name, - num_memory=num_memory, - state_delays=state_delays) - - # fractional order - if not bm.all(bm.logical_and(self.alpha <= 1, self.alpha > 0)): - raise UnsupportedError(f'Only support the fractional order in (0, 1), ' - f'but we got {self.alpha}.') - - # initial values - inits = check_inits(inits, self.variables) - - # delays - self.delays = bm.VarDict() - for key, val in inits.items(): - delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) - delay[0] = val - self.delays[key+'_delay'] = bm.Variable(delay) - self._idx = bm.Variable(bm.asarray([1])) - - # binomial coefficients - bc = (1 - (1 + self.alpha.reshape((-1, 1))) / bm.arange(1, num_memory + 1)) - bc = bm.cumprod(bm.vstack([bm.ones_like(self.alpha), bc.T]), axis=0) - self._binomial_coef = bm.flip(bc[1:], axis=0) - - # integral function - self.set_integral(self._integral_func) - - def reset(self, inits): - """Reset function of the delay variables.""" - self._idx.value = bm.asarray([1]) - inits = check_inits(inits, self.variables) - for key, val in inits.items(): - delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) - delay[0] = val - self.delays[key].value = delay - - @property - def binomial_coef(self): - return bm.as_numpy(bm.flip(self._binomial_coef, axis=0)) - - def _integral_func(self, *args, **kwargs): - # format arguments - all_args = format_args(args, kwargs, self.arg_names) - dt = all_args.pop(DT, self.dt) - - # derivative values - devs = self.f(**all_args) - if len(self.variables) == 1: - if not isinstance(devs, (bm.BaseArray, jax.Array)): - raise ValueError('Derivative values must be a tensor when there ' - 'is only one variable in the equation.') - devs = {self.variables[0]: devs} - else: - if not isinstance(devs, (tuple, list)): - raise ValueError('Derivative values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - devs = {var: devs[i] for i, var in enumerate(self.variables)} - - # integral results - integrals = [] - idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory - for i, var in enumerate(self.variables): - delay_var = var + '_delay' - summation = self._binomial_coef[:, i] @ self.delays[delay_var][idx] - integral = (dt ** self.alpha[i]) * devs[var] - summation - self.delays[delay_var][self._idx[0]] = integral - integrals.append(integral) - self._idx.value = (self._idx + 1) % self.num_memory - - # return integrals - if len(self.variables) == 1: - return integrals[0] - else: - return integrals + r"""Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_. + + According to the explicit numerical approximation of Grünwald-Letnikov, the + fractional-order derivative :math:`q` for a discrete function :math:`f(t_K)` + can be described as follows: + + .. math:: + + {{}_{k-\frac{L_{m}}{h}}D_{t_{k}}^{q}}f(t_{k})\approx h^{-q} + \sum\limits_{j=0}^{k}C_{j}^{q}f(t_{k-j}) + + where :math:`L_{m}` is the memory lenght, :math:`h` is the integration step size, + and :math:`C_{j}^{q}` are the binomial coefficients which are calculated recursively with + + .. math:: + + C_{0}^{q}=1,\ C_{j}^{q}=\left(1- \frac{1+q}{j}\right)C_{j-1}^{q},\ j=1,2, \ldots k. + + Then, the numerical solution for a fractional-order differential equation (FODE) expressed + in the form + + .. math:: + + D_{t_{k}}^{q}x(t_{k})=f(x(t_{k})) + + can be obtained by + + .. math:: + + x(t_{k})=f(x(t_{k-1}))h^{q}- \sum\limits_{j=1}^{k}C_{j}^{q}x(t_{k-j}). + + for :math:`0 < q < 1`. The above expression requires infinity memory length + for numerical solution since the summation term depends on the discritized + time :math:`t_k`. This implies relatively high simulation times. + + To reduce the computational time, the upper bound of summation needs to be modified by + :math:`k=v`, where + + .. math:: + + v=\begin{cases} k, & k\leq M,\\ L_{m}, & k > M. \end{cases} + + This is known as the short-memory principle, where :math:`M` + is the memory window with a width defined by :math:`M=\frac{L_{m}}{h}`. + As was reported in [2]_, the accuracy increases by increaing the width of memory window. + + Examples:: + + >>> import brainpy as bp + >>> + >>> a, b, c = 10, 28, 8 / 3 + >>> def lorenz(x, y, z, t): + >>> dx = a * (y - x) + >>> dy = x * (b - z) - y + >>> dz = x * y - c * z + >>> return dx, dy, dz + >>> + >>> integral = bp.fde.GLShortMemory(lorenz, + >>> alpha=0.96, + >>> num_memory=500, + >>> inits=[1., 0., 1.]) + >>> runner = bp.IntegratorRunner(integral, + >>> monitors=list('xyz'), + >>> inits=[1., 0., 1.], + >>> dt=0.005) + >>> runner.run(100.) + >>> + >>> import matplotlib.pyplot as plt + >>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) + >>> plt.show() + + + Parameters:: + + f : callable + The derivative function. + alpha: int, float, jnp.ndarray, bm.ndarray, sequence + The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. + num_memory: int + The length of the short memory. + + .. versionchanged:: 2.1.11 + + inits: sequence + A sequence of the initial values for variables. + dt: float, int + The numerical precision. + name: str + The integrator name. + + References:: + + .. [1] Clemente-López, D., et al. "Efficient computation of the + Grünwald-Letnikov method for arm-based implementations of + fractional-order chaotic systems." 2019 8th International + Conference on Modern Circuits and Systems Technologies (MOCAST). IEEE, 2019. + .. [2] M. F. Tolba, A. M. AbdelAty, N. S. Soliman, L. A. Said, A. H. + Madian, A. T. Azar, et al., "FPGA implementation of two fractional + order chaotic systems", International Journal of Electronics and + Communications, vol. 78, pp. 162-172, 2017. + """ + + def __init__( + self, + f: Callable, + alpha: Any, + inits: Any, + num_memory: int, + dt: float = None, + name: str = None, + state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, + ): + super(GLShortMemory, self).__init__(f=f, + alpha=alpha, + dt=dt, + name=name, + num_memory=num_memory, + state_delays=state_delays) + + # fractional order + if not bm.all(bm.logical_and(self.alpha <= 1, self.alpha > 0)): + raise UnsupportedError(f'Only support the fractional order in (0, 1), ' + f'but we got {self.alpha}.') + + # initial values + inits = check_inits(inits, self.variables) + + # delays + self.delays = bm.VarDict() + for key, val in inits.items(): + delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) + delay[0] = val + self.delays[key + '_delay'] = bm.Variable(delay) + self._idx = bm.Variable(bm.asarray([1])) + + # binomial coefficients + bc = (1 - (1 + self.alpha.reshape((-1, 1))) / bm.arange(1, num_memory + 1)) + bc = bm.cumprod(bm.vstack([bm.ones_like(self.alpha), bc.T]), axis=0) + self._binomial_coef = bm.flip(bc[1:], axis=0) + + # integral function + self.set_integral(self._integral_func) + + def reset(self, inits): + """Reset function of the delay variables.""" + self._idx.value = bm.asarray([1]) + inits = check_inits(inits, self.variables) + for key, val in inits.items(): + delay = bm.zeros((self.num_memory,) + val.shape, dtype=val.dtype) + delay[0] = val + self.delays[key].value = delay + + @property + def binomial_coef(self): + return bm.as_numpy(bm.flip(self._binomial_coef, axis=0)) + + def _integral_func(self, *args, **kwargs): + # format arguments + all_args = format_args(args, kwargs, self.arg_names) + dt = all_args.pop(DT, self.dt) + + # derivative values + devs = self.f(**all_args) + if len(self.variables) == 1: + if not isinstance(devs, (bm.BaseArray, jax.Array)): + raise ValueError('Derivative values must be a tensor when there ' + 'is only one variable in the equation.') + devs = {self.variables[0]: devs} + else: + if not isinstance(devs, (tuple, list)): + raise ValueError('Derivative values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + devs = {var: devs[i] for i, var in enumerate(self.variables)} + + # integral results + integrals = [] + idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory + for i, var in enumerate(self.variables): + delay_var = var + '_delay' + summation = self._binomial_coef[:, i] @ self.delays[delay_var][idx] + integral = (dt ** self.alpha[i]) * devs[var] - summation + self.delays[delay_var][self._idx[0]] = integral + integrals.append(integral) + self._idx.value = (self._idx + 1) % self.num_memory + + # return integrals + if len(self.variables) == 1: + return integrals[0] + else: + return integrals register_fde_integrator(name='short-memory', integrator=GLShortMemory) diff --git a/brainpy/_src/integrators/fde/__init__.py b/brainpy/_src/integrators/fde/__init__.py index df31e4f3a..2ac3f88cd 100644 --- a/brainpy/_src/integrators/fde/__init__.py +++ b/brainpy/_src/integrators/fde/__init__.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- +from .Caputo import * +from .GL import * from .base import * from .generic import * -from .GL import * -from .Caputo import * - - diff --git a/brainpy/_src/integrators/fde/base.py b/brainpy/_src/integrators/fde/base.py index 2eb29c86e..c4084cd4b 100644 --- a/brainpy/_src/integrators/fde/base.py +++ b/brainpy/_src/integrators/fde/base.py @@ -5,86 +5,85 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy.errors import UnsupportedError from brainpy._src.integrators.base import Integrator from brainpy._src.integrators.utils import get_args from brainpy.check import is_integer +from brainpy.errors import UnsupportedError __all__ = [ - 'FDEIntegrator' + 'FDEIntegrator' ] class FDEIntegrator(Integrator): - """Numerical integrator for fractional differential equations (FEDs). - - Parameters:: - - f : callable - The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence - The fractional-order of the derivative function. - dt: float, int - The numerical precision. - name: str - The integrator name. - """ - - alpha: bm.Array - """The fraction order for each variable.""" - - dt: Union[float, int] - """The numerical integration precision.""" - - f: Callable - """The fraction derivative function.""" - - def __init__( - self, - f: Callable, - alpha, - num_memory: int, - dt: float = None, - name: str = None, - state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, - ): - dt = bm.get_dt() if dt is None else dt - parses = get_args(f) - variables = parses[0] # variable names, (before 't') - parameters = parses[1] # parameter names, (after 't') - arguments = parses[2] # function arguments - - # memory length - is_integer(num_memory, 'num_memory', allow_none=False, min_bound=1) - self.num_memory = num_memory - - # super initialization - super(FDEIntegrator, self).__init__(name=name, - variables=variables, - parameters=parameters, - arguments=arguments, - dt=dt, - state_delays=state_delays) - - # derivative function - self.f = f - - # fractional-order - if isinstance(alpha, (int, float)): - alpha = jnp.ones(len(self.variables)) * alpha - elif isinstance(alpha, (jnp.ndarray, bm.ndarray)): - alpha = bm.as_jax(alpha) - elif isinstance(alpha, (list, tuple)): - for a in alpha: - assert isinstance(a, (float, int)), (f'Must be a tuple/list of int/float, ' - f'but we got {type(a)}: {a}') - alpha = bm.as_jax(alpha) - else: - raise UnsupportedError(f'Do not support {type(alpha)}, please ' - f'set fractional-order as number/tuple/list/tensor.') - if len(alpha) != len(self.variables): - raise ValueError(f'There are {len(self.variables)} variables, ' - f'while we only got {len(alpha)} fractional-order ' - f'settings: {alpha}') - self.alpha = alpha - + """Numerical integrator for fractional differential equations (FEDs). + + Parameters:: + + f : callable + The derivative function. + alpha: int, float, jnp.ndarray, bm.ndarray, sequence + The fractional-order of the derivative function. + dt: float, int + The numerical precision. + name: str + The integrator name. + """ + + alpha: bm.Array + """The fraction order for each variable.""" + + dt: Union[float, int] + """The numerical integration precision.""" + + f: Callable + """The fraction derivative function.""" + + def __init__( + self, + f: Callable, + alpha, + num_memory: int, + dt: float = None, + name: str = None, + state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, + ): + dt = bm.get_dt() if dt is None else dt + parses = get_args(f) + variables = parses[0] # variable names, (before 't') + parameters = parses[1] # parameter names, (after 't') + arguments = parses[2] # function arguments + + # memory length + is_integer(num_memory, 'num_memory', allow_none=False, min_bound=1) + self.num_memory = num_memory + + # super initialization + super(FDEIntegrator, self).__init__(name=name, + variables=variables, + parameters=parameters, + arguments=arguments, + dt=dt, + state_delays=state_delays) + + # derivative function + self.f = f + + # fractional-order + if isinstance(alpha, (int, float)): + alpha = jnp.ones(len(self.variables)) * alpha + elif isinstance(alpha, (jnp.ndarray, bm.ndarray)): + alpha = bm.as_jax(alpha) + elif isinstance(alpha, (list, tuple)): + for a in alpha: + assert isinstance(a, (float, int)), (f'Must be a tuple/list of int/float, ' + f'but we got {type(a)}: {a}') + alpha = bm.as_jax(alpha) + else: + raise UnsupportedError(f'Do not support {type(alpha)}, please ' + f'set fractional-order as number/tuple/list/tensor.') + if len(alpha) != len(self.variables): + raise ValueError(f'There are {len(self.variables)} variables, ' + f'while we only got {len(alpha)} fractional-order ' + f'settings: {alpha}') + self.alpha = alpha diff --git a/brainpy/_src/integrators/fde/generic.py b/brainpy/_src/integrators/fde/generic.py index d7d351cfd..94c577326 100644 --- a/brainpy/_src/integrators/fde/generic.py +++ b/brainpy/_src/integrators/fde/generic.py @@ -3,10 +3,10 @@ from .base import FDEIntegrator __all__ = [ - 'set_default_fdeint', - 'get_default_fdeint', - 'register_fde_integrator', - 'get_supported_methods', + 'set_default_fdeint', + 'get_default_fdeint', + 'register_fde_integrator', + 'get_supported_methods', ] name2method = {} @@ -23,86 +23,86 @@ def fdeint( dt: str = None, name: str = None ): - """Numerical integration for FDEs. - - Parameters:: - - f : callable, function - The derivative function. - method : str - The shortcut name of the numerical integrator. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence - The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. - num_memory: int - The number of the memory length. - inits: sequence - A sequence of the initial values for variables. - dt: float, int - The numerical precision. - name: str - The integrator name. - - Returns:: - - integral : FDEIntegrator - The numerical solver of `f`. - """ - method = _DEFAULT_DDE_METHOD if method is None else method - if method not in name2method: - raise ValueError(f'Unknown FDE numerical method "{method}". Currently ' - f'BrainPy supports: {list(name2method.keys())}') - - if f is None: - return lambda f: name2method[method](f, dt=dt, name=name, inits=inits, num_memory=num_memory, alpha=alpha) - else: - return name2method[method](f, dt=dt, name=name, inits=inits, num_memory=num_memory, alpha=alpha) + """Numerical integration for FDEs. + + Parameters:: + + f : callable, function + The derivative function. + method : str + The shortcut name of the numerical integrator. + alpha: int, float, jnp.ndarray, bm.ndarray, sequence + The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. + num_memory: int + The number of the memory length. + inits: sequence + A sequence of the initial values for variables. + dt: float, int + The numerical precision. + name: str + The integrator name. + + Returns:: + + integral : FDEIntegrator + The numerical solver of `f`. + """ + method = _DEFAULT_DDE_METHOD if method is None else method + if method not in name2method: + raise ValueError(f'Unknown FDE numerical method "{method}". Currently ' + f'BrainPy supports: {list(name2method.keys())}') + + if f is None: + return lambda f: name2method[method](f, dt=dt, name=name, inits=inits, num_memory=num_memory, alpha=alpha) + else: + return name2method[method](f, dt=dt, name=name, inits=inits, num_memory=num_memory, alpha=alpha) def set_default_fdeint(method): - """Set the default ODE numerical integrator method for differential equations. + """Set the default ODE numerical integrator method for differential equations. - Parameters:: + Parameters:: - method : str, callable - Numerical integrator method. - """ - if not isinstance(method, str): - raise ValueError(f'Only support string, not {type(method)}.') - if method not in name2method: - raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') + method : str, callable + Numerical integrator method. + """ + if not isinstance(method, str): + raise ValueError(f'Only support string, not {type(method)}.') + if method not in name2method: + raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') - global _DEFAULT_DDE_METHOD - _DEFAULT_ODE_METHOD = method + global _DEFAULT_DDE_METHOD + _DEFAULT_ODE_METHOD = method def get_default_fdeint(): - """Get the default ODE numerical integrator method. + """Get the default ODE numerical integrator method. - Returns:: + Returns:: - method : str - The default numerical integrator method. - """ - return _DEFAULT_DDE_METHOD + method : str + The default numerical integrator method. + """ + return _DEFAULT_DDE_METHOD def register_fde_integrator(name, integrator): - """Register a new ODE integrator. + """Register a new ODE integrator. - Parameters:: + Parameters:: - name: ste - The integrator name. - integrator: type - The integrator. - """ - if name in name2method: - raise ValueError(f'"{name}" has been registered in FDE integrators.') - if not issubclass(integrator, FDEIntegrator): - raise ValueError(f'"integrator" must be an instance of {FDEIntegrator.__name__}') - name2method[name] = integrator + name: ste + The integrator name. + integrator: type + The integrator. + """ + if name in name2method: + raise ValueError(f'"{name}" has been registered in FDE integrators.') + if not issubclass(integrator, FDEIntegrator): + raise ValueError(f'"integrator" must be an instance of {FDEIntegrator.__name__}') + name2method[name] = integrator def get_supported_methods(): - """Get all supported numerical methods for DDEs.""" - return list(name2method.keys()) + """Get all supported numerical methods for DDEs.""" + return list(name2method.keys()) diff --git a/brainpy/_src/integrators/fde/tests/test_Caputo.py b/brainpy/_src/integrators/fde/tests/test_Caputo.py index 06b8b05a5..94c9cc801 100644 --- a/brainpy/_src/integrators/fde/tests/test_Caputo.py +++ b/brainpy/_src/integrators/fde/tests/test_Caputo.py @@ -9,28 +9,28 @@ class TestCaputoL1(unittest.TestCase): - def test1(self): - bp.math.random.seed() - bp.math.enable_x64() - alpha = 0.9 - intg = bp.fde.CaputoL1Schema(lambda a, t: a, - alpha=alpha, - num_memory=10, - inits=[1., ]) - for N in [2, 3, 4, 5, 6, 7, 8]: - diff = np.random.rand(N - 1, 1) - memory_trace = 0 - for i in range(N - 1): - c = (N - i) ** (1 - alpha) - (N - i - 1) ** (1 - alpha) - memory_trace += c * diff[i] - - intg.idx[0] = N - 1 - intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff) - idx = ((intg.num_memory - intg.idx) + np.arange(intg.num_memory)) % intg.num_memory - memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff'] - - print() - print(memory_trace[0], ) - print(memory_trace2[0], bp.math.array_equal(memory_trace[0], memory_trace2[0])) - - bp.math.disable_x64() + def test1(self): + bp.math.random.seed() + bp.math.enable_x64() + alpha = 0.9 + intg = bp.fde.CaputoL1Schema(lambda a, t: a, + alpha=alpha, + num_memory=10, + inits=[1., ]) + for N in [2, 3, 4, 5, 6, 7, 8]: + diff = np.random.rand(N - 1, 1) + memory_trace = 0 + for i in range(N - 1): + c = (N - i) ** (1 - alpha) - (N - i - 1) ** (1 - alpha) + memory_trace += c * diff[i] + + intg.idx[0] = N - 1 + intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff) + idx = ((intg.num_memory - intg.idx) + np.arange(intg.num_memory)) % intg.num_memory + memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff'] + + print() + print(memory_trace[0], ) + print(memory_trace2[0], bp.math.array_equal(memory_trace[0], memory_trace2[0])) + + bp.math.disable_x64() diff --git a/brainpy/_src/integrators/fde/tests/test_GL.py b/brainpy/_src/integrators/fde/tests/test_GL.py index fb9056207..8787c9c8e 100644 --- a/brainpy/_src/integrators/fde/tests/test_GL.py +++ b/brainpy/_src/integrators/fde/tests/test_GL.py @@ -3,35 +3,33 @@ import unittest -import brainpy as bp import matplotlib.pyplot as plt +import brainpy as bp + block = False class TestGLShortMemory(unittest.TestCase): - def test_lorenz(self): - - a, b, c = 10, 28, 8 / 3 - - def lorenz(x, y, z, t): - dx = a * (y - x) - dy = x * (b - z) - y - dz = x * y - c * z - return dx, dy, dz - - bp.math.random.seed() - integral = bp.fde.GLShortMemory(lorenz, - alpha=0.99, - num_memory=500, - inits=[1., 0., 1.]) - runner = bp.IntegratorRunner(integral, - monitors=list('xyz'), - inits=[1., 0., 1.], - dt=0.005) - runner.run(100.) - - plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) - plt.show(block=block) - - + def test_lorenz(self): + a, b, c = 10, 28, 8 / 3 + + def lorenz(x, y, z, t): + dx = a * (y - x) + dy = x * (b - z) - y + dz = x * y - c * z + return dx, dy, dz + + bp.math.random.seed() + integral = bp.fde.GLShortMemory(lorenz, + alpha=0.99, + num_memory=500, + inits=[1., 0., 1.]) + runner = bp.IntegratorRunner(integral, + monitors=list('xyz'), + inits=[1., 0., 1.], + dt=0.005) + runner.run(100.) + + plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) + plt.show(block=block) diff --git a/brainpy/_src/integrators/joint_eq.py b/brainpy/_src/integrators/joint_eq.py index f9d5a1aca..8d40bdd32 100644 --- a/brainpy/_src/integrators/joint_eq.py +++ b/brainpy/_src/integrators/joint_eq.py @@ -6,213 +6,213 @@ from brainpy._src.math.object_transform.base import Collector __all__ = [ - 'JointEq', + 'JointEq', ] def _get_args(f): - """Get the function arguments""" - args = [] - kwargs = {} - for name, par in inspect.signature(f).parameters.items(): - if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: - if par.default is inspect._empty: - args.append(par.name) - else: - kwargs[par.name] = par.default - elif par.kind is inspect.Parameter.VAR_POSITIONAL: - raise errors.DiffEqError(f'{JointEq.__name__} does not support VAR_POSITIONAL parameters ' - f'*{par.name} (error in {f}).') - elif par.kind is inspect.Parameter.KEYWORD_ONLY: - raise errors.DiffEqError(f'{JointEq.__name__} does not support KEYWORD_ONLY parameters, ' - f'e.g., * (error in {f}).') - elif par.kind is inspect.Parameter.POSITIONAL_ONLY: - raise errors.DiffEqError(f'{JointEq.__name__} does not support POSITIONAL_ONLY parameters, ' - 'e.g., / (error in {f}).') - elif par.kind is inspect.Parameter.VAR_KEYWORD: - raise errors.DiffEqError(f'{JointEq.__name__} does not support VAR_KEYWORD ' - f'arguments **{par.name} (error in {f}).') + """Get the function arguments""" + args = [] + kwargs = {} + for name, par in inspect.signature(f).parameters.items(): + if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + if par.default is inspect._empty: + args.append(par.name) + else: + kwargs[par.name] = par.default + elif par.kind is inspect.Parameter.VAR_POSITIONAL: + raise errors.DiffEqError(f'{JointEq.__name__} does not support VAR_POSITIONAL parameters ' + f'*{par.name} (error in {f}).') + elif par.kind is inspect.Parameter.KEYWORD_ONLY: + raise errors.DiffEqError(f'{JointEq.__name__} does not support KEYWORD_ONLY parameters, ' + f'e.g., * (error in {f}).') + elif par.kind is inspect.Parameter.POSITIONAL_ONLY: + raise errors.DiffEqError(f'{JointEq.__name__} does not support POSITIONAL_ONLY parameters, ' + 'e.g., / (error in {f}).') + elif par.kind is inspect.Parameter.VAR_KEYWORD: + raise errors.DiffEqError(f'{JointEq.__name__} does not support VAR_KEYWORD ' + f'arguments **{par.name} (error in {f}).') + else: + raise errors.DiffEqError(f'Unknown argument type: {par.kind}') + + # variables + vars = [] + for a in args: + if a == 't': + break + vars.append(a) else: - raise errors.DiffEqError(f'Unknown argument type: {par.kind}') - - # variables - vars = [] - for a in args: - if a == 't': - break - vars.append(a) - else: - raise ValueError('Do not find time variable "t".') + raise ValueError('Do not find time variable "t".') - return vars, args, kwargs + return vars, args, kwargs def _std_func(f, all_vars: list): - f_vars, f_args, f_kwargs = _get_args(f) - - def call(t, *vars, **args_and_kwargs): - params = dict(t=t) - for var in f_vars: - params[var] = vars[all_vars.index(var)] - for par in f_args[len(f_vars) + 1:]: - if par in args_and_kwargs: - params[par] = args_and_kwargs[par] - else: - if par not in all_vars: - raise errors.DiffEqError(f'Missing {par} during the functional call of {f}.') - params[par] = vars[all_vars.index(par)] - for par, value in f_kwargs.items(): - if par in args_and_kwargs: - params[par] = args_and_kwargs[par] - return f(**params) - - return call + f_vars, f_args, f_kwargs = _get_args(f) + + def call(t, *vars, **args_and_kwargs): + params = dict(t=t) + for var in f_vars: + params[var] = vars[all_vars.index(var)] + for par in f_args[len(f_vars) + 1:]: + if par in args_and_kwargs: + params[par] = args_and_kwargs[par] + else: + if par not in all_vars: + raise errors.DiffEqError(f'Missing {par} during the functional call of {f}.') + params[par] = vars[all_vars.index(par)] + for par, value in f_kwargs.items(): + if par in args_and_kwargs: + params[par] = args_and_kwargs[par] + return f(**params) + + return call class JointEq(object): - """Make a joint equation from multiple derivation functions. - - For example, we have an Izhikevich neuron model, - - >>> a, b = 0.02, 0.20 - >>> dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext - >>> du = lambda u, t, V: a * (b * V - u) - - If we make numerical solver for each derivative function, they will be solved independently. - - >>> import brainpy as bp - >>> bp.odeint(dV, method='rk2', show_code=True) - def brainpy_itg_of_ode0(V, t, u, Iext, dt=0.1): - dV_k1 = f(V, t, u, Iext) - k2_V_arg = V + dt * dV_k1 * 0.6666666666666666 - k2_t_arg = t + dt * 0.6666666666666666 - dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext) - V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75 - return V_new - - As you see in the output code, "dV_k2" is evaluated by :math:`f(V_{k2}, u)`. - If you want to solve the above coupled equation jointly, i.e., evalute "dV_k2" - with :math:`f(V_{k2}, u_{k2})`, you can use :py:class:`brainpy.JointEq` - to emerge the above two derivative equations into a joint equation, so that - they will be numerically solved together. Let's see the difference: - - >>> eq = bp.JointEq(eqs=(dV, du)) - >>> bp.odeint(eq, method='rk2', show_code=True) - def brainpy_itg_of_ode0_joint_eq(V, u, t, Iext, dt=0.1): - dV_k1, du_k1 = f(V, u, t, Iext) - k2_V_arg = V + dt * dV_k1 * 0.6666666666666666 - k2_u_arg = u + dt * du_k1 * 0.6666666666666666 - k2_t_arg = t + dt * 0.6666666666666666 - dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext) - V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75 - u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75 - return V_new, u_new - - :py:class:`brainpy.JointEq` supports make nested ``JointEq``, which means - the instance of ``JointEq`` can be an element to compose a new ``JointEq``. - - >>> dw = lambda w, t, V: a * (b * V - w) - >>> eq2 = bp.JointEq(eqs=(eq, dw)) - - - Parameters:: - - *eqs : - The elements of derivative function to compose. - """ - - def _check_eqs(self, eqs): - for eq in eqs: - if isinstance(eq, (list, tuple)): - for a in self._check_eqs(eq): - yield a - elif callable(eq): - yield eq - else: - raise errors.DiffEqError(f'Elements in "eqs" only supports callable function, but got {eq}.') - - def __init__(self, *eqs): - eqs = list(self._check_eqs(eqs)) - - # variables in equations - self.vars_in_eqs = [] - vars_in_eqs = [] - for eq in eqs: - vars, _, _ = _get_args(eq) - for var in vars: - if var in vars_in_eqs: - raise errors.DiffEqError(f'Variable "{var}" has been used, however we got a same ' - f'variable name in {eq}. Please change another name.') - vars_in_eqs.extend(vars) - self.vars_in_eqs.append(vars) - - # arguments in equations - self.args_in_eqs = [] - all_arg_pars = [] - all_kwarg_pars = dict() - for eq in eqs: - vars, args, kwargs = _get_args(eq) - self.args_in_eqs.append(args + list(kwargs.keys())) - for par in args[len(vars) + 1:]: - if (par not in vars_in_eqs) and (par not in all_arg_pars) and (par not in all_kwarg_pars): - all_arg_pars.append(par) - for key, value in kwargs.items(): - if key in all_kwarg_pars and value != all_kwarg_pars[key]: - raise errors.DiffEqError(f'We got two different default value of "{key}": ' - f'{all_kwarg_pars[key]} != {value}') - elif (key not in vars_in_eqs) and (key not in all_arg_pars): - all_kwarg_pars[key] = value - else: - raise errors.DiffEqError - - # # variable names provided - # if not isinstance(variables, (tuple, list)): - # raise errors.DiffEqError(f'"variables" must be a list/tuple of str, but we got {variables}') - # for v in variables: - # if not isinstance(v, str): - # raise errors.DiffEqError(f'"variables" must be a list/tuple of str, but we got {v} in "variables"') - # if len(vars_in_eqs) != len(variables): - # raise errors.DiffEqError(f'We detect {len(vars_in_eqs)} variables "{vars_in_eqs}" ' - # f'in the provided equations. However, the used provided ' - # f'"variables" have {len(variables)} variables ' - # f'"{variables}".') - # if len(set(vars_in_eqs) - set(variables)) != 0: - # raise errors.DiffEqError(f'We detect there are variable "{vars_in_eqs}" in the provided ' - # f'equations, while the user provided variables "{variables}" ' - # f'is not the same.') - - # finally - self.eqs = eqs - # self.variables = variables - self.arg_keys = vars_in_eqs + ['t'] + all_arg_pars - self.kwarg_keys = list(all_kwarg_pars.keys()) - self.kwargs = all_kwarg_pars - parameters = [inspect.Parameter(vp, inspect.Parameter.POSITIONAL_OR_KEYWORD) - for vp in self.arg_keys] - parameters.extend([inspect.Parameter(k, - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=all_kwarg_pars[k]) - for k in self.kwarg_keys]) - signature = inspect.signature(eqs[0]) - self.__signature__ = signature.replace(parameters=parameters) - self.__name__ = 'joint_eq' - - def __call__(self, *args, **kwargs): - # format arguments - params_in = Collector() - for i, arg in enumerate(args): - if i < len(self.arg_keys): - params_in[self.arg_keys[i]] = arg - else: - params_in[self.kwarg_keys[i - len(self.arg_keys)]] = arg - params_in.update(kwargs) - - # call equations - results = [] - for i, eq in enumerate(self.eqs): - r = eq(**{arg: params_in[arg] for arg in self.args_in_eqs[i]}) - if isinstance(r, (list, tuple)): - results.extend(list(r)) - else: - results.append(r) - return results + """Make a joint equation from multiple derivation functions. + + For example, we have an Izhikevich neuron model, + + >>> a, b = 0.02, 0.20 + >>> dV = lambda V, t, u, Iext: 0.04 * V * V + 5 * V + 140 - u + Iext + >>> du = lambda u, t, V: a * (b * V - u) + + If we make numerical solver for each derivative function, they will be solved independently. + + >>> import brainpy as bp + >>> bp.odeint(dV, method='rk2', show_code=True) + def brainpy_itg_of_ode0(V, t, u, Iext, dt=0.1): + dV_k1 = f(V, t, u, Iext) + k2_V_arg = V + dt * dV_k1 * 0.6666666666666666 + k2_t_arg = t + dt * 0.6666666666666666 + dV_k2 = f(k2_V_arg, k2_t_arg, u, Iext) + V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75 + return V_new + + As you see in the output code, "dV_k2" is evaluated by :math:`f(V_{k2}, u)`. + If you want to solve the above coupled equation jointly, i.e., evalute "dV_k2" + with :math:`f(V_{k2}, u_{k2})`, you can use :py:class:`brainpy.JointEq` + to emerge the above two derivative equations into a joint equation, so that + they will be numerically solved together. Let's see the difference: + + >>> eq = bp.JointEq(eqs=(dV, du)) + >>> bp.odeint(eq, method='rk2', show_code=True) + def brainpy_itg_of_ode0_joint_eq(V, u, t, Iext, dt=0.1): + dV_k1, du_k1 = f(V, u, t, Iext) + k2_V_arg = V + dt * dV_k1 * 0.6666666666666666 + k2_u_arg = u + dt * du_k1 * 0.6666666666666666 + k2_t_arg = t + dt * 0.6666666666666666 + dV_k2, du_k2 = f(k2_V_arg, k2_u_arg, k2_t_arg, Iext) + V_new = V + dV_k1 * dt * 0.25 + dV_k2 * dt * 0.75 + u_new = u + du_k1 * dt * 0.25 + du_k2 * dt * 0.75 + return V_new, u_new + + :py:class:`brainpy.JointEq` supports make nested ``JointEq``, which means + the instance of ``JointEq`` can be an element to compose a new ``JointEq``. + + >>> dw = lambda w, t, V: a * (b * V - w) + >>> eq2 = bp.JointEq(eqs=(eq, dw)) + + + Parameters:: + + *eqs : + The elements of derivative function to compose. + """ + + def _check_eqs(self, eqs): + for eq in eqs: + if isinstance(eq, (list, tuple)): + for a in self._check_eqs(eq): + yield a + elif callable(eq): + yield eq + else: + raise errors.DiffEqError(f'Elements in "eqs" only supports callable function, but got {eq}.') + + def __init__(self, *eqs): + eqs = list(self._check_eqs(eqs)) + + # variables in equations + self.vars_in_eqs = [] + vars_in_eqs = [] + for eq in eqs: + vars, _, _ = _get_args(eq) + for var in vars: + if var in vars_in_eqs: + raise errors.DiffEqError(f'Variable "{var}" has been used, however we got a same ' + f'variable name in {eq}. Please change another name.') + vars_in_eqs.extend(vars) + self.vars_in_eqs.append(vars) + + # arguments in equations + self.args_in_eqs = [] + all_arg_pars = [] + all_kwarg_pars = dict() + for eq in eqs: + vars, args, kwargs = _get_args(eq) + self.args_in_eqs.append(args + list(kwargs.keys())) + for par in args[len(vars) + 1:]: + if (par not in vars_in_eqs) and (par not in all_arg_pars) and (par not in all_kwarg_pars): + all_arg_pars.append(par) + for key, value in kwargs.items(): + if key in all_kwarg_pars and value != all_kwarg_pars[key]: + raise errors.DiffEqError(f'We got two different default value of "{key}": ' + f'{all_kwarg_pars[key]} != {value}') + elif (key not in vars_in_eqs) and (key not in all_arg_pars): + all_kwarg_pars[key] = value + else: + raise errors.DiffEqError + + # # variable names provided + # if not isinstance(variables, (tuple, list)): + # raise errors.DiffEqError(f'"variables" must be a list/tuple of str, but we got {variables}') + # for v in variables: + # if not isinstance(v, str): + # raise errors.DiffEqError(f'"variables" must be a list/tuple of str, but we got {v} in "variables"') + # if len(vars_in_eqs) != len(variables): + # raise errors.DiffEqError(f'We detect {len(vars_in_eqs)} variables "{vars_in_eqs}" ' + # f'in the provided equations. However, the used provided ' + # f'"variables" have {len(variables)} variables ' + # f'"{variables}".') + # if len(set(vars_in_eqs) - set(variables)) != 0: + # raise errors.DiffEqError(f'We detect there are variable "{vars_in_eqs}" in the provided ' + # f'equations, while the user provided variables "{variables}" ' + # f'is not the same.') + + # finally + self.eqs = eqs + # self.variables = variables + self.arg_keys = vars_in_eqs + ['t'] + all_arg_pars + self.kwarg_keys = list(all_kwarg_pars.keys()) + self.kwargs = all_kwarg_pars + parameters = [inspect.Parameter(vp, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for vp in self.arg_keys] + parameters.extend([inspect.Parameter(k, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=all_kwarg_pars[k]) + for k in self.kwarg_keys]) + signature = inspect.signature(eqs[0]) + self.__signature__ = signature.replace(parameters=parameters) + self.__name__ = 'joint_eq' + + def __call__(self, *args, **kwargs): + # format arguments + params_in = Collector() + for i, arg in enumerate(args): + if i < len(self.arg_keys): + params_in[self.arg_keys[i]] = arg + else: + params_in[self.kwarg_keys[i - len(self.arg_keys)]] = arg + params_in.update(kwargs) + + # call equations + results = [] + for i, eq in enumerate(self.eqs): + r = eq(**{arg: params_in[arg] for arg in self.args_in_eqs[i]}) + if isinstance(r, (list, tuple)): + results.extend(list(r)) + else: + results.append(r) + return results diff --git a/brainpy/_src/integrators/ode/__init__.py b/brainpy/_src/integrators/ode/__init__.py index eafed5d35..da4c5cef2 100644 --- a/brainpy/_src/integrators/ode/__init__.py +++ b/brainpy/_src/integrators/ode/__init__.py @@ -4,8 +4,8 @@ Numerical methods for ordinary differential equations (ODEs). """ -from .base import * -from .generic import * from .adaptive_rk import * +from .base import * from .explicit_rk import * from .exponential import * +from .generic import * diff --git a/brainpy/_src/integrators/ode/adaptive_rk.py b/brainpy/_src/integrators/ode/adaptive_rk.py index a864cf8ee..a24f7be29 100644 --- a/brainpy/_src/integrators/ode/adaptive_rk.py +++ b/brainpy/_src/integrators/ode/adaptive_rk.py @@ -54,470 +54,471 @@ """ import jax.numpy as jnp + from brainpy import errors -from brainpy._src.integrators.ode.generic import register_ode_integrator from brainpy._src.integrators import constants as C, utils from brainpy._src.integrators.ode import common from brainpy._src.integrators.ode.base import ODEIntegrator +from brainpy._src.integrators.ode.generic import register_ode_integrator __all__ = [ - 'AdaptiveRKIntegrator', - 'RKF12', - 'RKF45', - 'DormandPrince', - 'CashKarp', - 'BogackiShampine', - 'HeunEuler', + 'AdaptiveRKIntegrator', + 'RKF12', + 'RKF45', + 'DormandPrince', + 'CashKarp', + 'BogackiShampine', + 'HeunEuler', ] class AdaptiveRKIntegrator(ODEIntegrator): - r"""Adaptive Runge-Kutta method for ordinary differential equations. - - The embedded methods are designed to produce an estimate of the local - truncation error of a single Runge-Kutta step, and as result, allow to - control the error with adaptive step-size. This is done by having two - methods in the tableau, one with order p and one with order :math:`p-1`. - - The lower-order step is given by - - .. math:: - - y^*_{n+1} = y_n + h\sum_{i=1}^s b^*_i k_i, - - where the :math:`k_{i}` are the same as for the higher order method. Then the error is - - .. math:: - - e_{n+1} = y_{n+1} - y^*_{n+1} = h\sum_{i=1}^s (b_i - b^*_i) k_i, - - - which is :math:`O(h^{p})`. The Butcher Tableau for this kind of method is extended to - give the values of :math:`b_{i}^{*}` - - .. math:: - - \begin{array}{c|cccc} - c_1 & a_{11} & a_{12}& \dots & a_{1s}\\ - c_2 & a_{21} & a_{22}& \dots & a_{2s}\\ - \vdots & \vdots & \vdots& \ddots& \vdots\\ - c_s & a_{s1} & a_{s2}& \dots & a_{ss} \\ - \hline & b_1 & b_2 & \dots & b_s\\ - & b_1^* & b_2^* & \dots & b_s^*\\ - \end{array} - - Parameters:: - - f : callable - The derivative function. - show_code : bool - Whether show the formatted code. - dt : float - The numerical precision. - adaptive : bool - Whether use the adaptive updating. - tol : float - The error tolerence. - var_type : str - The variable type. - """ - - A = [] # The A matrix in the Butcher tableau. - B1 = [] # The B1 vector in the Butcher tableau. - B2 = [] # The B2 vector in the Butcher tableau. - C = [] # The C vector in the Butcher tableau. - - def __init__(self, - f, - var_type=None, - dt=None, - name=None, - adaptive=None, - tol=None, - show_code=False, - state_delays=None, - neutral_delays=None): - super(AdaptiveRKIntegrator, self).__init__(f=f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays) - - # check parameters - self.adaptive = False if (adaptive is None) else adaptive - self.tol = 0.1 if tol is None else tol - self.var_type = C.POP_VAR if var_type is None else var_type - if self.var_type not in C.SUPPORTED_VAR_TYPE: - raise errors.IntegratorError(f'"var_type" only supports {C.SUPPORTED_VAR_TYPE}, ' - f'not {self.var_type}.') - - # integrator keywords - keywords = { - C.F: 'the derivative function', - # C.DT: 'the precision of numerical integration' - } - for v in self.variables: - keywords[f'{v}_new'] = 'the intermediate value' - for i in range(1, len(self.A) + 1): - keywords[f'd{v}_k{i}'] = 'the intermediate value' - for i in range(2, len(self.A) + 1): - keywords[f'k{i}_{v}_arg'] = 'the intermediate value' - keywords[f'k{i}_t_arg'] = 'the intermediate value' - if adaptive: - keywords['dt_new'] = 'the new numerical precision "dt"' - keywords['tol'] = 'the tolerance for the local truncation error' - keywords['error'] = 'the local truncation error' - for v in self.variables: - keywords[f'{v}_te'] = 'the local truncation error' - self.code_scope['tol'] = tol - self.code_scope['math'] = jnp - utils.check_kws(self.arg_names, keywords) - - # build the integrator - self.build() - - def build(self): - # step stage - common.step(self.variables, C.DT, - self.A, self.C, self.code_lines, self.parameters) - # variable update - return_args = common.update(self.variables, C.DT, self.B1, self.code_lines) - # error adaptive item - if self.adaptive: - errors_ = [] - for v in self.variables: - result = [] - for i, (b1, b2) in enumerate(zip(self.B1, self.B2)): - if isinstance(b1, str): - b1 = eval(b1) - if isinstance(b2, str): - b2 = eval(b2) - diff = b1 - b2 - if diff != 0.: - result.append(f'd{v}_k{i + 1} * {C.DT} * {diff}') - if len(result) > 0: - if self.var_type == C.SCALAR_VAR: - self.code_lines.append(f' {v}_te = abs({" + ".join(result)})') - else: - self.code_lines.append(f' {v}_te = sum(abs({" + ".join(result)}))') - errors_.append(f'{v}_te') - if len(errors_) > 0: - self.code_lines.append(f' error = {" + ".join(errors_)}') - self.code_lines.append(f' {C.DT}_new = math.where(error > tol, 0.9*{C.DT}*(tol/error)**0.2, {C.DT})') - return_args.append(f'{C.DT}_new') - # returns - self.code_lines.append(f' return {", ".join(return_args)}') - # compile - self.integral = utils.compile_code( - code_scope={k: v for k, v in self.code_scope.items()}, - code_lines=self.code_lines, - show_code=self.show_code, - func_name=self.func_name) + r"""Adaptive Runge-Kutta method for ordinary differential equations. + + The embedded methods are designed to produce an estimate of the local + truncation error of a single Runge-Kutta step, and as result, allow to + control the error with adaptive step-size. This is done by having two + methods in the tableau, one with order p and one with order :math:`p-1`. + + The lower-order step is given by + + .. math:: + + y^*_{n+1} = y_n + h\sum_{i=1}^s b^*_i k_i, + + where the :math:`k_{i}` are the same as for the higher order method. Then the error is + + .. math:: + + e_{n+1} = y_{n+1} - y^*_{n+1} = h\sum_{i=1}^s (b_i - b^*_i) k_i, + + + which is :math:`O(h^{p})`. The Butcher Tableau for this kind of method is extended to + give the values of :math:`b_{i}^{*}` + + .. math:: + + \begin{array}{c|cccc} + c_1 & a_{11} & a_{12}& \dots & a_{1s}\\ + c_2 & a_{21} & a_{22}& \dots & a_{2s}\\ + \vdots & \vdots & \vdots& \ddots& \vdots\\ + c_s & a_{s1} & a_{s2}& \dots & a_{ss} \\ + \hline & b_1 & b_2 & \dots & b_s\\ + & b_1^* & b_2^* & \dots & b_s^*\\ + \end{array} + + Parameters:: + + f : callable + The derivative function. + show_code : bool + Whether show the formatted code. + dt : float + The numerical precision. + adaptive : bool + Whether use the adaptive updating. + tol : float + The error tolerence. + var_type : str + The variable type. + """ + + A = [] # The A matrix in the Butcher tableau. + B1 = [] # The B1 vector in the Butcher tableau. + B2 = [] # The B2 vector in the Butcher tableau. + C = [] # The C vector in the Butcher tableau. + + def __init__(self, + f, + var_type=None, + dt=None, + name=None, + adaptive=None, + tol=None, + show_code=False, + state_delays=None, + neutral_delays=None): + super(AdaptiveRKIntegrator, self).__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) + + # check parameters + self.adaptive = False if (adaptive is None) else adaptive + self.tol = 0.1 if tol is None else tol + self.var_type = C.POP_VAR if var_type is None else var_type + if self.var_type not in C.SUPPORTED_VAR_TYPE: + raise errors.IntegratorError(f'"var_type" only supports {C.SUPPORTED_VAR_TYPE}, ' + f'not {self.var_type}.') + + # integrator keywords + keywords = { + C.F: 'the derivative function', + # C.DT: 'the precision of numerical integration' + } + for v in self.variables: + keywords[f'{v}_new'] = 'the intermediate value' + for i in range(1, len(self.A) + 1): + keywords[f'd{v}_k{i}'] = 'the intermediate value' + for i in range(2, len(self.A) + 1): + keywords[f'k{i}_{v}_arg'] = 'the intermediate value' + keywords[f'k{i}_t_arg'] = 'the intermediate value' + if adaptive: + keywords['dt_new'] = 'the new numerical precision "dt"' + keywords['tol'] = 'the tolerance for the local truncation error' + keywords['error'] = 'the local truncation error' + for v in self.variables: + keywords[f'{v}_te'] = 'the local truncation error' + self.code_scope['tol'] = tol + self.code_scope['math'] = jnp + utils.check_kws(self.arg_names, keywords) + + # build the integrator + self.build() + + def build(self): + # step stage + common.step(self.variables, C.DT, + self.A, self.C, self.code_lines, self.parameters) + # variable update + return_args = common.update(self.variables, C.DT, self.B1, self.code_lines) + # error adaptive item + if self.adaptive: + errors_ = [] + for v in self.variables: + result = [] + for i, (b1, b2) in enumerate(zip(self.B1, self.B2)): + if isinstance(b1, str): + b1 = eval(b1) + if isinstance(b2, str): + b2 = eval(b2) + diff = b1 - b2 + if diff != 0.: + result.append(f'd{v}_k{i + 1} * {C.DT} * {diff}') + if len(result) > 0: + if self.var_type == C.SCALAR_VAR: + self.code_lines.append(f' {v}_te = abs({" + ".join(result)})') + else: + self.code_lines.append(f' {v}_te = sum(abs({" + ".join(result)}))') + errors_.append(f'{v}_te') + if len(errors_) > 0: + self.code_lines.append(f' error = {" + ".join(errors_)}') + self.code_lines.append(f' {C.DT}_new = math.where(error > tol, 0.9*{C.DT}*(tol/error)**0.2, {C.DT})') + return_args.append(f'{C.DT}_new') + # returns + self.code_lines.append(f' return {", ".join(return_args)}') + # compile + self.integral = utils.compile_code( + code_scope={k: v for k, v in self.code_scope.items()}, + code_lines=self.code_lines, + show_code=self.show_code, + func_name=self.func_name) class RKF12(AdaptiveRKIntegrator): - r"""The Fehlberg RK1(2) method for ODEs. + r"""The Fehlberg RK1(2) method for ODEs. - The Fehlberg method has two methods of orders 1 and 2. + The Fehlberg method has two methods of orders 1 and 2. - It has the characteristics of: + It has the characteristics of: - - method stage = 2 - - method order = 1 - - Butcher Tables: + - method stage = 2 + - method order = 1 + - Butcher Tables: - .. math:: + .. math:: - \begin{array}{l|ll} - 0 & & \\ - 1 / 2 & 1 / 2 & \\ - 1 & 1 / 256 & 255 / 256 & \\ - \hline & 1 / 512 & 255 / 256 & 1 / 512 \\ - & 1 / 256 & 255 / 256 & 0 - \end{array} + \begin{array}{l|ll} + 0 & & \\ + 1 / 2 & 1 / 2 & \\ + 1 & 1 / 256 & 255 / 256 & \\ + \hline & 1 / 512 & 255 / 256 & 1 / 512 \\ + & 1 / 256 & 255 / 256 & 0 + \end{array} - References:: + References:: - .. [1] Fehlberg, E. (1969-07-01). "Low-order classical Runge-Kutta - formulas with stepsize control and their application to some heat - transfer problems" + .. [1] Fehlberg, E. (1969-07-01). "Low-order classical Runge-Kutta + formulas with stepsize control and their application to some heat + transfer problems" - """ + """ - A = [(), - (0.5,), - ('1/256', '255/256')] - B1 = ['1/512', '255/256', '1/512'] - B2 = ['1/256', '255/256', 0] - C = [0, 0.5, 1] + A = [(), + (0.5,), + ('1/256', '255/256')] + B1 = ['1/512', '255/256', '1/512'] + B2 = ['1/256', '255/256', 0] + C = [0, 0.5, 1] register_ode_integrator('rkf12', RKF12) class RKF45(AdaptiveRKIntegrator): - r"""The Runge–Kutta–Fehlberg method for ODEs. - - The method presented in Fehlberg's 1969 paper has been dubbed the - RKF45 method, and is a method of order :math:`O(h^4)` with an error - estimator of order :math:`O(h^5)`. The novelty of Fehlberg's method is - that it is an embedded method from the Runge–Kutta family, meaning that - identical function evaluations are used in conjunction with each other - to create methods of varying order and similar error constants. - - Its Butcher table is: - - .. math:: - - \begin{array}{l|lllll} - 0 & & & & & & \\ - 1 / 4 & 1 / 4 & & & & \\ - 3 / 8 & 3 / 32 & 9 / 32 & & \\ - 12 / 13 & 1932 / 2197 & -7200 / 2197 & 7296 / 2197 & \\ - 1 & 439 / 216 & -8 & 3680 / 513 & -845 / 4104 & & \\ - 1 / 2 & -8 / 27 & 2 & -3544 / 2565 & 1859 / 4104 & -11 / 40 & \\ - \hline & 16 / 135 & 0 & 6656 / 12825 & 28561 / 56430 & -9 / 50 & 2 / 55 \\ - & 25 / 216 & 0 & 1408 / 2565 & 2197 / 4104 & -1 / 5 & 0 - \end{array} - - References:: - - .. [1] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method - .. [2] Erwin Fehlberg (1969). Low-order classical Runge-Kutta formulas with step - size control and their application to some heat transfer problems . NASA Technical Report 315. - https://ntrs.nasa.gov/api/citations/19690021375/downloads/19690021375.pdf - - """ - - A = [(), - (0.25,), - (0.09375, 0.28125), - ('1932/2197', '-7200/2197', '7296/2197'), - ('439/216', -8, '3680/513', '-845/4104'), - ('-8/27', 2, '-3544/2565', '1859/4104', -0.275)] - B1 = ['16/135', 0, '6656/12825', '28561/56430', -0.18, '2/55'] - B2 = ['25/216', 0, '1408/2565', '2197/4104', -0.2, 0] - C = [0, 0.25, 0.375, '12/13', 1, '1/3'] + r"""The Runge–Kutta–Fehlberg method for ODEs. + + The method presented in Fehlberg's 1969 paper has been dubbed the + RKF45 method, and is a method of order :math:`O(h^4)` with an error + estimator of order :math:`O(h^5)`. The novelty of Fehlberg's method is + that it is an embedded method from the Runge–Kutta family, meaning that + identical function evaluations are used in conjunction with each other + to create methods of varying order and similar error constants. + + Its Butcher table is: + + .. math:: + + \begin{array}{l|lllll} + 0 & & & & & & \\ + 1 / 4 & 1 / 4 & & & & \\ + 3 / 8 & 3 / 32 & 9 / 32 & & \\ + 12 / 13 & 1932 / 2197 & -7200 / 2197 & 7296 / 2197 & \\ + 1 & 439 / 216 & -8 & 3680 / 513 & -845 / 4104 & & \\ + 1 / 2 & -8 / 27 & 2 & -3544 / 2565 & 1859 / 4104 & -11 / 40 & \\ + \hline & 16 / 135 & 0 & 6656 / 12825 & 28561 / 56430 & -9 / 50 & 2 / 55 \\ + & 25 / 216 & 0 & 1408 / 2565 & 2197 / 4104 & -1 / 5 & 0 + \end{array} + + References:: + + .. [1] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method + .. [2] Erwin Fehlberg (1969). Low-order classical Runge-Kutta formulas with step + size control and their application to some heat transfer problems . NASA Technical Report 315. + https://ntrs.nasa.gov/api/citations/19690021375/downloads/19690021375.pdf + + """ + + A = [(), + (0.25,), + (0.09375, 0.28125), + ('1932/2197', '-7200/2197', '7296/2197'), + ('439/216', -8, '3680/513', '-845/4104'), + ('-8/27', 2, '-3544/2565', '1859/4104', -0.275)] + B1 = ['16/135', 0, '6656/12825', '28561/56430', -0.18, '2/55'] + B2 = ['25/216', 0, '1408/2565', '2197/4104', -0.2, 0] + C = [0, 0.25, 0.375, '12/13', 1, '1/3'] register_ode_integrator('rkf45', RKF45) class DormandPrince(AdaptiveRKIntegrator): - r"""The Dormand–Prince method for ODEs. - - The DOPRI method, is an explicit method for solving ordinary differential equations - (Dormand & Prince 1980). The Dormand–Prince method has seven stages, but it uses only - six function evaluations per step because it has the FSAL (First Same As Last) property: - the last stage is evaluated at the same point as the first stage of the next step. - Dormand and Prince chose the coefficients of their method to minimize the error of - the fifth-order solution. This is the main difference with the Fehlberg method, which - was constructed so that the fourth-order solution has a small error. For this reason, - the Dormand–Prince method is more suitable when the higher-order solution is used to - continue the integration, a practice known as local extrapolation - (Shampine 1986; Hairer, Nørsett & Wanner 2008, pp. 178–179). - - Its Butcher table is: - - .. math:: - - \begin{array}{l|llllll} - 0 & \\ - 1 / 5 & 1 / 5 & & & \\ - 3 / 10 & 3 / 40 & 9 / 40 & & & \\ - 4 / 5 & 44 / 45 & -56 / 15 & 32 / 9 & & \\ - 8 / 9 & 19372 / 6561 & -25360 / 2187 & 64448 / 6561 & -212 / 729 & \\ - 1 & 9017 / 3168 & -355 / 33 & 46732 / 5247 & 49 / 176 & -5103 / 18656 & \\ - 1 & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & \\ - \hline & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & 0 \\ - & 5179 / 57600 & 0 & 7571 / 16695 & 393 / 640 & -92097 / 339200 & 187 / 2100 & 1 / 40 - \end{array} - - References:: - - .. [1] https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method - .. [2] Dormand, J. R.; Prince, P. J. (1980), "A family of embedded Runge-Kutta formulae", - Journal of Computational and Applied Mathematics, 6 (1): 19–26, - doi:10.1016/0771-050X(80)90013-3. - """ - - A = [(), - (0.2,), - (0.075, 0.225), - ('44/45', '-56/15', '32/9'), - ('19372/6561', '-25360/2187', '64448/6561', '-212/729'), - ('9017/3168', '-355/33', '46732/5247', '49/176', '-5103/18656'), - ('35/384', 0, '500/1113', '125/192', '-2187/6784', '11/84')] - B1 = ['35/384', 0, '500/1113', '125/192', '-2187/6784', '11/84', 0] - B2 = ['5179/57600', 0, '7571/16695', '393/640', '-92097/339200', '187/2100', 0.025] - C = [0, 0.2, 0.3, 0.8, '8/9', 1, 1] + r"""The Dormand–Prince method for ODEs. + + The DOPRI method, is an explicit method for solving ordinary differential equations + (Dormand & Prince 1980). The Dormand–Prince method has seven stages, but it uses only + six function evaluations per step because it has the FSAL (First Same As Last) property: + the last stage is evaluated at the same point as the first stage of the next step. + Dormand and Prince chose the coefficients of their method to minimize the error of + the fifth-order solution. This is the main difference with the Fehlberg method, which + was constructed so that the fourth-order solution has a small error. For this reason, + the Dormand–Prince method is more suitable when the higher-order solution is used to + continue the integration, a practice known as local extrapolation + (Shampine 1986; Hairer, Nørsett & Wanner 2008, pp. 178–179). + + Its Butcher table is: + + .. math:: + + \begin{array}{l|llllll} + 0 & \\ + 1 / 5 & 1 / 5 & & & \\ + 3 / 10 & 3 / 40 & 9 / 40 & & & \\ + 4 / 5 & 44 / 45 & -56 / 15 & 32 / 9 & & \\ + 8 / 9 & 19372 / 6561 & -25360 / 2187 & 64448 / 6561 & -212 / 729 & \\ + 1 & 9017 / 3168 & -355 / 33 & 46732 / 5247 & 49 / 176 & -5103 / 18656 & \\ + 1 & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & \\ + \hline & 35 / 384 & 0 & 500 / 1113 & 125 / 192 & -2187 / 6784 & 11 / 84 & 0 \\ + & 5179 / 57600 & 0 & 7571 / 16695 & 393 / 640 & -92097 / 339200 & 187 / 2100 & 1 / 40 + \end{array} + + References:: + + .. [1] https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method + .. [2] Dormand, J. R.; Prince, P. J. (1980), "A family of embedded Runge-Kutta formulae", + Journal of Computational and Applied Mathematics, 6 (1): 19–26, + doi:10.1016/0771-050X(80)90013-3. + """ + + A = [(), + (0.2,), + (0.075, 0.225), + ('44/45', '-56/15', '32/9'), + ('19372/6561', '-25360/2187', '64448/6561', '-212/729'), + ('9017/3168', '-355/33', '46732/5247', '49/176', '-5103/18656'), + ('35/384', 0, '500/1113', '125/192', '-2187/6784', '11/84')] + B1 = ['35/384', 0, '500/1113', '125/192', '-2187/6784', '11/84', 0] + B2 = ['5179/57600', 0, '7571/16695', '393/640', '-92097/339200', '187/2100', 0.025] + C = [0, 0.2, 0.3, 0.8, '8/9', 1, 1] register_ode_integrator('rkdp', DormandPrince) class CashKarp(AdaptiveRKIntegrator): - r"""The Cash–Karp method for ODEs. - - The Cash–Karp method was proposed by Professor Jeff R. Cash from Imperial College London - and Alan H. Karp from IBM Scientific Center. it uses six function evaluations to calculate - fourth- and fifth-order accurate solutions. The difference between these solutions is then - taken to be the error of the (fourth order) solution. This error estimate is very convenient - for adaptive stepsize integration algorithms. - - It has the characteristics of: - - - method stage = 6 - - method order = 4 - - Butcher Tables: - - .. math:: - - \begin{array}{l|lllll} - 0 & & & & & & \\ - 1 / 5 & 1 / 5 & & & & & \\ - 3 / 10 & 3 / 40 & 9 / 40 & & & \\ - 3 / 5 & 3 / 10 & -9 / 10 & 6 / 5 & & \\ - 1 & -11 / 54 & 5 / 2 & -70 / 27 & 35 / 27 & & \\ - 7 / 8 & 1631 / 55296 & 175 / 512 & 575 / 13824 & 44275 / 110592 & 253 / 4096 & \\ - \hline & 37 / 378 & 0 & 250 / 621 & 125 / 594 & 0 & 512 / 1771 \\ - & 2825 / 27648 & 0 & 18575 / 48384 & 13525 / 55296 & 277 / 14336 & 1 / 4 - \end{array} - - References:: - - .. [1] https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method - .. [2] J. R. Cash, A. H. Karp. "A variable order Runge-Kutta method for initial value - problems with rapidly varying right-hand sides", ACM Transactions on Mathematical - Software 16: 201-222, 1990. doi:10.1145/79505.79507 - """ - - A = [(), - (0.2,), - (0.075, 0.225), - (0.3, -0.9, 1.2), - ('-11/54', 2.5, '-70/27', '35/27'), - ('1631/55296', '175/512', '575/13824', '44275/110592', '253/4096')] - B1 = ['37/378', 0, '250/621', '125/594', 0, '512/1771'] - B2 = ['2825/27648', 0, '18575/48384', '13525/55296', '277/14336', 0.25] - C = [0, 0.2, 0.3, 0.6, 1, 0.875] + r"""The Cash–Karp method for ODEs. + + The Cash–Karp method was proposed by Professor Jeff R. Cash from Imperial College London + and Alan H. Karp from IBM Scientific Center. it uses six function evaluations to calculate + fourth- and fifth-order accurate solutions. The difference between these solutions is then + taken to be the error of the (fourth order) solution. This error estimate is very convenient + for adaptive stepsize integration algorithms. + + It has the characteristics of: + + - method stage = 6 + - method order = 4 + - Butcher Tables: + + .. math:: + + \begin{array}{l|lllll} + 0 & & & & & & \\ + 1 / 5 & 1 / 5 & & & & & \\ + 3 / 10 & 3 / 40 & 9 / 40 & & & \\ + 3 / 5 & 3 / 10 & -9 / 10 & 6 / 5 & & \\ + 1 & -11 / 54 & 5 / 2 & -70 / 27 & 35 / 27 & & \\ + 7 / 8 & 1631 / 55296 & 175 / 512 & 575 / 13824 & 44275 / 110592 & 253 / 4096 & \\ + \hline & 37 / 378 & 0 & 250 / 621 & 125 / 594 & 0 & 512 / 1771 \\ + & 2825 / 27648 & 0 & 18575 / 48384 & 13525 / 55296 & 277 / 14336 & 1 / 4 + \end{array} + + References:: + + .. [1] https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method + .. [2] J. R. Cash, A. H. Karp. "A variable order Runge-Kutta method for initial value + problems with rapidly varying right-hand sides", ACM Transactions on Mathematical + Software 16: 201-222, 1990. doi:10.1145/79505.79507 + """ + + A = [(), + (0.2,), + (0.075, 0.225), + (0.3, -0.9, 1.2), + ('-11/54', 2.5, '-70/27', '35/27'), + ('1631/55296', '175/512', '575/13824', '44275/110592', '253/4096')] + B1 = ['37/378', 0, '250/621', '125/594', 0, '512/1771'] + B2 = ['2825/27648', 0, '18575/48384', '13525/55296', '277/14336', 0.25] + C = [0, 0.2, 0.3, 0.6, 1, 0.875] register_ode_integrator('ck', CashKarp) class BogackiShampine(AdaptiveRKIntegrator): - r"""The Bogacki–Shampine method for ODEs. + r"""The Bogacki–Shampine method for ODEs. - The Bogacki–Shampine method was proposed by Przemysław Bogacki and Lawrence F. - Shampine in 1989 (Bogacki & Shampine 1989). The Bogacki–Shampine method is a - Runge–Kutta method of order three with four stages with the First Same As Last - (FSAL) property, so that it uses approximately three function evaluations per - step. It has an embedded second-order method which can be used to implement adaptive step size. + The Bogacki–Shampine method was proposed by Przemysław Bogacki and Lawrence F. + Shampine in 1989 (Bogacki & Shampine 1989). The Bogacki–Shampine method is a + Runge–Kutta method of order three with four stages with the First Same As Last + (FSAL) property, so that it uses approximately three function evaluations per + step. It has an embedded second-order method which can be used to implement adaptive step size. - It has the characteristics of: + It has the characteristics of: - - method stage = 4 - - method order = 3 - - Butcher Tables: + - method stage = 4 + - method order = 3 + - Butcher Tables: - .. math:: + .. math:: - \begin{array}{l|lll} - 0 & & & \\ - 1 / 2 & 1 / 2 & & \\ - 3 / 4 & 0 & 3 / 4 & \\ - 1 & 2 / 9 & 1 / 3 & 4 / 9 \\ - \hline & 2 / 9 & 1 / 3 & 4 / 90 \\ - & 7 / 24 & 1 / 4 & 1 / 3 & 1 / 8 - \end{array} + \begin{array}{l|lll} + 0 & & & \\ + 1 / 2 & 1 / 2 & & \\ + 3 / 4 & 0 & 3 / 4 & \\ + 1 & 2 / 9 & 1 / 3 & 4 / 9 \\ + \hline & 2 / 9 & 1 / 3 & 4 / 90 \\ + & 7 / 24 & 1 / 4 & 1 / 3 & 1 / 8 + \end{array} - References:: + References:: - .. [1] https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method - .. [2] Bogacki, Przemysław; Shampine, Lawrence F. (1989), "A 3(2) pair of Runge–Kutta - formulas", Applied Mathematics Letters, 2 (4): 321–325, doi:10.1016/0893-9659(89)90079-7 - """ + .. [1] https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method + .. [2] Bogacki, Przemysław; Shampine, Lawrence F. (1989), "A 3(2) pair of Runge–Kutta + formulas", Applied Mathematics Letters, 2 (4): 321–325, doi:10.1016/0893-9659(89)90079-7 + """ - A = [(), - (0.5,), - (0., 0.75), - ('2/9', '1/3', '4/0'), ] - B1 = ['2/9', '1/3', '4/9', 0] - B2 = ['7/24', 0.25, '1/3', 0.125] - C = [0, 0.5, 0.75, 1] + A = [(), + (0.5,), + (0., 0.75), + ('2/9', '1/3', '4/0'), ] + B1 = ['2/9', '1/3', '4/9', 0] + B2 = ['7/24', 0.25, '1/3', 0.125] + C = [0, 0.5, 0.75, 1] register_ode_integrator('bs', BogackiShampine) class HeunEuler(AdaptiveRKIntegrator): - r"""The Heun–Euler method for ODEs. + r"""The Heun–Euler method for ODEs. - The simplest adaptive Runge–Kutta method involves combining Heun's method, - which is order 2, with the Euler method, which is order 1. + The simplest adaptive Runge–Kutta method involves combining Heun's method, + which is order 2, with the Euler method, which is order 1. - It has the characteristics of: + It has the characteristics of: - - method stage = 2 - - method order = 1 - - Butcher Tables: + - method stage = 2 + - method order = 1 + - Butcher Tables: - .. math:: + .. math:: - \begin{array}{c|cc} - 0&\\ - 1& 1 \\ - \hline - & 1/2& 1/2\\ - & 1 & 0 - \end{array} + \begin{array}{c|cc} + 0&\\ + 1& 1 \\ + \hline + & 1/2& 1/2\\ + & 1 & 0 + \end{array} - """ + """ - A = [(), (1,)] - B1 = [0.5, 0.5] - B2 = [1, 0] - C = [0, 1] + A = [(), (1,)] + B1 = [0.5, 0.5] + B2 = [1, 0] + C = [0, 1] register_ode_integrator('heun_euler', HeunEuler) class DOP853(AdaptiveRKIntegrator): - # def DOP853(f=None, tol=None, adaptive=None, dt=None, show_code=None, each_var_is_scalar=None): - r"""The DOP853 method for ODEs. + # def DOP853(f=None, tol=None, adaptive=None, dt=None, show_code=None, each_var_is_scalar=None): + r"""The DOP853 method for ODEs. - DOP853 is an explicit Runge-Kutta method of order 8(5,3) due to Dormand & Prince - (with stepsize control and dense output). + DOP853 is an explicit Runge-Kutta method of order 8(5,3) due to Dormand & Prince + (with stepsize control and dense output). - References:: + References:: - .. [1] E. Hairer, S.P. Norsett and G. Wanner, "Solving ordinary Differential Equations - I. Nonstiff Problems", 2nd edition. Springer Series in Computational Mathematics, - Springer-Verlag (1993). - .. [2] http://www.unige.ch/~hairer/software.html - """ - pass + .. [1] E. Hairer, S.P. Norsett and G. Wanner, "Solving ordinary Differential Equations + I. Nonstiff Problems", 2nd edition. Springer Series in Computational Mathematics, + Springer-Verlag (1993). + .. [2] http://www.unige.ch/~hairer/software.html + """ + pass class BoSh3(AdaptiveRKIntegrator): - """ - Bogacki--Shampine's 3/2 method. - - 3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for - adaptive step sizing. - - """ - A = [(), - (0.5,), - (0.0, 0.75), - ('2/9', '1/3', '4/9')] - B1 = ['2/9', '1/3', '4/9', 0.0] - B2 = ['-5/72', 1 / 12, '1/9', '-1/8'] - C = [0., 0.5, 0.75, 1.0] + """ + Bogacki--Shampine's 3/2 method. + + 3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for + adaptive step sizing. + + """ + A = [(), + (0.5,), + (0.0, 0.75), + ('2/9', '1/3', '4/9')] + B1 = ['2/9', '1/3', '4/9', 0.0] + B2 = ['-5/72', 1 / 12, '1/9', '-1/8'] + C = [0., 0.5, 0.75, 1.0] register_ode_integrator('BoSh3', BoSh3) diff --git a/brainpy/_src/integrators/ode/base.py b/brainpy/_src/integrators/ode/base.py index 09494caab..02e923c36 100644 --- a/brainpy/_src/integrators/ode/base.py +++ b/brainpy/_src/integrators/ode/base.py @@ -3,148 +3,148 @@ from typing import Dict, Callable, Union -from brainpy.errors import DiffEqError, CodeError from brainpy._src import math as bm from brainpy._src.integrators import constants, utils from brainpy._src.integrators.base import Integrator from brainpy._src.integrators.constants import DT from brainpy.check import is_dict_data +from brainpy.errors import DiffEqError, CodeError __all__ = [ - 'ODEIntegrator', + 'ODEIntegrator', ] def f_names(f): - func_name = constants.unique_name('ode') - if f.__name__.isidentifier(): - func_name += '_' + f.__name__ - return func_name + func_name = constants.unique_name('ode') + if f.__name__.isidentifier(): + func_name += '_' + f.__name__ + return func_name class ODEIntegrator(Integrator): - """Numerical Integrator for Ordinary Differential Equations (ODEs). - - Parameters:: - - f : callable - The derivative function. - var_type: str - The type for each variable. - dt: float, int - The numerical precision. - name: str - The integrator name. - """ - - def __init__( - self, - f: Callable, - var_type: str = None, - dt: float = None, - name: str = None, - show_code: bool = False, - state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, - neutral_delays: Dict[str, Union[bm.NeuTimeDelay, bm.NeuLenDelay]] = None - ): - - dt = bm.get_dt() if dt is None else dt - parses = utils.get_args(f) - variables = parses[0] # variable names, (before 't') - parameters = parses[1] # parameter names, (after 't') - arguments = parses[2] # function arguments - - for p in tuple(variables) + tuple(parameters): - if p == DT: - raise CodeError(f'{DT} is a system keyword denotes the ' - f'precision of numerical integration. ' - f'It cannot be used as a variable or parameter, ' - f'please change an another name.') - - # super initialization - super(ODEIntegrator, self).__init__(name=name, - variables=variables, - parameters=parameters, - arguments=arguments, - dt=dt, - state_delays=state_delays) - - # others - self.show_code = show_code - self.var_type = var_type # variable type - - # derivative function - self.derivative = {constants.F: f} - self.f = f - - # code scope - self.code_scope = {constants.F: f} - - # code lines - self.func_name = f_names(f) - self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):'] - - # neutral delays - self._neutral_delays = dict() - if neutral_delays is not None: - is_dict_data(neutral_delays, key_type=str, val_type=bm.NeuTimeDelay) - for key, delay in neutral_delays.items(): - if key not in self.variables: - raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') - self._neutral_delays[key] = delay - self.register_implicit_nodes(self._neutral_delays) - - @property - def neutral_delays(self): - """neutral delays.""" - return self._neutral_delays - - @neutral_delays.setter - def neutral_delays(self, value): - raise ValueError('Cannot set "neutral_delays" by users.') - - def __call__(self, *args, **kwargs): - assert self.integral is not None, 'Please build the integrator first.' - - # check arguments - for i, arg in enumerate(args): - kwargs[self.arg_names[i]] = arg - - # integral - new_vars = self._call_integral(**kwargs) - if len(self.variables) == 1: - dict_vars = {self.variables[0]: new_vars} - else: - dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} - - dt = kwargs.pop(DT, self.dt) - # update neutral delay variables - if len(self.neutral_delays): - kwargs.update(dict_vars) - new_devs = self.f(**kwargs) - if len(self.variables) == 1: - new_devs = {self.variables[0]: new_devs} - else: - new_devs = {k: new_devs[i] for i, k in enumerate(self.variables)} - for key, delay in self.neutral_delays.items(): - if isinstance(delay, bm.NeuLenDelay): - delay.update(new_devs[key]) - elif isinstance(delay, bm.NeuTimeDelay): - delay.update(kwargs['t'] + dt, new_devs[key]) + """Numerical Integrator for Ordinary Differential Equations (ODEs). + + Parameters:: + + f : callable + The derivative function. + var_type: str + The type for each variable. + dt: float, int + The numerical precision. + name: str + The integrator name. + """ + + def __init__( + self, + f: Callable, + var_type: str = None, + dt: float = None, + name: str = None, + show_code: bool = False, + state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None, + neutral_delays: Dict[str, Union[bm.NeuTimeDelay, bm.NeuLenDelay]] = None + ): + + dt = bm.get_dt() if dt is None else dt + parses = utils.get_args(f) + variables = parses[0] # variable names, (before 't') + parameters = parses[1] # parameter names, (after 't') + arguments = parses[2] # function arguments + + for p in tuple(variables) + tuple(parameters): + if p == DT: + raise CodeError(f'{DT} is a system keyword denotes the ' + f'precision of numerical integration. ' + f'It cannot be used as a variable or parameter, ' + f'please change an another name.') + + # super initialization + super(ODEIntegrator, self).__init__(name=name, + variables=variables, + parameters=parameters, + arguments=arguments, + dt=dt, + state_delays=state_delays) + + # others + self.show_code = show_code + self.var_type = var_type # variable type + + # derivative function + self.derivative = {constants.F: f} + self.f = f + + # code scope + self.code_scope = {constants.F: f} + + # code lines + self.func_name = f_names(f) + self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):'] + + # neutral delays + self._neutral_delays = dict() + if neutral_delays is not None: + is_dict_data(neutral_delays, key_type=str, val_type=bm.NeuTimeDelay) + for key, delay in neutral_delays.items(): + if key not in self.variables: + raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') + self._neutral_delays[key] = delay + self.register_implicit_nodes(self._neutral_delays) + + @property + def neutral_delays(self): + """neutral delays.""" + return self._neutral_delays + + @neutral_delays.setter + def neutral_delays(self, value): + raise ValueError('Cannot set "neutral_delays" by users.') + + def __call__(self, *args, **kwargs): + assert self.integral is not None, 'Please build the integrator first.' + + # check arguments + for i, arg in enumerate(args): + kwargs[self.arg_names[i]] = arg + + # integral + new_vars = self._call_integral(**kwargs) + if len(self.variables) == 1: + dict_vars = {self.variables[0]: new_vars} else: - raise ValueError('Unknown delay variable. We only supports ' - f'{bm.NeuTimeDelay.__name__} and {bm.NeuLenDelay.__name__}. ' - f'While we got {delay}') - - # update state delay variables - for key, delay in self.state_delays.items(): - if isinstance(delay, bm.LengthDelay): - delay.update(dict_vars[key]) - elif isinstance(delay, bm.TimeDelay): - delay.update(dict_vars[key]) - else: - raise ValueError('Unknown delay variable. We only supports ' - f'{bm.LengthDelay.__name__} and {bm.TimeDelay.__name__}. ' - f'While we got {delay}') - - return new_vars + dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} + + dt = kwargs.pop(DT, self.dt) + # update neutral delay variables + if len(self.neutral_delays): + kwargs.update(dict_vars) + new_devs = self.f(**kwargs) + if len(self.variables) == 1: + new_devs = {self.variables[0]: new_devs} + else: + new_devs = {k: new_devs[i] for i, k in enumerate(self.variables)} + for key, delay in self.neutral_delays.items(): + if isinstance(delay, bm.NeuLenDelay): + delay.update(new_devs[key]) + elif isinstance(delay, bm.NeuTimeDelay): + delay.update(kwargs['t'] + dt, new_devs[key]) + else: + raise ValueError('Unknown delay variable. We only supports ' + f'{bm.NeuTimeDelay.__name__} and {bm.NeuLenDelay.__name__}. ' + f'While we got {delay}') + + # update state delay variables + for key, delay in self.state_delays.items(): + if isinstance(delay, bm.LengthDelay): + delay.update(dict_vars[key]) + elif isinstance(delay, bm.TimeDelay): + delay.update(dict_vars[key]) + else: + raise ValueError('Unknown delay variable. We only supports ' + f'{bm.LengthDelay.__name__} and {bm.TimeDelay.__name__}. ' + f'While we got {delay}') + + return new_vars diff --git a/brainpy/_src/integrators/ode/common.py b/brainpy/_src/integrators/ode/common.py index 5fa248fe8..449f23b32 100644 --- a/brainpy/_src/integrators/ode/common.py +++ b/brainpy/_src/integrators/ode/common.py @@ -2,53 +2,52 @@ def step(vars, dt_var, A, C, code_lines, other_args): - # steps - for si, sval in enumerate(A): - # k-step arguments - k_args = [] - for v in vars: - k_arg = f'{v}' - for j, sv in enumerate(sval): - if sv not in [0., '0.0', '0.', '0']: - if sv in ['1.0', '1.', '1', 1.]: - k_arg += f' + {dt_var} * d{v}_k{j + 1}' - else: - k_arg += f' + {dt_var} * d{v}_k{j + 1} * {sv}' - if k_arg != v: - name = f'k{si + 1}_{v}_arg' - code_lines.append(f' {name} = {k_arg}') - k_args.append(name) - else: - k_args.append(v) - - t_arg = 't' - if C[si] not in [0., '0.', '0']: - if C[si] in ['1.', '1', 1.]: - t_arg += f' + {dt_var}' - else: - t_arg += f' + {dt_var} * {C[si]}' - name = f'k{si + 1}_t_arg' - code_lines.append(f' {name} = {t_arg}') - k_args.append(name) - else: - k_args.append(t_arg) - - # k-step derivative names - k_derivatives = [f'd{v}_k{si + 1}' for v in vars] - - # k-step code line - code_lines.append(f' {", ".join(k_derivatives)} = f(' - f'{", ".join(k_args + other_args[1:])})') + # steps + for si, sval in enumerate(A): + # k-step arguments + k_args = [] + for v in vars: + k_arg = f'{v}' + for j, sv in enumerate(sval): + if sv not in [0., '0.0', '0.', '0']: + if sv in ['1.0', '1.', '1', 1.]: + k_arg += f' + {dt_var} * d{v}_k{j + 1}' + else: + k_arg += f' + {dt_var} * d{v}_k{j + 1} * {sv}' + if k_arg != v: + name = f'k{si + 1}_{v}_arg' + code_lines.append(f' {name} = {k_arg}') + k_args.append(name) + else: + k_args.append(v) + + t_arg = 't' + if C[si] not in [0., '0.', '0']: + if C[si] in ['1.', '1', 1.]: + t_arg += f' + {dt_var}' + else: + t_arg += f' + {dt_var} * {C[si]}' + name = f'k{si + 1}_t_arg' + code_lines.append(f' {name} = {t_arg}') + k_args.append(name) + else: + k_args.append(t_arg) + + # k-step derivative names + k_derivatives = [f'd{v}_k{si + 1}' for v in vars] + + # k-step code line + code_lines.append(f' {", ".join(k_derivatives)} = f(' + f'{", ".join(k_args + other_args[1:])})') def update(vars, dt_var, B, code_lines): - return_args = [] - for v in vars: - result = v - for i, b1 in enumerate(B): - if b1 not in [0., '0.', '0']: - result += f' + d{v}_k{i + 1} * {dt_var} * {b1}' - code_lines.append(f' {v}_new = {result}') - return_args.append(f'{v}_new') - return return_args - + return_args = [] + for v in vars: + result = v + for i, b1 in enumerate(B): + if b1 not in [0., '0.', '0']: + result += f' + d{v}_k{i + 1} * {dt_var} * {b1}' + code_lines.append(f' {v}_new = {result}') + return_args.append(f'{v}_new') + return return_args diff --git a/brainpy/_src/integrators/ode/explicit_rk.py b/brainpy/_src/integrators/ode/explicit_rk.py index 6088e1f08..44f5aeafa 100644 --- a/brainpy/_src/integrators/ode/explicit_rk.py +++ b/brainpy/_src/integrators/ode/explicit_rk.py @@ -73,798 +73,798 @@ from .generic import register_ode_integrator __all__ = [ - 'ExplicitRKIntegrator', - 'Euler', - 'MidPoint', - 'Heun2', - 'Ralston2', - 'RK2', - 'RK3', - 'Heun3', - 'Ralston3', - 'SSPRK3', - 'RK4', - 'Ralston4', - 'RK4Rule38', + 'ExplicitRKIntegrator', + 'Euler', + 'MidPoint', + 'Heun2', + 'Ralston2', + 'RK2', + 'RK3', + 'Heun3', + 'Ralston3', + 'SSPRK3', + 'RK4', + 'Ralston4', + 'RK4Rule38', ] class ExplicitRKIntegrator(ODEIntegrator): - r"""Explicit Runge–Kutta methods for ordinary differential equation. - - For the system, - - .. math:: - - \frac{d y}{d t}=f(t, y) - - - Explicit Runge-Kutta methods take the form - - .. math:: - - k_{i}=f\left(t_{n}+c_{i}h,y_{n}+h\sum _{j=1}^{s}a_{ij}k_{j}\right) \\ - y_{n+1}=y_{n}+h \sum_{i=1}^{s} b_{i} k_{i} - - Each method listed on this page is defined by its Butcher tableau, - which puts the coefficients of the method in a table as follows: - - .. math:: - - \begin{array}{c|cccc} - c_{1} & a_{11} & a_{12} & \ldots & a_{1 s} \\ - c_{2} & a_{21} & a_{22} & \ldots & a_{2 s} \\ - \vdots & \vdots & \vdots & \ddots & \vdots \\ - c_{s} & a_{s 1} & a_{s 2} & \ldots & a_{s s} \\ - \hline & b_{1} & b_{2} & \ldots & b_{s} - \end{array} - - Parameters:: - - f : callable - The derivative function. - show_code : bool - Whether show the formatted code. - dt : float - The numerical precision. - """ - A = [] # The A matrix in the Butcher tableau. - B = [] # The B vector in the Butcher tableau. - C = [] # The C vector in the Butcher tableau. - - def __init__(self, - f, - var_type=None, - dt=None, - name=None, - show_code=False, - state_delays=None, - neutral_delays=None): - super().__init__(f=f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays) - - # integrator keywords - keywords = { - C.F: 'the derivative function', - # C.DT: 'the precision of numerical integration' - } - for v in self.variables: - keywords[f'{v}_new'] = 'the intermediate value' - for i in range(1, len(self.A) + 1): - keywords[f'd{v}_k{i}'] = 'the intermediate value' - for i in range(2, len(self.A) + 1): - keywords[f'k{i}_{v}_arg'] = 'the intermediate value' - keywords[f'k{i}_t_arg'] = 'the intermediate value' - utils.check_kws(self.arg_names, keywords) - self.build() - - def build(self): - # step stage - common.step(self.variables, C.DT, - self.A, self.C, self.code_lines, self.parameters) - # variable update - return_args = common.update(self.variables, C.DT, self.B, self.code_lines) - # returns - self.code_lines.append(f' return {", ".join(return_args)}') - # compile - self.integral = utils.compile_code( - code_scope={k: v for k, v in self.code_scope.items()}, - code_lines=self.code_lines, - show_code=self.show_code, - func_name=self.func_name) + r"""Explicit Runge–Kutta methods for ordinary differential equation. + For the system, -class Euler(ExplicitRKIntegrator): - r"""The Euler method for ODEs. + .. math:: + + \frac{d y}{d t}=f(t, y) - Also named as `Forward Euler method`, or `Explicit Euler` method. - Given an ODE system, + Explicit Runge-Kutta methods take the form .. math:: - y'(t)=f(t,y(t)),\qquad y(t_{0})=y_{0}, + k_{i}=f\left(t_{n}+c_{i}h,y_{n}+h\sum _{j=1}^{s}a_{ij}k_{j}\right) \\ + y_{n+1}=y_{n}+h \sum_{i=1}^{s} b_{i} k_{i} - by using Euler method [1]_, we should choose a value :math:`h` for the - size of every step and set :math:`t_{n}=t_{0}+nh`. Now, one step - of the Euler method from :math:`t_{n}` to :math:`t_{n+1}=t_{n}+h` is: + Each method listed on this page is defined by its Butcher tableau, + which puts the coefficients of the method in a table as follows: - .. math:: + .. math:: - y_{n+1}=y_{n}+hf(t_{n},y_{n}). + \begin{array}{c|cccc} + c_{1} & a_{11} & a_{12} & \ldots & a_{1 s} \\ + c_{2} & a_{21} & a_{22} & \ldots & a_{2 s} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + c_{s} & a_{s 1} & a_{s 2} & \ldots & a_{s s} \\ + \hline & b_{1} & b_{2} & \ldots & b_{s} + \end{array} + + Parameters:: + + f : callable + The derivative function. + show_code : bool + Whether show the formatted code. + dt : float + The numerical precision. + """ + A = [] # The A matrix in the Butcher tableau. + B = [] # The B vector in the Butcher tableau. + C = [] # The C vector in the Butcher tableau. + + def __init__(self, + f, + var_type=None, + dt=None, + name=None, + show_code=False, + state_delays=None, + neutral_delays=None): + super().__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) + + # integrator keywords + keywords = { + C.F: 'the derivative function', + # C.DT: 'the precision of numerical integration' + } + for v in self.variables: + keywords[f'{v}_new'] = 'the intermediate value' + for i in range(1, len(self.A) + 1): + keywords[f'd{v}_k{i}'] = 'the intermediate value' + for i in range(2, len(self.A) + 1): + keywords[f'k{i}_{v}_arg'] = 'the intermediate value' + keywords[f'k{i}_t_arg'] = 'the intermediate value' + utils.check_kws(self.arg_names, keywords) + self.build() + + def build(self): + # step stage + common.step(self.variables, C.DT, + self.A, self.C, self.code_lines, self.parameters) + # variable update + return_args = common.update(self.variables, C.DT, self.B, self.code_lines) + # returns + self.code_lines.append(f' return {", ".join(return_args)}') + # compile + self.integral = utils.compile_code( + code_scope={k: v for k, v in self.code_scope.items()}, + code_lines=self.code_lines, + show_code=self.show_code, + func_name=self.func_name) - Note that the method increments a solution through an interval :math:`h` - while using derivative information from only the beginning of the interval. - As a result, the step's error is :math:`O(h^2)`. - **Geometric interpretation** +class Euler(ExplicitRKIntegrator): + r"""The Euler method for ODEs. - Illustration of the Euler method. The unknown curve is in blue, - and its polygonal approximation is in red [2]_: + Also named as `Forward Euler method`, or `Explicit Euler` method. - .. image:: ../../../../_static/ode_Euler_method.svg - :align: center + Given an ODE system, - **Derivation** + .. math:: - There are several ways to get Euler method [2]_. + y'(t)=f(t,y(t)),\qquad y(t_{0})=y_{0}, - The first is to consider the Taylor expansion of the function :math:`y` - around :math:`t_{0}`: + by using Euler method [1]_, we should choose a value :math:`h` for the + size of every step and set :math:`t_{n}=t_{0}+nh`. Now, one step + of the Euler method from :math:`t_{n}` to :math:`t_{n+1}=t_{n}+h` is: - .. math:: + .. math:: + + y_{n+1}=y_{n}+hf(t_{n},y_{n}). - y(t_{0}+h)=y(t_{0})+hy'(t_{0})+{\frac {1}{2}}h^{2}y''(t_{0})+O(h^{3}). + Note that the method increments a solution through an interval :math:`h` + while using derivative information from only the beginning of the interval. + As a result, the step's error is :math:`O(h^2)`. - where :math:`y'(t_0)=f(t_0,y)`. We ignore the quadratic and higher-order - terms, then we get Euler method. The Taylor expansion is used below to - analyze the error committed by the Euler method, and it can be extended - to produce Runge–Kutta methods. + **Geometric interpretation** - The second way is to replace the derivative with the forward finite - difference formula: + Illustration of the Euler method. The unknown curve is in blue, + and its polygonal approximation is in red [2]_: - .. math:: + .. image:: ../../../../_static/ode_Euler_method.svg + :align: center - y'(t_{0})\approx {\frac {y(t_{0}+h)-y(t_{0})}{h}}. + **Derivation** - The third method is integrate the differential equation from :math:`t_{0}` - to :math:`t_{0}+h` and apply the fundamental theorem of calculus to get: + There are several ways to get Euler method [2]_. - .. math:: + The first is to consider the Taylor expansion of the function :math:`y` + around :math:`t_{0}`: + + .. math:: - y(t_{0}+h)-y(t_{0})=\int _{t_{0}}^{t_{0}+h}f(t,y(t))\,\mathrm {d} t \approx hf(t_{0},y(t_{0})). + y(t_{0}+h)=y(t_{0})+hy'(t_{0})+{\frac {1}{2}}h^{2}y''(t_{0})+O(h^{3}). + + where :math:`y'(t_0)=f(t_0,y)`. We ignore the quadratic and higher-order + terms, then we get Euler method. The Taylor expansion is used below to + analyze the error committed by the Euler method, and it can be extended + to produce Runge–Kutta methods. + + The second way is to replace the derivative with the forward finite + difference formula: + + .. math:: + + y'(t_{0})\approx {\frac {y(t_{0}+h)-y(t_{0})}{h}}. + + The third method is integrate the differential equation from :math:`t_{0}` + to :math:`t_{0}+h` and apply the fundamental theorem of calculus to get: + + .. math:: - **Note** + y(t_{0}+h)-y(t_{0})=\int _{t_{0}}^{t_{0}+h}f(t,y(t))\,\mathrm {d} t \approx hf(t_{0},y(t_{0})). - Euler method is a first order numerical procedure for solving - ODEs with a given initial value. The lack of stability - and accuracy limits its popularity mainly to use as a - simple introductory example of a numeric solution method. + **Note** - References:: + Euler method is a first order numerical procedure for solving + ODEs with a given initial value. The lack of stability + and accuracy limits its popularity mainly to use as a + simple introductory example of a numeric solution method. - .. [1] W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, - W. T. Numerical Recipes in FORTRAN: The Art of Scientific - Computing, 2nd ed. Cambridge, England: Cambridge University - Press, p. 710, 1992. - .. [2] https://en.wikipedia.org/wiki/Euler_method - """ - A = [(), ] - B = [1] - C = [0] + References:: + + .. [1] W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, + W. T. Numerical Recipes in FORTRAN: The Art of Scientific + Computing, 2nd ed. Cambridge, England: Cambridge University + Press, p. 710, 1992. + .. [2] https://en.wikipedia.org/wiki/Euler_method + """ + A = [(), ] + B = [1] + C = [0] register_ode_integrator('euler', Euler) class MidPoint(ExplicitRKIntegrator): - r"""Explicit midpoint method for ODEs. + r"""Explicit midpoint method for ODEs. - Also known as the `modified Euler method` [1]_. + Also known as the `modified Euler method` [1]_. - The midpoint method is a one-step method for numerically solving - the differential equation given by: + The midpoint method is a one-step method for numerically solving + the differential equation given by: - .. math:: + .. math:: - y'(t) = f(t, y(t)), \quad y(t_0) = y_0 . + y'(t) = f(t, y(t)), \quad y(t_0) = y_0 . - The formula of the explicit midpoint method is: + The formula of the explicit midpoint method is: - .. math:: + .. math:: - y_{n+1} = y_n + hf\left(t_n+\frac{h}{2},y_n+\frac{h}{2}f(t_n, y_n)\right). + y_{n+1} = y_n + hf\left(t_n+\frac{h}{2},y_n+\frac{h}{2}f(t_n, y_n)\right). - Therefore, the Butcher tableau of the midpoint method is: + Therefore, the Butcher tableau of the midpoint method is: - .. math:: + .. math:: - \begin{array}{c|cc} - 0 & 0 & 0 \\ - 1 / 2 & 1 / 2 & 0 \\ - \hline & 0 & 1 - \end{array} + \begin{array}{c|cc} + 0 & 0 & 0 \\ + 1 / 2 & 1 / 2 & 0 \\ + \hline & 0 & 1 + \end{array} - **Derivation** + **Derivation** - Compared to the slope formula of Euler method :math:`y'(t) \approx \frac{y(t+h) - y(t)}{h}`, - the midpoint method use + Compared to the slope formula of Euler method :math:`y'(t) \approx \frac{y(t+h) - y(t)}{h}`, + the midpoint method use - .. math:: + .. math:: - y'\left(t+\frac{h}{2}\right) \approx \frac{y(t+h) - y(t)}{h}, + y'\left(t+\frac{h}{2}\right) \approx \frac{y(t+h) - y(t)}{h}, - The reason why we use this, please see the following geometric interpretation. - Then, we get + The reason why we use this, please see the following geometric interpretation. + Then, we get - .. math:: + .. math:: - y(t+h) \approx y(t) + hf\left(t+\frac{h}{2},y\left(t+\frac{h}{2}\right)\right). + y(t+h) \approx y(t) + hf\left(t+\frac{h}{2},y\left(t+\frac{h}{2}\right)\right). - However, we do not know :math:`y(t+h/2)`. The solution is then to use a Taylor - series expansion exactly as the Euler method to solve: + However, we do not know :math:`y(t+h/2)`. The solution is then to use a Taylor + series expansion exactly as the Euler method to solve: - .. math:: + .. math:: - y\left(t + \frac{h}{2}\right) \approx y(t) + \frac{h}{2}y'(t)=y(t) + \frac{h}{2}f(t, y(t)), + y\left(t + \frac{h}{2}\right) \approx y(t) + \frac{h}{2}y'(t)=y(t) + \frac{h}{2}f(t, y(t)), - Finally, we can get the final step function: + Finally, we can get the final step function: - .. math:: + .. math:: - y(t + h) \approx y(t) + hf\left(t + \frac{h}{2}, y(t) + \frac{h}{2}f(t, y(t))\right). + y(t + h) \approx y(t) + hf\left(t + \frac{h}{2}, y(t) + \frac{h}{2}f(t, y(t))\right). - **Geometric interpretation** + **Geometric interpretation** - In the basic Euler's method, the tangent of the curve at :math:`(t_{n},y_{n})` is computed - using :math:`f(t_{n},y_{n})`. The next value :math:`y_{n+1}` is found where the tangent - intersects the vertical line :math:`t=t_{n+1}`. However, if the second derivative is only - positive between :math:`t_{n}` and :math:`t_{n+1}`, or only negative, the curve will - increasingly veer away from the tangent, leading to larger errors as :math:`h` increases. + In the basic Euler's method, the tangent of the curve at :math:`(t_{n},y_{n})` is computed + using :math:`f(t_{n},y_{n})`. The next value :math:`y_{n+1}` is found where the tangent + intersects the vertical line :math:`t=t_{n+1}`. However, if the second derivative is only + positive between :math:`t_{n}` and :math:`t_{n+1}`, or only negative, the curve will + increasingly veer away from the tangent, leading to larger errors as :math:`h` increases. - Compared with the Euler method, midpoint method use the tangent at the midpoint (upper, green - line segment in the following figure [2]_), which would most likely give a more accurate - approximation of the curve in that interval. + Compared with the Euler method, midpoint method use the tangent at the midpoint (upper, green + line segment in the following figure [2]_), which would most likely give a more accurate + approximation of the curve in that interval. - .. image:: ../../../../_static/ode_Midpoint_method_illustration.png - :align: center + .. image:: ../../../../_static/ode_Midpoint_method_illustration.png + :align: center - Although this midpoint tangent could not be accurately calculated, we can estimate midpoint - value of :math:`y(t)` by using the original Euler's method. Finally, the improved tangent - is used to calculate the value of :math:`y_{n+1}` from :math:`y_{n}`. This last step is - represented by the red chord in the diagram. + Although this midpoint tangent could not be accurately calculated, we can estimate midpoint + value of :math:`y(t)` by using the original Euler's method. Finally, the improved tangent + is used to calculate the value of :math:`y_{n+1}` from :math:`y_{n}`. This last step is + represented by the red chord in the diagram. - **Note** + **Note** - Note that the red chord is not exactly parallel to the green segment (the true tangent), - due to the error in estimating the value of :math:`y(t)` at the midpoint. + Note that the red chord is not exactly parallel to the green segment (the true tangent), + due to the error in estimating the value of :math:`y(t)` at the midpoint. - References:: + References:: - .. [1] Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003. - .. [2] https://en.wikipedia.org/wiki/Midpoint_method - """ - A = [(), (0.5,)] - B = [0, 1] - C = [0, 0.5] + .. [1] Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003. + .. [2] https://en.wikipedia.org/wiki/Midpoint_method + """ + A = [(), (0.5,)] + B = [0, 1] + C = [0, 0.5] register_ode_integrator('midpoint', MidPoint) class Heun2(ExplicitRKIntegrator): - r"""Heun's method for ODEs. + r"""Heun's method for ODEs. - This method is named after Karl Heun [1]_. It is also known as - the `explicit trapezoid rule`, `improved Euler's method`, or `modified Euler's method`. + This method is named after Karl Heun [1]_. It is also known as + the `explicit trapezoid rule`, `improved Euler's method`, or `modified Euler's method`. - Given ODEs with a given initial value, + Given ODEs with a given initial value, - .. math:: - y'(t) = f(t,y(t)), \qquad y(t_0)=y_0, + .. math:: + y'(t) = f(t,y(t)), \qquad y(t_0)=y_0, - the two-stage Heun's method is formulated as: + the two-stage Heun's method is formulated as: - .. math:: - \tilde{y}_{n+1} = y_n + h f(t_n,y_n) + .. math:: + \tilde{y}_{n+1} = y_n + h f(t_n,y_n) - .. math:: - y_{n+1} = y_n + \frac{h}{2}[f(t_n, y_n) + f(t_{n+1},\tilde{y}_{n+1})], + .. math:: + y_{n+1} = y_n + \frac{h}{2}[f(t_n, y_n) + f(t_{n+1},\tilde{y}_{n+1})], - where :math:`h` is the step size and :math:`t_{n+1}=t_n+h`. + where :math:`h` is the step size and :math:`t_{n+1}=t_n+h`. - Therefore, the Butcher tableau of the two-stage Heun's method is: + Therefore, the Butcher tableau of the two-stage Heun's method is: - .. math:: - \begin{array}{c|cc} - 0.0 & 0.0 & 0.0 \\ - 1.0 & 1.0 & 0.0 \\ - \hline & 0.5 & 0.5 - \end{array} + .. math:: + \begin{array}{c|cc} + 0.0 & 0.0 & 0.0 \\ + 1.0 & 1.0 & 0.0 \\ + \hline & 0.5 & 0.5 + \end{array} - **Geometric interpretation** + **Geometric interpretation** - In the :py:func:`brainpy.integrators.ode.midpoint`, we have already known Euler - method has big estimation error because it uses the - line tangent to the function at the beginning of the interval :math:`t_n` as an - estimate of the slope of the function over the interval :math:`(t_n, t_{n+1})`. + In the :py:func:`brainpy.integrators.ode.midpoint`, we have already known Euler + method has big estimation error because it uses the + line tangent to the function at the beginning of the interval :math:`t_n` as an + estimate of the slope of the function over the interval :math:`(t_n, t_{n+1})`. - In order to address this problem, Heun's Method considers the tangent lines to - the solution curve at both ends of the interval (:math:`t_n` and :math:`t_{n+1}`), - one (:math:`f(t_n, y_n)`) which *underestimates*, and one - (:math:`f(t_{n+1},\tilde{y}_{n+1})`, approximated using Euler's Method) which - *overestimates* the ideal vertical coordinates. The ideal point lies approximately - halfway between the erroneous overestimation and underestimation, the average of the two slopes. + In order to address this problem, Heun's Method considers the tangent lines to + the solution curve at both ends of the interval (:math:`t_n` and :math:`t_{n+1}`), + one (:math:`f(t_n, y_n)`) which *underestimates*, and one + (:math:`f(t_{n+1},\tilde{y}_{n+1})`, approximated using Euler's Method) which + *overestimates* the ideal vertical coordinates. The ideal point lies approximately + halfway between the erroneous overestimation and underestimation, the average of the two slopes. - .. image:: ../../../../_static/ode_Heun2_Method_Diagram.jpg - :align: center + .. image:: ../../../../_static/ode_Heun2_Method_Diagram.jpg + :align: center - .. math:: - \begin{aligned} - {\text{Slope}}_{\text{left}}=&f(t_{n},y_{n}) \\ - {\text{Slope}}_{\text{right}}=&f(t_{n}+h,y_{n}+hf(t_{n},y_{n})) \\ - {\text{Slope}}_{\text{ideal}}=&{\frac {1}{2}}({\text{Slope}}_{\text{left}}+{\text{Slope}}_{\text{right}}) - \end{aligned} + .. math:: + \begin{aligned} + {\text{Slope}}_{\text{left}}=&f(t_{n},y_{n}) \\ + {\text{Slope}}_{\text{right}}=&f(t_{n}+h,y_{n}+hf(t_{n},y_{n})) \\ + {\text{Slope}}_{\text{ideal}}=&{\frac {1}{2}}({\text{Slope}}_{\text{left}}+{\text{Slope}}_{\text{right}}) + \end{aligned} - References:: + References:: - .. [1] Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003. - """ - A = [(), (1,)] - B = [0.5, 0.5] - C = [0, 1] + .. [1] Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003. + """ + A = [(), (1,)] + B = [0.5, 0.5] + C = [0, 1] register_ode_integrator('heun2', Heun2) class Ralston2(ExplicitRKIntegrator): - r"""Ralston's method for ODEs. + r"""Ralston's method for ODEs. - Ralston's method is a second-order method with two stages and - a minimum local error bound. + Ralston's method is a second-order method with two stages and + a minimum local error bound. - Given ODEs with a given initial value, + Given ODEs with a given initial value, - .. math:: - y'(t) = f(t,y(t)), \qquad y(t_0)=y_0, + .. math:: + y'(t) = f(t,y(t)), \qquad y(t_0)=y_0, - the Ralston's second order method is given by + the Ralston's second order method is given by - .. math:: - y_{n+1}=y_{n}+\frac{h}{4} f\left(t_{n}, y_{n}\right)+ - \frac{3 h}{4} f\left(t_{n}+\frac{2 h}{3}, y_{n}+\frac{2 h}{3} f\left(t_{n}, y_{n}\right)\right) + .. math:: + y_{n+1}=y_{n}+\frac{h}{4} f\left(t_{n}, y_{n}\right)+ + \frac{3 h}{4} f\left(t_{n}+\frac{2 h}{3}, y_{n}+\frac{2 h}{3} f\left(t_{n}, y_{n}\right)\right) - Therefore, the corresponding Butcher tableau is: + Therefore, the corresponding Butcher tableau is: - .. math:: - \begin{array}{c|cc} - 0 & 0 & 0 \\ - 2 / 3 & 2 / 3 & 0 \\ - \hline & 1 / 4 & 3 / 4 - \end{array} - """ - A = [(), ('2/3',)] - B = [0.25, 0.75] - C = [0, '2/3'] + .. math:: + \begin{array}{c|cc} + 0 & 0 & 0 \\ + 2 / 3 & 2 / 3 & 0 \\ + \hline & 1 / 4 & 3 / 4 + \end{array} + """ + A = [(), ('2/3',)] + B = [0.25, 0.75] + C = [0, '2/3'] register_ode_integrator('ralston2', Ralston2) class RK2(ExplicitRKIntegrator): - r"""Generic second order Runge-Kutta method for ODEs. + r"""Generic second order Runge-Kutta method for ODEs. - **Derivation** + **Derivation** - In the :py:func:`brainpy.integrators.ode.midpoint`, - :py:func:`brainpy.integrators.ode.heun2`, and :py:func:`brainpy.integrators.ode.ralston2`, - we have already known first-order Euler method :py:func:`brainpy.integrators.ode.euler` - has big estimation error. + In the :py:func:`brainpy.integrators.ode.midpoint`, + :py:func:`brainpy.integrators.ode.heun2`, and :py:func:`brainpy.integrators.ode.ralston2`, + we have already known first-order Euler method :py:func:`brainpy.integrators.ode.euler` + has big estimation error. - Here, we seek to derive a generic second order Runge-Kutta method [1]_ for the - given ODE system with a given initial value, + Here, we seek to derive a generic second order Runge-Kutta method [1]_ for the + given ODE system with a given initial value, - .. math:: - y'(t) = f(t,y(t)), \qquad y(t_0)=y_0, - - we want to get a generic solution: - - .. math:: - \begin{align} y_{n+1} &= y_{n} + h \left ( a_1 K_1 + a_2 K_2 \right ) \tag{1} - \end{align} - - where :math:`a_1` and :math:`a_2` are some weights to be determined, - and :math:`K_1` and :math:`K_2` are derivatives on the form: - - .. math:: - \begin{align} - K_1 & = f(t_n,y_n) \qquad \text{and} \qquad K_2 = f(t_n + p_1 h,y_n + p_2 K_1 h ) \tag{2} - \end{align} + .. math:: + y'(t) = f(t,y(t)), \qquad y(t_0)=y_0, - By substitution of (2) in (1) we get: + we want to get a generic solution: - .. math:: - \begin{align} - y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h f(t_n + p_1 h,y_n + p_2 K_1 h) \tag{3} - \end{align} + .. math:: + \begin{align} y_{n+1} &= y_{n} + h \left ( a_1 K_1 + a_2 K_2 \right ) \tag{1} + \end{align} - Now, we may find a Taylor-expansion of :math:`f(t_n + p_1 h, y_n + p_2 K_1 h )` + where :math:`a_1` and :math:`a_2` are some weights to be determined, + and :math:`K_1` and :math:`K_2` are derivatives on the form: - .. math:: - \begin{align} - f(t_n + p_1 h, y_n + p_2 K_1 h ) &= f + p_1 h f_t + p_2 K_1 h f_y + \text{h.o.t.} \nonumber \\ - & = f + p_1 h f_t + p_2 h f f_y + \text{h.o.t.} \tag{4} - \end{align} + .. math:: + \begin{align} + K_1 & = f(t_n,y_n) \qquad \text{and} \qquad K_2 = f(t_n + p_1 h,y_n + p_2 K_1 h ) \tag{2} + \end{align} - where :math:`f_t \equiv \frac{\partial f}{\partial t}` and - :math:`f_y \equiv \frac{\partial f}{\partial y}`. + By substitution of (2) in (1) we get: - By substitution of (4) in (3) we eliminate the implicit dependency of :math:`y_{n+1}` + .. math:: + \begin{align} + y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h f(t_n + p_1 h,y_n + p_2 K_1 h) \tag{3} + \end{align} - .. math:: - \begin{align} - y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h \left (f + p_1 h f_t + p_2 h f f_y \right ) \nonumber \\ - &= y_{n} + (a_1 + a_2) h f + \left (a_2 p_1 f_t + a_2 p_2 f f_y \right) h^2 \tag{5} - \end{align} + Now, we may find a Taylor-expansion of :math:`f(t_n + p_1 h, y_n + p_2 K_1 h )` - In the next, we try to get the second order Taylor expansion of the solution: + .. math:: + \begin{align} + f(t_n + p_1 h, y_n + p_2 K_1 h ) &= f + p_1 h f_t + p_2 K_1 h f_y + \text{h.o.t.} \nonumber \\ + & = f + p_1 h f_t + p_2 h f f_y + \text{h.o.t.} \tag{4} + \end{align} - .. math:: - \begin{align} - y(t_n+h) = y_n + h y' + \frac{h^2}{2} y'' + O(h^3) \tag{6} - \end{align} + where :math:`f_t \equiv \frac{\partial f}{\partial t}` and + :math:`f_y \equiv \frac{\partial f}{\partial y}`. - where the second order derivative is given by + By substitution of (4) in (3) we eliminate the implicit dependency of :math:`y_{n+1}` - .. math:: - \begin{align} - y'' = \frac{d^2 y}{dt^2} = \frac{df}{dt} = \frac{\partial{f}}{\partial{t}} - \frac{dt}{dt} + \frac{\partial{f}}{\partial{y}} \frac{dy}{dt} = f_t + f f_y \tag{7} - \end{align} + .. math:: + \begin{align} + y_{n+1} &= y_{n} + a_1 h f(t_n,y_n) + a_2 h \left (f + p_1 h f_t + p_2 h f f_y \right ) \nonumber \\ + &= y_{n} + (a_1 + a_2) h f + \left (a_2 p_1 f_t + a_2 p_2 f f_y \right) h^2 \tag{5} + \end{align} - Substitution of (7) into (6) yields: + In the next, we try to get the second order Taylor expansion of the solution: - .. math:: - \begin{align} - y(t_n+h) = y_n + h f + \frac{h^2}{2} \left (f_t + f f_y \right ) + O(h^3) \tag{8} - \end{align} + .. math:: + \begin{align} + y(t_n+h) = y_n + h y' + \frac{h^2}{2} y'' + O(h^3) \tag{6} + \end{align} - Finally, in order to approximate (8) by using (5), we get the generic second order - Runge-Kutta method, where + where the second order derivative is given by - .. math:: - \begin{aligned} - a_1 + a_2 = 1 \\ - a_2 p_1 = \frac{1}{2} \\ - a_2 p_2 = \frac{1}{2}. - \end{aligned} + .. math:: + \begin{align} + y'' = \frac{d^2 y}{dt^2} = \frac{df}{dt} = \frac{\partial{f}}{\partial{t}} + \frac{dt}{dt} + \frac{\partial{f}}{\partial{y}} \frac{dy}{dt} = f_t + f f_y \tag{7} + \end{align} - Furthermore, let :math:`p_1=\beta`, we get + Substitution of (7) into (6) yields: - .. math:: - \begin{aligned} - p_1 = & \beta \\ - p_2 = & \beta \\ - a_2 = &\frac{1}{2\beta} \\ - a_1 = &1 - \frac{1}{2\beta} . - \end{aligned} + .. math:: + \begin{align} + y(t_n+h) = y_n + h f + \frac{h^2}{2} \left (f_t + f f_y \right ) + O(h^3) \tag{8} + \end{align} - Therefore, the corresponding Butcher tableau is: + Finally, in order to approximate (8) by using (5), we get the generic second order + Runge-Kutta method, where - .. math:: + .. math:: + \begin{aligned} + a_1 + a_2 = 1 \\ + a_2 p_1 = \frac{1}{2} \\ + a_2 p_2 = \frac{1}{2}. + \end{aligned} - \begin{array}{c|cc} - 0 & 0 & 0 \\ - \beta & \beta & 0 \\ - \hline & 1 - {1 \over 2 * \beta} & {1 \over 2 * \beta} - \end{array} + Furthermore, let :math:`p_1=\beta`, we get - References:: + .. math:: + \begin{aligned} + p_1 = & \beta \\ + p_2 = & \beta \\ + a_2 = &\frac{1}{2\beta} \\ + a_1 = &1 - \frac{1}{2\beta} . + \end{aligned} - .. [1] Chapra, Steven C., and Raymond P. Canale. Numerical methods - for engineers. Vol. 1221. New York: Mcgraw-hill, 2011. + Therefore, the corresponding Butcher tableau is: - """ + .. math:: - def __init__(self, - f, - beta=2 / 3, - var_type=None, - dt=None, - name=None, - show_code=False, - state_delays=None, - neutral_delays=None): - self.A = [(), (beta,)] - self.B = [1 - 1 / (2 * beta), 1 / (2 * beta)] - self.C = [0, beta] - super(RK2, self).__init__(f=f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays) + \begin{array}{c|cc} + 0 & 0 & 0 \\ + \beta & \beta & 0 \\ + \hline & 1 - {1 \over 2 * \beta} & {1 \over 2 * \beta} + \end{array} + + References:: + + .. [1] Chapra, Steven C., and Raymond P. Canale. Numerical methods + for engineers. Vol. 1221. New York: Mcgraw-hill, 2011. + + """ + + def __init__(self, + f, + beta=2 / 3, + var_type=None, + dt=None, + name=None, + show_code=False, + state_delays=None, + neutral_delays=None): + self.A = [(), (beta,)] + self.B = [1 - 1 / (2 * beta), 1 / (2 * beta)] + self.C = [0, beta] + super(RK2, self).__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) register_ode_integrator('rk2', RK2) class RK3(ExplicitRKIntegrator): - r"""Classical third-order Runge-Kutta method for ODEs. + r"""Classical third-order Runge-Kutta method for ODEs. - For the given initial value problem :math:`y'(x) = f(t,y);\, y(t_0) = y_0`, - the third order Runge-Kutta method is given by: + For the given initial value problem :math:`y'(x) = f(t,y);\, y(t_0) = y_0`, + the third order Runge-Kutta method is given by: - .. math:: - y_{n+1} = y_n + 1/6 ( k_1 + 4 k_2 + k_3), + .. math:: + y_{n+1} = y_n + 1/6 ( k_1 + 4 k_2 + k_3), + + where - where + .. math:: + k_1 = h f(t_n, y_n), \\ + k_2 = h f(t_n + h / 2, y_n + k_1 / 2), \\ + k_3 = h f(t_n + h, y_n - k_1 + 2 k_2 ), - .. math:: - k_1 = h f(t_n, y_n), \\ - k_2 = h f(t_n + h / 2, y_n + k_1 / 2), \\ - k_3 = h f(t_n + h, y_n - k_1 + 2 k_2 ), + where :math:`t_n = t_0 + n h.` - where :math:`t_n = t_0 + n h.` + Error term :math:`O(h^4)`, correct up to the third order term in Taylor series expansion. - Error term :math:`O(h^4)`, correct up to the third order term in Taylor series expansion. + The Taylor series expansion is :math:`y(t+h)=y(t)+\frac{k}{6}+\frac{2 k_{2}}{3}+\frac{k_{3}}{6}+O\left(h^{4}\right)`. - The Taylor series expansion is :math:`y(t+h)=y(t)+\frac{k}{6}+\frac{2 k_{2}}{3}+\frac{k_{3}}{6}+O\left(h^{4}\right)`. + The corresponding Butcher tableau is: - The corresponding Butcher tableau is: - - .. math:: - \begin{array}{c|ccc} - 0 & 0 & 0 & 0 \\ - 1 / 2 & 1 / 2 & 0 & 0 \\ - 1 & -1 & 2 & 0 \\ - \hline & 1 / 6 & 2 / 3 & 1 / 6 - \end{array} + .. math:: + \begin{array}{c|ccc} + 0 & 0 & 0 & 0 \\ + 1 / 2 & 1 / 2 & 0 & 0 \\ + 1 & -1 & 2 & 0 \\ + \hline & 1 / 6 & 2 / 3 & 1 / 6 + \end{array} - """ - A = [(), (0.5,), (-1, 2)] - B = ['1/6', '2/3', '1/6'] - C = [0, 0.5, 1] + """ + A = [(), (0.5,), (-1, 2)] + B = ['1/6', '2/3', '1/6'] + C = [0, 0.5, 1] register_ode_integrator('rk3', RK3) class Heun3(ExplicitRKIntegrator): - r"""Heun's third-order method for ODEs. + r"""Heun's third-order method for ODEs. - It has the characteristics of: + It has the characteristics of: - - method stage = 3 - - method order = 3 - - Butcher Tables: + - method stage = 3 + - method order = 3 + - Butcher Tables: - .. math:: + .. math:: - \begin{array}{c|ccc} - 0 & 0 & 0 & 0 \\ - 1 / 3 & 1 / 3 & 0 & 0 \\ - 2 / 3 & 0 & 2 / 3 & 0 \\ - \hline & 1 / 4 & 0 & 3 / 4 - \end{array} + \begin{array}{c|ccc} + 0 & 0 & 0 & 0 \\ + 1 / 3 & 1 / 3 & 0 & 0 \\ + 2 / 3 & 0 & 2 / 3 & 0 \\ + \hline & 1 / 4 & 0 & 3 / 4 + \end{array} - """ - A = [(), ('1/3',), (0, '2/3')] - B = [0.25, 0, 0.75] - C = [0, '1/3', '2/3'] + """ + A = [(), ('1/3',), (0, '2/3')] + B = [0.25, 0, 0.75] + C = [0, '1/3', '2/3'] register_ode_integrator('heun3', Heun3) class Ralston3(ExplicitRKIntegrator): - r"""Ralston's third-order method for ODEs. + r"""Ralston's third-order method for ODEs. - It has the characteristics of: + It has the characteristics of: - - method stage = 3 - - method order = 3 - - Butcher Tables: + - method stage = 3 + - method order = 3 + - Butcher Tables: - .. math:: - \begin{array}{c|ccc} - 0 & 0 & 0 & 0 \\ - 1 / 2 & 1 / 2 & 0 & 0 \\ - 3 / 4 & 0 & 3 / 4 & 0 \\ - \hline & 2 / 9 & 1 / 3 & 4 / 9 - \end{array} + .. math:: + \begin{array}{c|ccc} + 0 & 0 & 0 & 0 \\ + 1 / 2 & 1 / 2 & 0 & 0 \\ + 3 / 4 & 0 & 3 / 4 & 0 \\ + \hline & 2 / 9 & 1 / 3 & 4 / 9 + \end{array} - References:: + References:: - .. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds". - Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0 + .. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds". + Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0 - """ - A = [(), (0.5,), (0, 0.75)] - B = ['2/9', '1/3', '4/9'] - C = [0, 0.5, 0.75] + """ + A = [(), (0.5,), (0, 0.75)] + B = ['2/9', '1/3', '4/9'] + C = [0, 0.5, 0.75] register_ode_integrator('ralston3', Ralston3) class SSPRK3(ExplicitRKIntegrator): - r"""Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). + r"""Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). - It has the characteristics of: + It has the characteristics of: - - method stage = 3 - - method order = 3 - - Butcher Tables: + - method stage = 3 + - method order = 3 + - Butcher Tables: - .. math:: - \begin{array}{c|ccc} - 0 & 0 & 0 & 0 \\ - 1 & 1 & 0 & 0 \\ - 1 / 2 & 1 / 4 & 1 / 4 & 0 \\ - \hline & 1 / 6 & 1 / 6 & 2 / 3 - \end{array} + .. math:: + \begin{array}{c|ccc} + 0 & 0 & 0 & 0 \\ + 1 & 1 & 0 & 0 \\ + 1 / 2 & 1 / 4 & 1 / 4 & 0 \\ + \hline & 1 / 6 & 1 / 6 & 2 / 3 + \end{array} - """ - A = [(), (1,), (0.25, 0.25)] - B = ['1/6', '1/6', '2/3'] - C = [0, 1, 0.5] + """ + A = [(), (1,), (0.25, 0.25)] + B = ['1/6', '1/6', '2/3'] + C = [0, 1, 0.5] register_ode_integrator('ssprk3', SSPRK3) class RK4(ExplicitRKIntegrator): - r"""Classical fourth-order Runge-Kutta method for ODEs. - - For the given initial value problem of + r"""Classical fourth-order Runge-Kutta method for ODEs. - .. math:: - {\frac {dy}{dt}}=f(t,y),\quad y(t_{0})=y_{0}. + For the given initial value problem of - The fourth-order RK method is formulated as: - - .. math:: - \begin{aligned} - y_{n+1}&=y_{n}+{\frac {1}{6}}h\left(k_{1}+2k_{2}+2k_{3}+k_{4}\right),\\ - t_{n+1}&=t_{n}+h\\ - \end{aligned} - - for :math:`n = 0, 1, 2, 3, \cdot`, using + .. math:: + {\frac {dy}{dt}}=f(t,y),\quad y(t_{0})=y_{0}. - .. math:: - \begin{aligned} - k_{1}&=\ f(t_{n},y_{n}),\\ - k_{2}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{1}}{2}}\right),\\ - k_{3}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{2}}{2}}\right),\\ - k_{4}&=\ f\left(t_{n}+h,y_{n}+hk_{3}\right). - \end{aligned} + The fourth-order RK method is formulated as: - Here :math:`y_{n+1}` is the RK4 approximation of :math:`y(t_{n+1})`, and the next - value (:math:`y_{n+1}`) is determined by the present value (:math:`y_{n}`) plus - the weighted average of four increments, where each increment is the product - of the size of the interval, :math:`h`, and an estimated slope specified by function - :math:`f` on the right-hand side of the differential equation. + .. math:: + \begin{aligned} + y_{n+1}&=y_{n}+{\frac {1}{6}}h\left(k_{1}+2k_{2}+2k_{3}+k_{4}\right),\\ + t_{n+1}&=t_{n}+h\\ + \end{aligned} - - :math:`k_{1}` is the slope at the beginning of the interval, using :math:`y` (Euler's method); - - :math:`k_{2}` is the slope at the midpoint of the interval, using :math:`y` and :math:`k_{1}`; - - :math:`k_{3}` is again the slope at the midpoint, but now using :math:`y` and :math:`k_{2}`; - - :math:`k_{4}` is the slope at the end of the interval, using :math:`y` and :math:`k_{3}`. + for :math:`n = 0, 1, 2, 3, \cdot`, using - The RK4 method is a fourth-order method, meaning that the local truncation error is on the order - of (:math:`O(h^{5}`), while the total accumulated error is on the order of (:math:`O(h^{4}`). + .. math:: + \begin{aligned} + k_{1}&=\ f(t_{n},y_{n}),\\ + k_{2}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{1}}{2}}\right),\\ + k_{3}&=\ f\left(t_{n}+{\frac {h}{2}},y_{n}+h{\frac {k_{2}}{2}}\right),\\ + k_{4}&=\ f\left(t_{n}+h,y_{n}+hk_{3}\right). + \end{aligned} - The corresponding Butcher tableau is: + Here :math:`y_{n+1}` is the RK4 approximation of :math:`y(t_{n+1})`, and the next + value (:math:`y_{n+1}`) is determined by the present value (:math:`y_{n}`) plus + the weighted average of four increments, where each increment is the product + of the size of the interval, :math:`h`, and an estimated slope specified by function + :math:`f` on the right-hand side of the differential equation. - .. math:: - \begin{array}{c|cccc} - 0 & 0 & 0 & 0 & 0 \\ - 1 / 2 & 1 / 2 & 0 & 0 & 0 \\ - 1 / 2 & 0 & 1 / 2 & 0 & 0 \\ - 1 & 0 & 0 & 1 & 0 \\ - \hline & 1 / 6 & 1 / 3 & 1 / 3 & 1 / 6 - \end{array} + - :math:`k_{1}` is the slope at the beginning of the interval, using :math:`y` (Euler's method); + - :math:`k_{2}` is the slope at the midpoint of the interval, using :math:`y` and :math:`k_{1}`; + - :math:`k_{3}` is again the slope at the midpoint, but now using :math:`y` and :math:`k_{2}`; + - :math:`k_{4}` is the slope at the end of the interval, using :math:`y` and :math:`k_{3}`. - References:: + The RK4 method is a fourth-order method, meaning that the local truncation error is on the order + of (:math:`O(h^{5}`), while the total accumulated error is on the order of (:math:`O(h^{4}`). - .. [1] Lambert, J. D. and Lambert, D. Ch. 5 in Numerical Methods for Ordinary - Differential Systems: The Initial Value Problem. New York: Wiley, 1991. - .. [2] Press, W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. - "Runge-Kutta Method" and "Adaptive Step Size Control for Runge-Kutta." - §16.1 and 16.2 in Numerical Recipes in FORTRAN: The Art of Scientific - Computing, 2nd ed. Cambridge, England: Cambridge University Press, - pp. 704-716, 1992. - """ + The corresponding Butcher tableau is: - A = [(), (0.5,), (0., 0.5), (0., 0., 1)] - B = ['1/6', '1/3', '1/3', '1/6'] - C = [0, 0.5, 0.5, 1] + .. math:: + \begin{array}{c|cccc} + 0 & 0 & 0 & 0 & 0 \\ + 1 / 2 & 1 / 2 & 0 & 0 & 0 \\ + 1 / 2 & 0 & 1 / 2 & 0 & 0 \\ + 1 & 0 & 0 & 1 & 0 \\ + \hline & 1 / 6 & 1 / 3 & 1 / 3 & 1 / 6 + \end{array} + + References:: + + .. [1] Lambert, J. D. and Lambert, D. Ch. 5 in Numerical Methods for Ordinary + Differential Systems: The Initial Value Problem. New York: Wiley, 1991. + .. [2] Press, W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. + "Runge-Kutta Method" and "Adaptive Step Size Control for Runge-Kutta." + §16.1 and 16.2 in Numerical Recipes in FORTRAN: The Art of Scientific + Computing, 2nd ed. Cambridge, England: Cambridge University Press, + pp. 704-716, 1992. + """ + + A = [(), (0.5,), (0., 0.5), (0., 0., 1)] + B = ['1/6', '1/3', '1/3', '1/6'] + C = [0, 0.5, 0.5, 1] register_ode_integrator('rk4', RK4) class Ralston4(ExplicitRKIntegrator): - r"""Ralston's fourth-order method for ODEs. + r"""Ralston's fourth-order method for ODEs. - It has the characteristics of: + It has the characteristics of: - - method stage = 4 - - method order = 4 - - Butcher Tables: + - method stage = 4 + - method order = 4 + - Butcher Tables: - .. math:: + .. math:: - \begin{array}{c|cccc} - 0 & 0 & 0 & 0 & 0 \\ - .4 & .4 & 0 & 0 & 0 \\ - .45573725 & .29697761 & .15875964 & 0 & 0 \\ - 1 & .21810040 & -3.05096516 & 3.83286476 & 0 \\ - \hline & .17476028 & -.55148066 & 1.20553560 & .17118478 - \end{array} + \begin{array}{c|cccc} + 0 & 0 & 0 & 0 & 0 \\ + .4 & .4 & 0 & 0 & 0 \\ + .45573725 & .29697761 & .15875964 & 0 & 0 \\ + 1 & .21810040 & -3.05096516 & 3.83286476 & 0 \\ + \hline & .17476028 & -.55148066 & 1.20553560 & .17118478 + \end{array} - References:: + References:: - .. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds". - Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0 + .. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds". + Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0 - """ - A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)] - B = [.17476028, -.55148066, 1.20553560, .17118478] - C = [0, .4, .45573725, 1] + """ + A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)] + B = [.17476028, -.55148066, 1.20553560, .17118478] + C = [0, .4, .45573725, 1] register_ode_integrator('ralston4', Ralston4) class RK4Rule38(ExplicitRKIntegrator): - r"""3/8-rule fourth-order method for ODEs. + r"""3/8-rule fourth-order method for ODEs. - A slight variation of "the" Runge–Kutta method is also due - to Kutta in 1901 [1]_ and is called the 3/8-rule. The primary - advantage this method has is that almost all of the error - coefficients are smaller than in the popular method, but it - requires slightly more FLOPs (floating-point operations) per - time step. + A slight variation of "the" Runge–Kutta method is also due + to Kutta in 1901 [1]_ and is called the 3/8-rule. The primary + advantage this method has is that almost all of the error + coefficients are smaller than in the popular method, but it + requires slightly more FLOPs (floating-point operations) per + time step. - It has the characteristics of: + It has the characteristics of: - - method stage = 4 - - method order = 4 - - Butcher Tables: + - method stage = 4 + - method order = 4 + - Butcher Tables: - .. math:: + .. math:: - \begin{array}{c|cccc} - 0 & 0 & 0 & 0 & 0 \\ - 1 / 3 & 1 / 3 & 0 & 0 & 0 \\ - 2 / 3 & -1 / 3 & 1 & 0 & 0 \\ - 1 & 1 & -1 & 1 & 0 \\ - \hline & 1 / 8 & 3 / 8 & 3 / 8 & 1 / 8 - \end{array} + \begin{array}{c|cccc} + 0 & 0 & 0 & 0 & 0 \\ + 1 / 3 & 1 / 3 & 0 & 0 & 0 \\ + 2 / 3 & -1 / 3 & 1 & 0 & 0 \\ + 1 & 1 & -1 & 1 & 0 \\ + \hline & 1 / 8 & 3 / 8 & 3 / 8 & 1 / 8 + \end{array} - References:: + References:: - .. [1] Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), - Solving ordinary differential equations I: Nonstiff problems, - Berlin, New York: Springer-Verlag, ISBN 978-3-540-56670-0. + .. [1] Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), + Solving ordinary differential equations I: Nonstiff problems, + Berlin, New York: Springer-Verlag, ISBN 978-3-540-56670-0. - """ - A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] - B = [0.125, 0.375, 0.375, 0.125] - C = [0, '1/3', '2/3', 1] + """ + A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] + B = [0.125, 0.375, 0.375, 0.125] + C = [0, '1/3', '2/3', 1] register_ode_integrator('rk4_38rule', RK4Rule38) diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index 9176bbe2d..eb8d9472f 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -115,255 +115,255 @@ from brainpy._src.integrators.ode.base import ODEIntegrator from .generic import register_ode_integrator - __all__ = [ - 'ExponentialEuler', + 'ExponentialEuler', ] class ExponentialEuler(ODEIntegrator): - """Exponential Euler method using automatic differentiation. - - This method uses `brainpy.math.vector_grad <../../math/generated/brainpy.math.autograd.vector_grad.html>`_ - to automatically infer the linear part of the given function. Therefore, it has minimal constraints - on your derivative function. Arbitrary complex functions can be numerically integrated with this method. - - Examples:: - - Here is an example uses ``ExponentialEuler`` to implement HH neuron model. - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> class HH(bp.dyn.NeuDyn): - >>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., - >>> gL=0.1, V_th=20., phi=5.0, name=None): - >>> super(HH, self).__init__(size=size, name=name) - >>> - >>> # parameters - >>> self.ENa = ENa - >>> self.EK = EK - >>> self.EL = EL - >>> self.C = C - >>> self.gNa = gNa - >>> self.gK = gK - >>> self.gL = gL - >>> self.V_th = V_th - >>> self.phi = phi - >>> - >>> # variables - >>> self.V = bm.Variable(bm.ones(size) * -65.) - >>> self.h = bm.Variable(bm.ones(size) * 0.6) - >>> self.n = bm.Variable(bm.ones(size) * 0.32) - >>> self.spike = bm.Variable(bm.zeros(size, dtype=bool)) - >>> self.input = bm.Variable(bm.zeros(size)) - >>> - >>> # functions - >>> self.int_h = bp.ode.ExponentialEuler(self.dh) - >>> self.int_n = bp.ode.ExponentialEuler(self.dn) - >>> self.int_V = bp.ode.ExponentialEuler(self.dV) - >>> - >>> def dh(self, h, t, V): - >>> alpha = 0.07 * bm.exp(-(V + 58) / 20) - >>> beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) - >>> dhdt = self.phi * (alpha * (1 - h) - beta * h) - >>> return dhdt - >>> - >>> def dn(self, n, t, V): - >>> alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) - >>> beta = 0.125 * bm.exp(-(V + 44) / 80) - >>> dndt = self.phi * (alpha * (1 - n) - beta * n) - >>> return dndt - >>> - >>> def dV(self, V, t, h, n, Iext): - >>> m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) - >>> m_beta = 4 * bm.exp(-(V + 60) / 18) - >>> m = m_alpha / (m_alpha + m_beta) - >>> INa = self.gNa * m ** 3 * h * (V - self.ENa) - >>> IK = self.gK * n ** 4 * (V - self.EK) - >>> IL = self.gL * (V - self.EL) - >>> dVdt = (- INa - IK - IL + Iext) / self.C - >>> - >>> return dVdt - >>> - >>> def update(self, tdi): - >>> h = self.int_h(self.h, tdi.t, self.V, dt=tdi.dt) - >>> n = self.int_n(self.n, tdi.t, self.V, dt=tdi.dt) - >>> V = self.int_V(self.V, tdi.t, self.h, self.n, self.input, dt=tdi.dt) - >>> self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - >>> self.V.value = V - >>> self.h.value = h - >>> self.n.value = n - >>> self.input[:] = 0. - >>> - >>> run = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) - >>> run(100) - >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) - - The above example can also be defined with ``brainpy.JointEq``. - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> class HH(bp.dyn.NeuDyn): - >>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., - >>> gL=0.1, V_th=20., phi=5.0, name=None): - >>> super(HH, self).__init__(size=size, name=name) - >>> - >>> # parameters - >>> self.ENa = ENa - >>> self.EK = EK - >>> self.EL = EL - >>> self.C = C - >>> self.gNa = gNa - >>> self.gK = gK - >>> self.gL = gL - >>> self.V_th = V_th - >>> self.phi = phi - >>> - >>> # variables - >>> self.V = bm.Variable(bm.ones(size) * -65.) - >>> self.h = bm.Variable(bm.ones(size) * 0.6) - >>> self.n = bm.Variable(bm.ones(size) * 0.32) - >>> self.spike = bm.Variable(bm.zeros(size, dtype=bool)) - >>> self.input = bm.Variable(bm.zeros(size)) - >>> - >>> # functions - >>> derivative = bp.JointEq([self.dh, self.dn, self.dV]) - >>> self.integral = bp.ode.ExponentialEuler(derivative) - >>> - >>> def dh(self, h, t, V): - >>> alpha = 0.07 * bm.exp(-(V + 58) / 20) - >>> beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) - >>> dhdt = self.phi * (alpha * (1 - h) - beta * h) - >>> return dhdt - >>> - >>> def dn(self, n, t, V): - >>> alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) - >>> beta = 0.125 * bm.exp(-(V + 44) / 80) - >>> dndt = self.phi * (alpha * (1 - n) - beta * n) - >>> return dndt - >>> - >>> def dV(self, V, t, h, n, Iext): - >>> m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) - >>> m_beta = 4 * bm.exp(-(V + 60) / 18) - >>> m = m_alpha / (m_alpha + m_beta) - >>> INa = self.gNa * m ** 3 * h * (V - self.ENa) - >>> IK = self.gK * n ** 4 * (V - self.EK) - >>> IL = self.gL * (V - self.EL) - >>> dVdt = (- INa - IK - IL + Iext) / self.C - >>> - >>> return dVdt - >>> - >>> def update(self, tdi): - >>> h, n, V = self.integral(self.h, self.n, self.V, tdi.t, self.input, dt=tdi.dt) - >>> self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - >>> self.V.value = V - >>> self.h.value = h - >>> self.n.value = n - >>> self.input[:] = 0. - >>> - >>> run = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) - >>> run(100) - >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) - - Parameters:: - - f : function, joint_eq.JointEq - The derivative function. - var_type : optional, str - The variable type. - dt : optional, float - The default numerical integration step. - name : optional, str - The integrator name. - """ - - def __init__( - self, - f, - var_type=None, - dt=None, - name=None, - show_code=False, - state_delays=None, - neutral_delays=None - ): - super(ExponentialEuler, self).__init__(f=f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays) - - if var_type == C.SYSTEM_VAR: - raise NotImplementedError(f'{self.__class__.__name__} does not support {C.SYSTEM_VAR}, ' - f'because the auto-differentiation ') - - # build the integrator - self.code_lines = [] - self.code_scope = {} - self.integral = self.build() - - def build(self): - parses = self._build_integrator(self.f) - all_vps = self.variables + self.parameters - - @wraps(self.f) - def integral_func(*args, **kwargs): - # format arguments - params_in = bm.Collector() - for i, arg in enumerate(args): - params_in[all_vps[i]] = arg - params_in.update(kwargs) - if C.DT not in params_in: - params_in[C.DT] = self.dt - - # call integrals - results = [] - for i, parse in enumerate(parses): - f_integral, vars_, pars_ = parse - vps = vars_ + pars_ + [C.DT] - r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) - results.append(r) - return results if len(self.variables) > 1 else results[0] - - return integral_func - - def _build_integrator(self, eq): - if isinstance(eq, joint_eq.JointEq): - results = [] - for sub_eq in eq.eqs: - results.extend(self._build_integrator(sub_eq)) - return results - else: - vars, pars, _ = utils.get_args(eq) - - # checking - if len(vars) != 1: - raise errors.DiffEqError(C.multi_vars_msg.format(cls=self.__class__.__name__, - vars=str(vars), - eq=str(eq))) - - # integration function - def integral(*args, **kwargs): - assert len(args) > 0 - if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]: - raise ValueError('The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.' - f'But we got {args[0].dtype}.') - dt = kwargs.pop(C.DT, self.dt) - linear, derivative = bm.vector_grad(eq, argnums=0, return_value=True)(*args, **kwargs) - phi = bm.exprel(dt * linear) - return args[0] + dt * phi * derivative - - return [(integral, vars, pars), ] + """Exponential Euler method using automatic differentiation. + + This method uses `brainpy.math.vector_grad <../../math/generated/brainpy.math.autograd.vector_grad.html>`_ + to automatically infer the linear part of the given function. Therefore, it has minimal constraints + on your derivative function. Arbitrary complex functions can be numerically integrated with this method. + + Examples:: + + Here is an example uses ``ExponentialEuler`` to implement HH neuron model. + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> class HH(bp.dyn.NeuDyn): + >>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., + >>> gL=0.1, V_th=20., phi=5.0, name=None): + >>> super(HH, self).__init__(size=size, name=name) + >>> + >>> # parameters + >>> self.ENa = ENa + >>> self.EK = EK + >>> self.EL = EL + >>> self.C = C + >>> self.gNa = gNa + >>> self.gK = gK + >>> self.gL = gL + >>> self.V_th = V_th + >>> self.phi = phi + >>> + >>> # variables + >>> self.V = bm.Variable(bm.ones(size) * -65.) + >>> self.h = bm.Variable(bm.ones(size) * 0.6) + >>> self.n = bm.Variable(bm.ones(size) * 0.32) + >>> self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + >>> self.input = bm.Variable(bm.zeros(size)) + >>> + >>> # functions + >>> self.int_h = bp.ode.ExponentialEuler(self.dh) + >>> self.int_n = bp.ode.ExponentialEuler(self.dn) + >>> self.int_V = bp.ode.ExponentialEuler(self.dV) + >>> + >>> def dh(self, h, t, V): + >>> alpha = 0.07 * bm.exp(-(V + 58) / 20) + >>> beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + >>> dhdt = self.phi * (alpha * (1 - h) - beta * h) + >>> return dhdt + >>> + >>> def dn(self, n, t, V): + >>> alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + >>> beta = 0.125 * bm.exp(-(V + 44) / 80) + >>> dndt = self.phi * (alpha * (1 - n) - beta * n) + >>> return dndt + >>> + >>> def dV(self, V, t, h, n, Iext): + >>> m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + >>> m_beta = 4 * bm.exp(-(V + 60) / 18) + >>> m = m_alpha / (m_alpha + m_beta) + >>> INa = self.gNa * m ** 3 * h * (V - self.ENa) + >>> IK = self.gK * n ** 4 * (V - self.EK) + >>> IL = self.gL * (V - self.EL) + >>> dVdt = (- INa - IK - IL + Iext) / self.C + >>> + >>> return dVdt + >>> + >>> def update(self, tdi): + >>> h = self.int_h(self.h, tdi.t, self.V, dt=tdi.dt) + >>> n = self.int_n(self.n, tdi.t, self.V, dt=tdi.dt) + >>> V = self.int_V(self.V, tdi.t, self.h, self.n, self.input, dt=tdi.dt) + >>> self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + >>> self.V.value = V + >>> self.h.value = h + >>> self.n.value = n + >>> self.input[:] = 0. + >>> + >>> run = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) + >>> run(100) + >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) + + The above example can also be defined with ``brainpy.JointEq``. + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> class HH(bp.dyn.NeuDyn): + >>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., + >>> gL=0.1, V_th=20., phi=5.0, name=None): + >>> super(HH, self).__init__(size=size, name=name) + >>> + >>> # parameters + >>> self.ENa = ENa + >>> self.EK = EK + >>> self.EL = EL + >>> self.C = C + >>> self.gNa = gNa + >>> self.gK = gK + >>> self.gL = gL + >>> self.V_th = V_th + >>> self.phi = phi + >>> + >>> # variables + >>> self.V = bm.Variable(bm.ones(size) * -65.) + >>> self.h = bm.Variable(bm.ones(size) * 0.6) + >>> self.n = bm.Variable(bm.ones(size) * 0.32) + >>> self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + >>> self.input = bm.Variable(bm.zeros(size)) + >>> + >>> # functions + >>> derivative = bp.JointEq([self.dh, self.dn, self.dV]) + >>> self.integral = bp.ode.ExponentialEuler(derivative) + >>> + >>> def dh(self, h, t, V): + >>> alpha = 0.07 * bm.exp(-(V + 58) / 20) + >>> beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + >>> dhdt = self.phi * (alpha * (1 - h) - beta * h) + >>> return dhdt + >>> + >>> def dn(self, n, t, V): + >>> alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + >>> beta = 0.125 * bm.exp(-(V + 44) / 80) + >>> dndt = self.phi * (alpha * (1 - n) - beta * n) + >>> return dndt + >>> + >>> def dV(self, V, t, h, n, Iext): + >>> m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + >>> m_beta = 4 * bm.exp(-(V + 60) / 18) + >>> m = m_alpha / (m_alpha + m_beta) + >>> INa = self.gNa * m ** 3 * h * (V - self.ENa) + >>> IK = self.gK * n ** 4 * (V - self.EK) + >>> IL = self.gL * (V - self.EL) + >>> dVdt = (- INa - IK - IL + Iext) / self.C + >>> + >>> return dVdt + >>> + >>> def update(self, tdi): + >>> h, n, V = self.integral(self.h, self.n, self.V, tdi.t, self.input, dt=tdi.dt) + >>> self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + >>> self.V.value = V + >>> self.h.value = h + >>> self.n.value = n + >>> self.input[:] = 0. + >>> + >>> run = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) + >>> run(100) + >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) + + Parameters:: + + f : function, joint_eq.JointEq + The derivative function. + var_type : optional, str + The variable type. + dt : optional, float + The default numerical integration step. + name : optional, str + The integrator name. + """ + + def __init__( + self, + f, + var_type=None, + dt=None, + name=None, + show_code=False, + state_delays=None, + neutral_delays=None + ): + super(ExponentialEuler, self).__init__(f=f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays) + + if var_type == C.SYSTEM_VAR: + raise NotImplementedError(f'{self.__class__.__name__} does not support {C.SYSTEM_VAR}, ' + f'because the auto-differentiation ') + + # build the integrator + self.code_lines = [] + self.code_scope = {} + self.integral = self.build() + + def build(self): + parses = self._build_integrator(self.f) + all_vps = self.variables + self.parameters + + @wraps(self.f) + def integral_func(*args, **kwargs): + # format arguments + params_in = bm.Collector() + for i, arg in enumerate(args): + params_in[all_vps[i]] = arg + params_in.update(kwargs) + if C.DT not in params_in: + params_in[C.DT] = self.dt + + # call integrals + results = [] + for i, parse in enumerate(parses): + f_integral, vars_, pars_ = parse + vps = vars_ + pars_ + [C.DT] + r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) + results.append(r) + return results if len(self.variables) > 1 else results[0] + + return integral_func + + def _build_integrator(self, eq): + if isinstance(eq, joint_eq.JointEq): + results = [] + for sub_eq in eq.eqs: + results.extend(self._build_integrator(sub_eq)) + return results + else: + vars, pars, _ = utils.get_args(eq) + + # checking + if len(vars) != 1: + raise errors.DiffEqError(C.multi_vars_msg.format(cls=self.__class__.__name__, + vars=str(vars), + eq=str(eq))) + + # integration function + def integral(*args, **kwargs): + assert len(args) > 0 + if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]: + raise ValueError( + 'The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.' + f'But we got {args[0].dtype}.') + dt = kwargs.pop(C.DT, self.dt) + linear, derivative = bm.vector_grad(eq, argnums=0, return_value=True)(*args, **kwargs) + phi = bm.exprel(dt * linear) + return args[0] + dt * phi * derivative + + return [(integral, vars, pars), ] register_ode_integrator('exponential_euler', ExponentialEuler) diff --git a/brainpy/_src/integrators/ode/generic.py b/brainpy/_src/integrators/ode/generic.py index 63e03605c..90a323197 100644 --- a/brainpy/_src/integrators/ode/generic.py +++ b/brainpy/_src/integrators/ode/generic.py @@ -6,10 +6,10 @@ from .base import ODEIntegrator __all__ = [ - 'set_default_odeint', - 'get_default_odeint', - 'register_ode_integrator', - 'get_supported_methods', + 'set_default_odeint', + 'get_default_odeint', + 'register_ode_integrator', + 'get_supported_methods', ] name2method = { @@ -29,127 +29,127 @@ def odeint( neutral_delays: Dict[str, NeuTimeDelay] = None, **kwargs ): - """Numerical integration for ODEs. - - Examples:: - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import matplotlib.pyplot as plt - >>> - >>> a=0.7; b=0.8; tau=12.5; Vth=1.9 - >>> V = 0; w = 0 # initial values - >>> - >>> @bp.odeint(method='rk4', dt=0.04) - >>> def integral(V, w, t, Iext): - >>> dw = (V + a - b * w) / tau - >>> dV = V - V * V * V / 3 - w + Iext - >>> return dV, dw - >>> - >>> hist_V = [] - >>> for t in bp.math.arange(0, 100, integral.dt): - >>> V, w = integral(V, w, t, 0.5) - >>> hist_V.append(V) - >>> plt.plot(bp.math.arange(0, 100, integral.dt), hist_V) - >>> plt.show() - - - Parameters:: - - f : callable, function - The derivative function. - method : str - The shortcut name of the numerical integrator. - var_type: str - The type of the variable defined in the equation. - dt: float - The numerical integration precision. - name: str - The integrator node. - state_delays: dict - The state delay variable. - show_code: bool - Show the formated code. - adaptive: bool - The use adaptive mode. - tol: float - The tolerence to adapt new step size. - - Returns:: - - integral : ODEIntegrator - The numerical solver of `f`. - """ - method = _DEFAULT_DDE_METHOD if method is None else method - if method not in name2method: - raise ValueError(f'Unknown ODE numerical method "{method}". Currently ' - f'BrainPy only support: {list(name2method.keys())}') - - if f is None: - return lambda f: name2method[method](f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays, - **kwargs) - else: - return name2method[method](f, - var_type=var_type, - dt=dt, - name=name, - show_code=show_code, - state_delays=state_delays, - neutral_delays=neutral_delays, - **kwargs) + """Numerical integration for ODEs. + + Examples:: + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import matplotlib.pyplot as plt + >>> + >>> a=0.7; b=0.8; tau=12.5; Vth=1.9 + >>> V = 0; w = 0 # initial values + >>> + >>> @bp.odeint(method='rk4', dt=0.04) + >>> def integral(V, w, t, Iext): + >>> dw = (V + a - b * w) / tau + >>> dV = V - V * V * V / 3 - w + Iext + >>> return dV, dw + >>> + >>> hist_V = [] + >>> for t in bp.math.arange(0, 100, integral.dt): + >>> V, w = integral(V, w, t, 0.5) + >>> hist_V.append(V) + >>> plt.plot(bp.math.arange(0, 100, integral.dt), hist_V) + >>> plt.show() + + + Parameters:: + + f : callable, function + The derivative function. + method : str + The shortcut name of the numerical integrator. + var_type: str + The type of the variable defined in the equation. + dt: float + The numerical integration precision. + name: str + The integrator node. + state_delays: dict + The state delay variable. + show_code: bool + Show the formated code. + adaptive: bool + The use adaptive mode. + tol: float + The tolerence to adapt new step size. + + Returns:: + + integral : ODEIntegrator + The numerical solver of `f`. + """ + method = _DEFAULT_DDE_METHOD if method is None else method + if method not in name2method: + raise ValueError(f'Unknown ODE numerical method "{method}". Currently ' + f'BrainPy only support: {list(name2method.keys())}') + + if f is None: + return lambda f: name2method[method](f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays, + **kwargs) + else: + return name2method[method](f, + var_type=var_type, + dt=dt, + name=name, + show_code=show_code, + state_delays=state_delays, + neutral_delays=neutral_delays, + **kwargs) def set_default_odeint(method): - """Set the default ODE numerical integrator method for differential equations. + """Set the default ODE numerical integrator method for differential equations. - Parameters:: - - method : str, callable - Numerical integrator method. - """ - if not isinstance(method, str): - raise ValueError(f'Only support string, not {type(method)}.') - if method not in name2method: - raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') + Parameters:: - global _DEFAULT_DDE_METHOD - _DEFAULT_ODE_METHOD = method + method : str, callable + Numerical integrator method. + """ + if not isinstance(method, str): + raise ValueError(f'Only support string, not {type(method)}.') + if method not in name2method: + raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') + + global _DEFAULT_DDE_METHOD + _DEFAULT_ODE_METHOD = method def get_default_odeint(): - """Get the default ODE numerical integrator method. + """Get the default ODE numerical integrator method. + + Returns:: - Returns:: - - method : str - The default numerical integrator method. - """ - return _DEFAULT_DDE_METHOD + method : str + The default numerical integrator method. + """ + return _DEFAULT_DDE_METHOD def register_ode_integrator(name, integrator): - """Register a new ODE integrator. + """Register a new ODE integrator. + + Parameters:: - Parameters:: - - name: ste - integrator: type - """ - if name in name2method: - raise ValueError(f'"{name}" has been registered in ODE integrators.') - if not issubclass(integrator, ODEIntegrator): - raise ValueError(f'"integrator" must be an instance of {ODEIntegrator.__name__}') - name2method[name] = integrator + name: ste + integrator: type + """ + if name in name2method: + raise ValueError(f'"{name}" has been registered in ODE integrators.') + if not issubclass(integrator, ODEIntegrator): + raise ValueError(f'"integrator" must be an instance of {ODEIntegrator.__name__}') + name2method[name] = integrator def get_supported_methods(): - """Get all supported numerical methods for DDEs.""" - return list(name2method.keys()) + """Get all supported numerical methods for DDEs.""" + return list(name2method.keys()) diff --git a/brainpy/_src/integrators/ode/tests/test_delay_ode.py b/brainpy/_src/integrators/ode/tests/test_delay_ode.py index 4759924bb..5b07fd6b8 100644 --- a/brainpy/_src/integrators/ode/tests/test_delay_ode.py +++ b/brainpy/_src/integrators/ode/tests/test_delay_ode.py @@ -13,25 +13,27 @@ def delay_odeint(duration, eq, args=None, inits=None, state_delays=None, neutral_delays=None, monitors=('x',), method='euler', dt=0.1): - # define integrators of ODEs based on `brainpy.odeint` - dde = bp.odeint(eq, - state_delays=state_delays, - neutral_delays=neutral_delays, - method=method) - # define IntegratorRunner - runner = bp.IntegratorRunner(dde, - args=args, - monitors=monitors, - dt=dt, - inits=inits, - progress_bar=False) - runner.run(duration) - return runner.mon + # define integrators of ODEs based on `brainpy.odeint` + dde = bp.odeint(eq, + state_delays=state_delays, + neutral_delays=neutral_delays, + method=method) + # define IntegratorRunner + runner = bp.IntegratorRunner(dde, + args=args, + monitors=monitors, + dt=dt, + inits=inits, + progress_bar=False) + runner.run(duration) + return runner.mon + def get_eq1(xdelay): - def eq1(x, t): - return -xdelay(t - 1) - return eq1 + def eq1(x, t): + return -xdelay(t - 1) + + return eq1 case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') @@ -41,9 +43,10 @@ def eq1(x, t): def get_eq2(xdelay): - def eq2(x, t): - return -xdelay(t - 2) - return eq2 + def eq2(x, t): + return -xdelay(t - 2) + + return eq2 delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01, interp_method='round') @@ -53,70 +56,68 @@ def eq2(x, t): class TestFirstOrderConstantDelay(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super(TestFirstOrderConstantDelay, self).__init__(*args, **kwargs) - - @parameterized.named_parameters( - {'testcase_name': f'constant_delay_{name}', - 'method': name} - for name in get_supported_methods() - ) - def test1(self, method): - bm.random.seed() - case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') - case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='linear_interp') - - case1 = delay_odeint(20., get_eq1(case1_delay), state_delays={'x': case1_delay}, method=method) - case2 = delay_odeint(20., get_eq1(case2_delay), state_delays={'x': case2_delay}, method=method) - - print(method) - print("case1.keys()", case1.keys()) - print("case2.keys()", case2.keys()) - print("self.ref1.keys()", ref1.keys()) - print("self.ref2.keys()", ref2.keys()) - - # self.assertTrue((case1['x'] - self.ref1['x']).mean() < 1e-3) - # self.assertTrue((case2['x'] - self.ref2['x']).mean() < 1e-3) - - # fig, axs = plt.subplots(2, 1) - # fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0) - # fig.suptitle("$y'(t)=-y(t-1)$") - # axs[0].plot(case1.ts, case1.x, color='red', linewidth=1) - # axs[0].set_title('$ihf(t)=-1$') - # axs[1].plot(case2.ts, case2.x, color='red', linewidth=1) - # axs[1].set_title('$ihf(t)=0$') - # plt.show(block=block) - # plt.close() - - + def __init__(self, *args, **kwargs): + super(TestFirstOrderConstantDelay, self).__init__(*args, **kwargs) + + @parameterized.named_parameters( + {'testcase_name': f'constant_delay_{name}', + 'method': name} + for name in get_supported_methods() + ) + def test1(self, method): + bm.random.seed() + case1_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='round') + case2_delay = bm.TimeDelay(bm.zeros((1,)), 1., before_t0=-1., interp_method='linear_interp') + + case1 = delay_odeint(20., get_eq1(case1_delay), state_delays={'x': case1_delay}, method=method) + case2 = delay_odeint(20., get_eq1(case2_delay), state_delays={'x': case2_delay}, method=method) + + print(method) + print("case1.keys()", case1.keys()) + print("case2.keys()", case2.keys()) + print("self.ref1.keys()", ref1.keys()) + print("self.ref2.keys()", ref2.keys()) + + # self.assertTrue((case1['x'] - self.ref1['x']).mean() < 1e-3) + # self.assertTrue((case2['x'] - self.ref2['x']).mean() < 1e-3) + + # fig, axs = plt.subplots(2, 1) + # fig.tight_layout(rect=[0, 0, 1, 0.95], pad=3.0) + # fig.suptitle("$y'(t)=-y(t-1)$") + # axs[0].plot(case1.ts, case1.x, color='red', linewidth=1) + # axs[0].set_title('$ihf(t)=-1$') + # axs[1].plot(case2.ts, case2.x, color='red', linewidth=1) + # axs[1].set_title('$ihf(t)=0$') + # plt.show(block=block) + # plt.close() class TestNonConstantHist(parameterized.TestCase): - def get_eq(self, xdelay): - def eq(x, t): - return -xdelay(t - 2) - return eq - - def __init__(self, *args, **kwargs): - super(TestNonConstantHist, self).__init__(*args, **kwargs) - - @parameterized.named_parameters( - {'testcase_name': f'constant_delay_{name}', 'method': name} - for name in get_supported_methods() - ) - def test1(self, method): - bm.random.seed() - - delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01, interp_method='round') - delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01) - case1 = delay_odeint(4., self.get_eq(delay1), state_delays={'x': delay1}, dt=0.01, method=method) - case2 = delay_odeint(4., self.get_eq(delay2), state_delays={'x': delay2}, dt=0.01, method=method) - - print("case1.keys()", case1.keys()) - print("case2.keys()", case2.keys()) - print("ref3.keys()", ref3.keys()) - print("ref4.keys()", ref4.keys()) - - # self.assertTrue((case1['x'] - self.ref1['x']).mean() < 1e-1) - # self.assertTrue((case2['x'] - self.ref2['x']).mean() < 1e-1) - + def get_eq(self, xdelay): + def eq(x, t): + return -xdelay(t - 2) + + return eq + + def __init__(self, *args, **kwargs): + super(TestNonConstantHist, self).__init__(*args, **kwargs) + + @parameterized.named_parameters( + {'testcase_name': f'constant_delay_{name}', 'method': name} + for name in get_supported_methods() + ) + def test1(self, method): + bm.random.seed() + + delay1 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01, interp_method='round') + delay2 = bm.TimeDelay(bm.zeros(1), 2., before_t0=lambda t: jnp.exp(-t) - 1, dt=0.01) + case1 = delay_odeint(4., self.get_eq(delay1), state_delays={'x': delay1}, dt=0.01, method=method) + case2 = delay_odeint(4., self.get_eq(delay2), state_delays={'x': delay2}, dt=0.01, method=method) + + print("case1.keys()", case1.keys()) + print("case2.keys()", case2.keys()) + print("ref3.keys()", ref3.keys()) + print("ref4.keys()", ref4.keys()) + + # self.assertTrue((case1['x'] - self.ref1['x']).mean() < 1e-1) + # self.assertTrue((case2['x'] - self.ref2['x']).mean() < 1e-1) diff --git a/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_adaptive_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_adaptive_rk.py index 9579593dd..8e3116e96 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_adaptive_rk.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_adaptive_rk.py @@ -5,337 +5,336 @@ import pytest from brainpy import errors -from brainpy.integrators import ode from brainpy import odeint class TestAdaptiveRKKeywords(unittest.TestCase): - def test_bs(self): - method = 'bs' - for adaptive in [True, False]: + def test_bs(self): + method = 'bs' + for adaptive in [True, False]: - print(f'Test {"adaptive" if adaptive else ""} {method} method:') - print() - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) + print(f'Test {"adaptive" if adaptive else ""} {method} method:') + print() + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) - if adaptive: - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) + if adaptive: + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) - print('-' * 40) + print('-' * 40) - def test_rkf45(self): - method = 'rkf45' - for adaptive in [True, False]: + def test_rkf45(self): + method = 'rkf45' + for adaptive in [True, False]: - print(f'Test {"adaptive" if adaptive else ""} {method} method:') - print() - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) + print(f'Test {"adaptive" if adaptive else ""} {method} method:') + print() + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) - if adaptive: - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) + if adaptive: + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) - print('-' * 40) + print('-' * 40) - def test_heun_euler(self): - method = 'heun_euler' - for adaptive in [True, False]: + def test_heun_euler(self): + method = 'heun_euler' + for adaptive in [True, False]: - print(f'Test {"adaptive" if adaptive else ""} {method} method:') - print() - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) + print(f'Test {"adaptive" if adaptive else ""} {method} method:') + print() + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) - if adaptive: - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) + if adaptive: + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) - print('-' * 40) + print('-' * 40) - def test_rkf12(self): - method = 'rkf12' - for adaptive in [True, False]: + def test_rkf12(self): + method = 'rkf12' + for adaptive in [True, False]: - print(f'Test {"adaptive" if adaptive else ""} {method} method:') - print() - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) + print(f'Test {"adaptive" if adaptive else ""} {method} method:') + print() + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) - if adaptive: - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) + if adaptive: + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) - print('-' * 40) + print('-' * 40) - def test_ck(self): - method = 'ck' - for adaptive in [True, False]: + def test_ck(self): + method = 'ck' + for adaptive in [True, False]: - print(f'Test {"adaptive" if adaptive else ""} {method} method:') - print() - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) + print(f'Test {"adaptive" if adaptive else ""} {method} method:') + print() + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_t_arg: t) - if adaptive: - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) + if adaptive: + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) - print('-' * 40) + print('-' * 40) - def test_rkdp(self): - method = 'rkdp' - for adaptive in [True, False]: + def test_rkdp(self): + method = 'rkdp' + for adaptive in [True, False]: - print(f'Test {"adaptive" if adaptive else ""} {method} method:') - print() - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) + print(f'Test {"adaptive" if adaptive else ""} {method} method:') + print() + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, p: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda f, t, dt: t) - # with pytest.raises(errors.CodeError): - # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) + # with pytest.raises(errors.CodeError): + # odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k3_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k4_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k5_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k6_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k7_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k7_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k7_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, k7_t_arg: t) - if adaptive: - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) + if adaptive: + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, error: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, v_te: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, adaptive=adaptive, f=lambda v, t, dt_new: t) - print('-' * 40) + print('-' * 40) diff --git a/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_exp_euler.py index 0d10a29da..6ae4887d0 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_exp_euler.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_exp_euler.py @@ -10,21 +10,21 @@ class TestExponentialEuler(unittest.TestCase): - def test1(self): - def func(m, t, V): - alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) - beta = 4.0 * np.exp(-(V + 65) / 18) - dmdt = alpha * (1 - m) - beta * m - return dmdt - - odeint(method='exponential_euler', show_code=True, f=func) - - def test3(self): - with pytest.raises(errors.CodeError): - def func(m, t, dt): - alpha = 0.1 * (dt + 40) / (1 - np.exp(-(dt + 40) / 10)) - beta = 4.0 * np.exp(-(dt + 65) / 18) - dmdt = alpha * (1 - m) - beta * m - return dmdt - - odeint(method='exponential_euler', show_code=True, f=func) + def test1(self): + def func(m, t, V): + alpha = 0.1 * (V + 40) / (1 - np.exp(-(V + 40) / 10)) + beta = 4.0 * np.exp(-(V + 65) / 18) + dmdt = alpha * (1 - m) - beta * m + return dmdt + + odeint(method='exponential_euler', show_code=True, f=func) + + def test3(self): + with pytest.raises(errors.CodeError): + def func(m, t, dt): + alpha = 0.1 * (dt + 40) / (1 - np.exp(-(dt + 40) / 10)) + beta = 4.0 * np.exp(-(dt + 65) / 18) + dmdt = alpha * (1 - m) - beta * m + return dmdt + + odeint(method='exponential_euler', show_code=True, f=func) diff --git a/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_general_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_general_rk.py index 323377905..05a7b974c 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_general_rk.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_keywords_for_general_rk.py @@ -9,165 +9,165 @@ class TestExplicitRKKeywords(unittest.TestCase): - def test_euler(self): - print('Test Euler method:') - print() - odeint(method='euler', show_code=True, f=lambda v, t, p: t) + def test_euler(self): + print('Test Euler method:') + print() + odeint(method='euler', show_code=True, f=lambda v, t, p: t) - with pytest.raises(errors.CodeError): - odeint(method='euler', show_code=True, f=lambda f, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method='euler', show_code=True, f=lambda f, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method='euler', show_code=True, f=lambda v, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method='euler', show_code=True, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method='euler', show_code=True, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method='euler', show_code=True, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method='euler', show_code=True, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method='euler', show_code=True, f=lambda v, t, dv_k1: t) - print('-' * 40) + print('-' * 40) - def test_order2_rk(self): - for method in ['heun2', 'midpoint', 'ralston2']: - print(f'Test {method} method:') - print() - odeint(method=method, show_code=True, f=lambda v, t, p: t) + def test_order2_rk(self): + for method in ['heun2', 'midpoint', 'ralston2']: + print(f'Test {method} method:') + print() + odeint(method=method, show_code=True, f=lambda v, t, p: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda f, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda f, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) - print('-' * 40) + print('-' * 40) - def test_rk2(self): - method = 'rk2' + def test_rk2(self): + method = 'rk2' - print(f'Test {method} method:') - print() - odeint(method=method, show_code=True, f=lambda v, t, p: t) + print(f'Test {method} method:') + print() + odeint(method=method, show_code=True, f=lambda v, t, p: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda f, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda f, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) - print('-' * 40) + print('-' * 40) - def test_order3_rk(self): - for method in ['rk3', 'heun3', 'ralston3', 'ssprk3']: - print(f'Test {method} method:') - print() - odeint(method=method, show_code=True, f=lambda v, t, p: t) + def test_order3_rk(self): + for method in ['rk3', 'heun3', 'ralston3', 'ssprk3']: + print(f'Test {method} method:') + print() + odeint(method=method, show_code=True, f=lambda v, t, p: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda f, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda f, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k3: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k3: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k3_t_arg: t) - print('-' * 40) + print('-' * 40) - def test_order4_rk(self): - for method in ['rk4', 'ralston4', 'rk4_38rule']: - print(f'Test {method} method:') - print() - odeint(method=method, show_code=True, f=lambda v, t, p: t) + def test_order4_rk(self): + for method in ['rk4', 'ralston4', 'rk4_38rule']: + print(f'Test {method} method:') + print() + odeint(method=method, show_code=True, f=lambda v, t, p: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda f, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda f, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dt: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dt: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, v_new: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, v_new: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k1: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k2: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k3: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k3: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, dv_k4: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, dv_k4: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k2_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k3_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k3_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k3_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k3_t_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k4_v_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k4_v_arg: t) - with pytest.raises(errors.CodeError): - odeint(method=method, show_code=True, f=lambda v, t, k4_t_arg: t) + with pytest.raises(errors.CodeError): + odeint(method=method, show_code=True, f=lambda v, t, k4_t_arg: t) - print('-' * 40) + print('-' * 40) diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py index 859478e39..1f9de54ee 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_method_adaptive_rk.py @@ -2,8 +2,8 @@ import unittest -import numpy as np import matplotlib.pyplot as plt +import numpy as np import brainpy.math as bm from brainpy._src.integrators.ode import adaptive_rk diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py index 6064aa3e2..57bd45835 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_method_exp_euler.py @@ -12,118 +12,117 @@ class TestExpnentialEuler(unittest.TestCase): - def test_hh_model(self): - def drivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C): - alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) - beta = 4.0 * bm.exp(-(V + 65) / 18) - dmdt = alpha * (1 - m) - beta * m + def test_hh_model(self): + def drivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C): + alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + beta = 4.0 * bm.exp(-(V + 65) / 18) + dmdt = alpha * (1 - m) - beta * m - alpha = 0.07 * bm.exp(-(V + 65) / 20.) - beta = 1 / (1 + bm.exp(-(V + 35) / 10)) - dhdt = alpha * (1 - h) - beta * h + alpha = 0.07 * bm.exp(-(V + 65) / 20.) + beta = 1 / (1 + bm.exp(-(V + 35) / 10)) + dhdt = alpha * (1 - h) - beta * h - alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) - beta = 0.125 * bm.exp(-(V + 65) / 80) - dndt = alpha * (1 - n) - beta * n + alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + beta = 0.125 * bm.exp(-(V + 65) / 80) + dndt = alpha * (1 - n) - beta * n - I_Na = (gNa * m ** 3.0 * h) * (V - ENa) - I_K = (gK * n ** 4.0) * (V - EK) - I_leak = gL * (V - EL) - dVdt = (- I_Na - I_K - I_leak + Iext) / C + I_Na = (gNa * m ** 3.0 * h) * (V - ENa) + I_K = (gK * n ** 4.0) * (V - EK) + I_leak = gL * (V - EL) + dVdt = (- I_Na - I_K - I_leak + Iext) / C - return dVdt, dmdt, dhdt, dndt + return dVdt, dmdt, dhdt, dndt - with self.assertRaises(bp.errors.DiffEqError): - ExponentialEuler(f=drivative, show_code=True, dt=0.01, var_type='SCALAR') + with self.assertRaises(bp.errors.DiffEqError): + ExponentialEuler(f=drivative, show_code=True, dt=0.01, var_type='SCALAR') - def test1(self): - def dev(x, t): - dx = bm.power(x, 3) - return dx + def test1(self): + def dev(x, t): + dx = bm.power(x, 3) + return dx - ExponentialEuler(f=dev, show_code=True, dt=0.01) + ExponentialEuler(f=dev, show_code=True, dt=0.01) class TestExpEulerAuto(unittest.TestCase): - def test_hh_model(self): - class HH(bp.dyn.NeuDyn): - def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., - gL=0.1, V_th=20., phi=5.0, name=None, method='exponential_euler'): - super(HH, self).__init__(size=size, name=name) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - self.phi = phi - - # variables - self.V = bm.Variable(bm.ones(size) * -65.) - self.h = bm.Variable(bm.ones(size) * 0.6) - self.n = bm.Variable(bm.ones(size) * 0.32) - self.spike = bm.Variable(bm.zeros(size, dtype=bool)) - self.input = bm.Variable(bm.zeros(size)) - - self.integral = bp.odeint(bp.JointEq(self.dV, self.dh, self.dn), method=method, show_code=True) - - def dh(self, h, t, V): - alpha = 0.07 * bm.exp(-(V + 58) / 20) - beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) - dhdt = self.phi * (alpha * (1 - h) - beta * h) - return dhdt - - def dn(self, n, t, V): - alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * bm.exp(-(V + 44) / 80) - dndt = self.phi * (alpha * (1 - n) - beta * n) - return dndt - - def dV(self, V, t, h, n, Iext): - m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * bm.exp(-(V + 60) / 18) - m = m_alpha / (m_alpha + m_beta) - INa = self.gNa * m ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + Iext) / self.C - - return dVdt - - def update(self): - t, dt = bp.share['t'], bp.share['dt'] - V, h, n = self.integral(self.V.value, self.h.value, self.n.value, t, self.input.value, dt=dt) - self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) - self.V.value = V - self.h.value = h - self.n.value = n - self.input[:] = 0. - - bm.random.seed() - hh1 = HH(1, method='exp_euler') - runner1 = bp.DSRunner(hh1, inputs=('input', 2.), monitors=['V', 'h', 'n']) - runner1.run(100) - plt.figure() - plt.plot(runner1.mon.ts, runner1.mon.V, label='V') - plt.plot(runner1.mon.ts, runner1.mon.h, label='h') - plt.plot(runner1.mon.ts, runner1.mon.n, label='n') - plt.show(block=block) - - hh2 = HH(1, method='exp_euler_auto') - runner2 = bp.DSRunner(hh2, inputs=('input', 2.), monitors=['V', 'h', 'n']) - runner2.run(100) - plt.figure() - plt.plot(runner2.mon.ts, runner2.mon.V, label='V') - plt.plot(runner2.mon.ts, runner2.mon.h, label='h') - plt.plot(runner2.mon.ts, runner2.mon.n, label='n') - plt.show(block=block) - - diff = (runner2.mon.V - runner1.mon.V).mean() - self.assertTrue(diff < 1e0) - - plt.close() - + def test_hh_model(self): + class HH(bp.dyn.NeuDyn): + def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., + gL=0.1, V_th=20., phi=5.0, name=None, method='exponential_euler'): + super(HH, self).__init__(size=size, name=name) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = bm.Variable(bm.ones(size) * -65.) + self.h = bm.Variable(bm.ones(size) * 0.6) + self.n = bm.Variable(bm.ones(size) * 0.32) + self.spike = bm.Variable(bm.zeros(size, dtype=bool)) + self.input = bm.Variable(bm.zeros(size)) + + self.integral = bp.odeint(bp.JointEq(self.dV, self.dh, self.dn), method=method, show_code=True) + + def dh(self, h, t, V): + alpha = 0.07 * bm.exp(-(V + 58) / 20) + beta = 1 / (bm.exp(-0.1 * (V + 28)) + 1) + dhdt = self.phi * (alpha * (1 - h) - beta * h) + return dhdt + + def dn(self, n, t, V): + alpha = -0.01 * (V + 34) / (bm.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bm.exp(-(V + 44) / 80) + dndt = self.phi * (alpha * (1 - n) - beta * n) + return dndt + + def dV(self, V, t, h, n, Iext): + m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * bm.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + return dVdt + + def update(self): + t, dt = bp.share['t'], bp.share['dt'] + V, h, n = self.integral(self.V.value, self.h.value, self.n.value, t, self.input.value, dt=dt) + self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) + self.V.value = V + self.h.value = h + self.n.value = n + self.input[:] = 0. + + bm.random.seed() + hh1 = HH(1, method='exp_euler') + runner1 = bp.DSRunner(hh1, inputs=('input', 2.), monitors=['V', 'h', 'n']) + runner1.run(100) + plt.figure() + plt.plot(runner1.mon.ts, runner1.mon.V, label='V') + plt.plot(runner1.mon.ts, runner1.mon.h, label='h') + plt.plot(runner1.mon.ts, runner1.mon.n, label='n') + plt.show(block=block) + + hh2 = HH(1, method='exp_euler_auto') + runner2 = bp.DSRunner(hh2, inputs=('input', 2.), monitors=['V', 'h', 'n']) + runner2.run(100) + plt.figure() + plt.plot(runner2.mon.ts, runner2.mon.V, label='V') + plt.plot(runner2.mon.ts, runner2.mon.h, label='h') + plt.plot(runner2.mon.ts, runner2.mon.n, label='n') + plt.show(block=block) + + diff = (runner2.mon.V - runner1.mon.V).mean() + self.assertTrue(diff < 1e0) + + plt.close() diff --git a/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py b/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py index 9a7e2bf8a..ba8b1c5d7 100644 --- a/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py +++ b/brainpy/_src/integrators/ode/tests/test_ode_method_rk.py @@ -2,9 +2,9 @@ import unittest +import jax import numpy as np -import jax import brainpy.math as bm from brainpy._src.integrators.ode import explicit_rk @@ -46,7 +46,6 @@ def step(t): jit=True ) - mon_x, mon_y, mon_z = results mon_x = np.array(mon_x).flatten() mon_y = np.array(mon_y).flatten() @@ -87,4 +86,3 @@ def test_all_methods(self): assert np.linalg.norm(mon_x - _baseline_x) / (duration / dt) < 0.1 assert np.linalg.norm(mon_y - _baseline_y) / (duration / dt) < 0.1 assert np.linalg.norm(mon_z - _baseline_z) / (duration / dt) < 0.1 - diff --git a/brainpy/_src/integrators/pde/base.py b/brainpy/_src/integrators/pde/base.py index 492a6c7de..72ae4c602 100644 --- a/brainpy/_src/integrators/pde/base.py +++ b/brainpy/_src/integrators/pde/base.py @@ -4,5 +4,4 @@ class PDEIntegrator(Integrator): - pass - + pass diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index 8939b20b1..cbe791ce9 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -18,318 +18,318 @@ from .base import Integrator __all__ = [ - 'IntegratorRunner', + 'IntegratorRunner', ] class IntegratorRunner(Runner): - """Structural runner for numerical integrators in brainpy. - - Examples:: - - Example to run an ODE integrator, - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> a=0.7; b=0.8; tau=12.5 - >>> dV = lambda V, t, w, I: V - V * V * V / 3 - w + I - >>> dw = lambda w, t, V, a, b: (V + a - b * w) / tau - >>> integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto') - >>> - >>> runner = bp.IntegratorRunner( - >>> integral, # the simulation target - >>> monitors=['V', 'w'], # the variables to monitor - >>> inits={'V': bm.random.rand(10), - >>> 'w': bm.random.normal(size=10)}, # the initial values - >>> ) - >>> runner.run(100., - >>> args={'a': 1., 'b': 1.}, # update arguments - >>> dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # each time each current input - >>> ) - >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 4], show=True) - - Example to run an SDE intragetor, - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> # stochastic Lorenz system - >>> sigma=10; beta=8 / 3; rho=28 - >>> g = lambda x, y, z, t, p: (p * x, p * y, p * z) - >>> f = lambda x, y, z, t, p: [sigma * (y - x), x * (rho - z) - y, x * y - beta * z] - >>> lorenz = bp.sdeint(f, g, method='milstein2') - >>> - >>> runner = bp.IntegratorRunner( - >>> lorenz, - >>> monitors=['x', 'y', 'z'], - >>> inits=[1., 1., 1.], # initialize all variable to 1. - >>> dt=0.01 - >>> ) - >>> runner.run(100., args={'p': 0.1},) - >>> - >>> import matplotlib.pyplot as plt - >>> fig = plt.figure() - >>> ax = fig.gca(projection='3d') - >>> plt.plot(runner.mon.x.squeeze(), runner.mon.y.squeeze(), runner.mon.z.squeeze()) - >>> ax.set_xlabel('x') - >>> ax.set_xlabel('y') - >>> ax.set_xlabel('z') - >>> plt.show() - - """ - - def __init__( - self, - target: Integrator, - - # IntegratorRunner specific arguments - inits: Union[Sequence, Dict] = None, - - # regular/common arguments - dt: Union[float, int] = None, - monitors: Sequence[str] = None, - dyn_vars: Dict[str, bm.Variable] = None, - jit: Union[bool, Dict[str, bool]] = True, - numpy_mon_after_run: bool = True, - progress_bar: bool = True, - - # deprecated - args: Dict = None, - dyn_args: Dict[str, Union[bm.ndarray, jnp.ndarray]] = None, - fun_monitors: Dict[str, Callable] = None, - ): - """Initialization of structural runner for integrators. - - Parameters:: - - target: Integrator - The target to run. - monitors: sequence of str - The variables to monitor. - fun_monitors: dict - The monitors with callable functions. - .. deprecated:: 2.3.1 - inits: sequence, dict - The initial value of variables. With this parameter, - you can easily control the number of variables to simulate. - For example, if one of the variable has the shape of 10, - then all variables will be an instance of :py:class:`brainpy.math.Variable` - with the shape of :math:`(10,)`. - args: dict - The equation arguments to update. - Note that if one of the arguments are heterogeneous (i.e., a tensor), - it means we should run multiple trials. However, you can set the number - of the elements in the variables so that each pair of variables can - correspond to one set of arguments. - - .. deprecated:: 2.3.1 - Will be removed after version 2.4.0. - - dyn_args: dict - The dynamically changed arguments. This means this argument can control - the argument dynamically changed. For example, if you want to inject a - time varied currents into the HH neuron model, you can pack the currents - into this ``dyn_args`` argument. - - .. deprecated:: 2.3.1 - Will be removed after version 2.4.0. - - dt: float, int - dyn_vars: dict - jit: bool - progress_bar: bool - numpy_mon_after_run: bool - """ - - if not isinstance(target, Integrator): - raise TypeError(f'Target must be instance of {Integrator.__name__}, ' - f'but we got {type(target)}') - # get maximum size and initial variables - if inits is not None: - if isinstance(inits, (list, tuple, bm.BaseArray, jnp.ndarray)): - assert len(target.variables) == len(inits) - inits = {k: inits[i] for i, k in enumerate(target.variables)} - assert isinstance(inits, dict), f'"inits" must be a dict, but we got {type(inits)}' - sizes = np.unique([np.size(v) for v in list(inits.values())]) - max_size = np.max(sizes) - else: - max_size = 1 - inits = dict() - - # initialize variables - self.variables = {v: bm.Variable(bm.zeros(max_size)) for v in target.variables} - for k in inits.keys(): - self.variables[k][:] = inits[k] - - # format string monitors - if isinstance(monitors, (tuple, list)): - monitors = self._format_seq_monitors(monitors) - monitors = {k: (self.variables[k], i) for k, i in monitors} - elif isinstance(monitors, dict): - monitors = self._format_dict_monitors(monitors) - monitors = {k: ((self.variables[i], i) if isinstance(i, str) else i) for k, i in monitors.items()} - else: - raise ValueError - - # initialize super class - super(IntegratorRunner, self).__init__(target=target, - monitors=monitors, - fun_monitors=fun_monitors, - jit=jit, - progress_bar=progress_bar, - dyn_vars=dyn_vars, - numpy_mon_after_run=numpy_mon_after_run) - - self.register_implicit_vars(self.variables) - - # parameters - dt = bm.get_dt() if dt is None else dt - if not isinstance(dt, (int, float)): - raise RunningError(f'"dt" must be scalar, but got {dt}') - self.dt = dt - - # target - if not isinstance(self.target, Integrator): - raise RunningError(f'"target" must be an instance of {Integrator.__name__}, ' - f'but we got {type(target)}: {target}') - - # arguments of the integral function - if args is not None: - warnings.warn('Set "args" in `IntegratorRunner.run()` function, instead of __init__ function. ' - 'Will be removed since 2.4.0', - UserWarning) - assert isinstance(args, dict), (f'"args" must be a dict, but ' - f'we got {type(args)}: {args}') - self._static_args = args - else: - self._static_args = dict() - if dyn_args is not None: - warnings.warn('Set "dyn_args" in `IntegratorRunner.run()` function, instead of __init__ function. ' - 'Will be removed since 2.4.0', - UserWarning) - assert isinstance(dyn_args, dict), (f'"dyn_args" must be a dict, but we get ' - f'{type(dyn_args)}: {dyn_args}') - sizes = np.unique([len(v) for v in dyn_args.values()]) - num_size = len(sizes) - if num_size != 1: - raise RunningError(f'All values in "dyn_args" should have the same length. ' - f'But we got {num_size}: {sizes}') - self._dyn_args = dyn_args - else: - self._dyn_args = dict() - - # start simulation time and index - self.start_t = bm.Variable(bm.zeros(1)) - self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_)) - - def _run_fun_integration(self, static_args, dyn_args, times, indices): - return bm.for_loop(partial(self._step_fun_integrator, static_args), - (dyn_args, times, indices), - jit=self.jit['predict']) - - def _step_fun_integrator(self, static_args, dyn_args, t, i): - # arguments - kwargs = Collector(dt=self.dt, t=t) - kwargs.update(static_args) - kwargs.update(dyn_args) - kwargs.update({k: v.value for k, v in self.variables.items()}) - - # call integrator function - update_values = self.target(**kwargs) - if len(self.target.variables) == 1: - self.variables[self.target.variables[0]].update(update_values) - else: - for i, v in enumerate(self.target.variables): - self.variables[v].update(update_values[i]) - - # progress bar - if self.progress_bar: - jax.debug.callback(lambda *args: self._pbar.update(), ()) - - # return of function monitors - shared = dict(t=t + self.dt, dt=self.dt, i=i) - returns = dict() - for k, v in self._monitors.items(): - if callable(v): - returns[k] = bm.as_jax(v(shared)) - else: - returns[k] = self.variables[k].value - return returns - - def run( - self, - duration: float, - start_t: float = None, - eval_time: bool = False, - args: Dict = None, - dyn_args: Dict = None, - ): - """The running function. - - Parameters:: - - duration : float, int, tuple, list - The running duration. - start_t : float, optional - The start time to simulate. - eval_time: bool - Evaluate the running time or not? - args: dict - The equation arguments to update. - .. versionadded:: 2.3.1 - - dyn_args: dict - The dynamically changed arguments over time. The size of first dimension should be - equal to the running ``duration``. - - .. versionadded:: 2.3.1 + """Structural runner for numerical integrators in brainpy. + + Examples:: + + Example to run an ODE integrator, + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> a=0.7; b=0.8; tau=12.5 + >>> dV = lambda V, t, w, I: V - V * V * V / 3 - w + I + >>> dw = lambda w, t, V, a, b: (V + a - b * w) / tau + >>> integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto') + >>> + >>> runner = bp.IntegratorRunner( + >>> integral, # the simulation target + >>> monitors=['V', 'w'], # the variables to monitor + >>> inits={'V': bm.random.rand(10), + >>> 'w': bm.random.normal(size=10)}, # the initial values + >>> ) + >>> runner.run(100., + >>> args={'a': 1., 'b': 1.}, # update arguments + >>> dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # each time each current input + >>> ) + >>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=[0, 1, 4], show=True) + + Example to run an SDE intragetor, + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> # stochastic Lorenz system + >>> sigma=10; beta=8 / 3; rho=28 + >>> g = lambda x, y, z, t, p: (p * x, p * y, p * z) + >>> f = lambda x, y, z, t, p: [sigma * (y - x), x * (rho - z) - y, x * y - beta * z] + >>> lorenz = bp.sdeint(f, g, method='milstein2') + >>> + >>> runner = bp.IntegratorRunner( + >>> lorenz, + >>> monitors=['x', 'y', 'z'], + >>> inits=[1., 1., 1.], # initialize all variable to 1. + >>> dt=0.01 + >>> ) + >>> runner.run(100., args={'p': 0.1},) + >>> + >>> import matplotlib.pyplot as plt + >>> fig = plt.figure() + >>> ax = fig.gca(projection='3d') + >>> plt.plot(runner.mon.x.squeeze(), runner.mon.y.squeeze(), runner.mon.z.squeeze()) + >>> ax.set_xlabel('x') + >>> ax.set_xlabel('y') + >>> ax.set_xlabel('z') + >>> plt.show() """ - args = dict() if args is None else args - dyn_args = dict() if dyn_args is None else dyn_args - assert isinstance(args, dict), f'"args" must be a dict, but we got {type(args)}: {args}' - assert isinstance(dyn_args, dict), f'"dyn_args" must be a dict, but we got {type(dyn_args)}: {dyn_args}' - args.update(self._static_args) - dyn_args.update(self._dyn_args) - - # time step - if start_t is None: - start_t = self.start_t[0] - end_t = start_t + duration - # times - times = bm.arange(start_t, end_t, self.dt).value - indices = bm.arange(times.size).value + self.idx.value - - _dyn_args, _ = tree_flatten(dyn_args) - for _d in _dyn_args: - if jnp.shape(_d)[0] != times.size: - raise ValueError(f'The shape of `dyn_args` does not match the given duration. ' - f'{jnp.shape(_d)[0]} != {times.size} (duration={duration}, dt={self.dt}).') - del _d - del _dyn_args - - # running - if self.progress_bar: - self._pbar = tqdm.auto.tqdm(total=times.size) - self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)", - refresh=True) - if eval_time: - t0 = time.time() - hists = self._run_fun_integration(args, dyn_args, times, indices) - if eval_time: - running_time = time.time() - t0 - if self.progress_bar: - self._pbar.close() - - # post-running - times += self.dt - if self.numpy_mon_after_run: - times = np.asarray(times) - for key in list(hists.keys()): - hists[key] = np.asarray(hists[key]) - self.mon.ts = times - for key in hists.keys(): - self.mon[key] = hists[key] - self.start_t[0] = end_t - self.idx[0] += times.size - if eval_time: - return running_time + + def __init__( + self, + target: Integrator, + + # IntegratorRunner specific arguments + inits: Union[Sequence, Dict] = None, + + # regular/common arguments + dt: Union[float, int] = None, + monitors: Sequence[str] = None, + dyn_vars: Dict[str, bm.Variable] = None, + jit: Union[bool, Dict[str, bool]] = True, + numpy_mon_after_run: bool = True, + progress_bar: bool = True, + + # deprecated + args: Dict = None, + dyn_args: Dict[str, Union[bm.ndarray, jnp.ndarray]] = None, + fun_monitors: Dict[str, Callable] = None, + ): + """Initialization of structural runner for integrators. + + Parameters:: + + target: Integrator + The target to run. + monitors: sequence of str + The variables to monitor. + fun_monitors: dict + The monitors with callable functions. + .. deprecated:: 2.3.1 + inits: sequence, dict + The initial value of variables. With this parameter, + you can easily control the number of variables to simulate. + For example, if one of the variable has the shape of 10, + then all variables will be an instance of :py:class:`brainpy.math.Variable` + with the shape of :math:`(10,)`. + args: dict + The equation arguments to update. + Note that if one of the arguments are heterogeneous (i.e., a tensor), + it means we should run multiple trials. However, you can set the number + of the elements in the variables so that each pair of variables can + correspond to one set of arguments. + + .. deprecated:: 2.3.1 + Will be removed after version 2.4.0. + + dyn_args: dict + The dynamically changed arguments. This means this argument can control + the argument dynamically changed. For example, if you want to inject a + time varied currents into the HH neuron model, you can pack the currents + into this ``dyn_args`` argument. + + .. deprecated:: 2.3.1 + Will be removed after version 2.4.0. + + dt: float, int + dyn_vars: dict + jit: bool + progress_bar: bool + numpy_mon_after_run: bool + """ + + if not isinstance(target, Integrator): + raise TypeError(f'Target must be instance of {Integrator.__name__}, ' + f'but we got {type(target)}') + # get maximum size and initial variables + if inits is not None: + if isinstance(inits, (list, tuple, bm.BaseArray, jnp.ndarray)): + assert len(target.variables) == len(inits) + inits = {k: inits[i] for i, k in enumerate(target.variables)} + assert isinstance(inits, dict), f'"inits" must be a dict, but we got {type(inits)}' + sizes = np.unique([np.size(v) for v in list(inits.values())]) + max_size = np.max(sizes) + else: + max_size = 1 + inits = dict() + + # initialize variables + self.variables = {v: bm.Variable(bm.zeros(max_size)) for v in target.variables} + for k in inits.keys(): + self.variables[k][:] = inits[k] + + # format string monitors + if isinstance(monitors, (tuple, list)): + monitors = self._format_seq_monitors(monitors) + monitors = {k: (self.variables[k], i) for k, i in monitors} + elif isinstance(monitors, dict): + monitors = self._format_dict_monitors(monitors) + monitors = {k: ((self.variables[i], i) if isinstance(i, str) else i) for k, i in monitors.items()} + else: + raise ValueError + + # initialize super class + super(IntegratorRunner, self).__init__(target=target, + monitors=monitors, + fun_monitors=fun_monitors, + jit=jit, + progress_bar=progress_bar, + dyn_vars=dyn_vars, + numpy_mon_after_run=numpy_mon_after_run) + + self.register_implicit_vars(self.variables) + + # parameters + dt = bm.get_dt() if dt is None else dt + if not isinstance(dt, (int, float)): + raise RunningError(f'"dt" must be scalar, but got {dt}') + self.dt = dt + + # target + if not isinstance(self.target, Integrator): + raise RunningError(f'"target" must be an instance of {Integrator.__name__}, ' + f'but we got {type(target)}: {target}') + + # arguments of the integral function + if args is not None: + warnings.warn('Set "args" in `IntegratorRunner.run()` function, instead of __init__ function. ' + 'Will be removed since 2.4.0', + UserWarning) + assert isinstance(args, dict), (f'"args" must be a dict, but ' + f'we got {type(args)}: {args}') + self._static_args = args + else: + self._static_args = dict() + if dyn_args is not None: + warnings.warn('Set "dyn_args" in `IntegratorRunner.run()` function, instead of __init__ function. ' + 'Will be removed since 2.4.0', + UserWarning) + assert isinstance(dyn_args, dict), (f'"dyn_args" must be a dict, but we get ' + f'{type(dyn_args)}: {dyn_args}') + sizes = np.unique([len(v) for v in dyn_args.values()]) + num_size = len(sizes) + if num_size != 1: + raise RunningError(f'All values in "dyn_args" should have the same length. ' + f'But we got {num_size}: {sizes}') + self._dyn_args = dyn_args + else: + self._dyn_args = dict() + + # start simulation time and index + self.start_t = bm.Variable(bm.zeros(1)) + self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_)) + + def _run_fun_integration(self, static_args, dyn_args, times, indices): + return bm.for_loop(partial(self._step_fun_integrator, static_args), + (dyn_args, times, indices), + jit=self.jit['predict']) + + def _step_fun_integrator(self, static_args, dyn_args, t, i): + # arguments + kwargs = Collector(dt=self.dt, t=t) + kwargs.update(static_args) + kwargs.update(dyn_args) + kwargs.update({k: v.value for k, v in self.variables.items()}) + + # call integrator function + update_values = self.target(**kwargs) + if len(self.target.variables) == 1: + self.variables[self.target.variables[0]].update(update_values) + else: + for i, v in enumerate(self.target.variables): + self.variables[v].update(update_values[i]) + + # progress bar + if self.progress_bar: + jax.debug.callback(lambda *args: self._pbar.update(), ()) + + # return of function monitors + shared = dict(t=t + self.dt, dt=self.dt, i=i) + returns = dict() + for k, v in self._monitors.items(): + if callable(v): + returns[k] = bm.as_jax(v(shared)) + else: + returns[k] = self.variables[k].value + return returns + + def run( + self, + duration: float, + start_t: float = None, + eval_time: bool = False, + args: Dict = None, + dyn_args: Dict = None, + ): + """The running function. + + Parameters:: + + duration : float, int, tuple, list + The running duration. + start_t : float, optional + The start time to simulate. + eval_time: bool + Evaluate the running time or not? + args: dict + The equation arguments to update. + .. versionadded:: 2.3.1 + + dyn_args: dict + The dynamically changed arguments over time. The size of first dimension should be + equal to the running ``duration``. + + .. versionadded:: 2.3.1 + + """ + args = dict() if args is None else args + dyn_args = dict() if dyn_args is None else dyn_args + assert isinstance(args, dict), f'"args" must be a dict, but we got {type(args)}: {args}' + assert isinstance(dyn_args, dict), f'"dyn_args" must be a dict, but we got {type(dyn_args)}: {dyn_args}' + args.update(self._static_args) + dyn_args.update(self._dyn_args) + + # time step + if start_t is None: + start_t = self.start_t[0] + end_t = start_t + duration + # times + times = bm.arange(start_t, end_t, self.dt).value + indices = bm.arange(times.size).value + self.idx.value + + _dyn_args, _ = tree_flatten(dyn_args) + for _d in _dyn_args: + if jnp.shape(_d)[0] != times.size: + raise ValueError(f'The shape of `dyn_args` does not match the given duration. ' + f'{jnp.shape(_d)[0]} != {times.size} (duration={duration}, dt={self.dt}).') + del _d + del _dyn_args + + # running + if self.progress_bar: + self._pbar = tqdm.auto.tqdm(total=times.size) + self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)", + refresh=True) + if eval_time: + t0 = time.time() + hists = self._run_fun_integration(args, dyn_args, times, indices) + if eval_time: + running_time = time.time() - t0 + if self.progress_bar: + self._pbar.close() + + # post-running + times += self.dt + if self.numpy_mon_after_run: + times = np.asarray(times) + for key in list(hists.keys()): + hists[key] = np.asarray(hists[key]) + self.mon.ts = times + for key in hists.keys(): + self.mon[key] = hists[key] + self.start_t[0] = end_t + self.idx[0] += times.size + if eval_time: + return running_time diff --git a/brainpy/_src/integrators/sde/__init__.py b/brainpy/_src/integrators/sde/__init__.py index 00b3d4641..98c9f7600 100644 --- a/brainpy/_src/integrators/sde/__init__.py +++ b/brainpy/_src/integrators/sde/__init__.py @@ -5,8 +5,6 @@ """ from .base import * +from .generic import * from .normal import * from .srk_scalar import * -from .generic import * - - diff --git a/brainpy/_src/integrators/sde/base.py b/brainpy/_src/integrators/sde/base.py index 1a0193e1f..e145c57f2 100644 --- a/brainpy/_src/integrators/sde/base.py +++ b/brainpy/_src/integrators/sde/base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Dict, Callable, Union, Sequence +from typing import Dict, Callable import jax.numpy as jnp @@ -11,77 +11,77 @@ from brainpy._src.math.delayvars import AbstractDelay __all__ = [ - 'SDEIntegrator', + 'SDEIntegrator', ] def f_names(f): - func_name = constants.unique_name('sde') - if f.__name__.isidentifier(): - func_name += '_' + f.__name__ - return func_name + func_name = constants.unique_name('sde') + if f.__name__.isidentifier(): + func_name += '_' + f.__name__ + return func_name class SDEIntegrator(Integrator): - """SDE Integrator.""" + """SDE Integrator.""" - def __init__( - self, - f: Callable, - g: Callable, - dt: float = None, - name: str = None, - show_code: bool = False, - var_type: str = None, - intg_type: str = None, - wiener_type: str = None, - state_delays: Dict[str, AbstractDelay] = None, - ): - dt = bm.get_dt() if dt is None else dt - parses = utils.get_args(f) - variables = parses[0] # variable names, (before 't') - parameters = parses[1] # parameter names, (after 't') - arguments = parses[2] # function arguments + def __init__( + self, + f: Callable, + g: Callable, + dt: float = None, + name: str = None, + show_code: bool = False, + var_type: str = None, + intg_type: str = None, + wiener_type: str = None, + state_delays: Dict[str, AbstractDelay] = None, + ): + dt = bm.get_dt() if dt is None else dt + parses = utils.get_args(f) + variables = parses[0] # variable names, (before 't') + parameters = parses[1] # parameter names, (after 't') + arguments = parses[2] # function arguments - # super initialization - super(SDEIntegrator, self).__init__(name=name, - variables=variables, - parameters=parameters, - arguments=arguments, - dt=dt, - state_delays=state_delays) + # super initialization + super(SDEIntegrator, self).__init__(name=name, + variables=variables, + parameters=parameters, + arguments=arguments, + dt=dt, + state_delays=state_delays) - # derivative functions - self.derivative = {constants.F: f, constants.G: g} - self.f = f - self.g = g + # derivative functions + self.derivative = {constants.F: f, constants.G: g} + self.f = f + self.g = g - # essential parameters - intg_type = constants.ITO_SDE if intg_type is None else intg_type - var_type = constants.SCALAR_VAR if var_type is None else var_type - wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type - if intg_type not in constants.SUPPORTED_INTG_TYPE: - raise errors.IntegratorError(f'Currently, BrainPy only support SDE_INT types: ' - f'{constants.SUPPORTED_INTG_TYPE}. But we got {intg_type}.') - if var_type not in constants.SUPPORTED_VAR_TYPE: - raise errors.IntegratorError(f'Currently, BrainPy only supports variable types: ' - f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.') - if wiener_type not in constants.SUPPORTED_WIENER_TYPE: - raise errors.IntegratorError(f'Currently, BrainPy only supports Wiener ' - f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' - f'But we got {wiener_type}.') - self.var_type = var_type # variable type - self.intg_type = intg_type # integral type - self.wiener_type = wiener_type # wiener process type + # essential parameters + intg_type = constants.ITO_SDE if intg_type is None else intg_type + var_type = constants.SCALAR_VAR if var_type is None else var_type + wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type + if intg_type not in constants.SUPPORTED_INTG_TYPE: + raise errors.IntegratorError(f'Currently, BrainPy only support SDE_INT types: ' + f'{constants.SUPPORTED_INTG_TYPE}. But we got {intg_type}.') + if var_type not in constants.SUPPORTED_VAR_TYPE: + raise errors.IntegratorError(f'Currently, BrainPy only supports variable types: ' + f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.') + if wiener_type not in constants.SUPPORTED_WIENER_TYPE: + raise errors.IntegratorError(f'Currently, BrainPy only supports Wiener ' + f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' + f'But we got {wiener_type}.') + self.var_type = var_type # variable type + self.intg_type = intg_type # integral type + self.wiener_type = wiener_type # wiener process type - # code scope - self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': bm.random.DEFAULT} - # code lines - self.func_name = f_names(f) - self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):'] - # others - self.show_code = show_code + # code scope + self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': bm.random.DEFAULT} + # code lines + self.func_name = f_names(f) + self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):'] + # others + self.show_code = show_code - def _check_vector_wiener_dim(self, noise_size, var_size): - if noise_size[:-1] > var_size[-len(noise_size) +1:]: - raise ValueError(f"Incompatible shapes for shapes of noise {noise_size} and variable {var_size}") + def _check_vector_wiener_dim(self, noise_size, var_size): + if noise_size[:-1] > var_size[-len(noise_size) + 1:]: + raise ValueError(f"Incompatible shapes for shapes of noise {noise_size} and variable {var_size}") diff --git a/brainpy/_src/integrators/sde/generic.py b/brainpy/_src/integrators/sde/generic.py index a6e25b0c3..61214f555 100644 --- a/brainpy/_src/integrators/sde/generic.py +++ b/brainpy/_src/integrators/sde/generic.py @@ -6,10 +6,10 @@ from .base import SDEIntegrator __all__ = [ - 'set_default_sdeint', - 'get_default_sdeint', - 'register_sde_integrator', - 'get_supported_methods', + 'set_default_sdeint', + 'get_default_sdeint', + 'register_sde_integrator', + 'get_supported_methods', ] name2method = { @@ -30,105 +30,105 @@ def sdeint( wiener_type: str = None, state_delays: Dict[str, Union[bm.LengthDelay, bm.TimeDelay]] = None ): - """Numerical integration for SDEs. - - Parameters:: - - f : callable, function - The derivative function. - method : str - The shortcut name of the numerical integrator. - - Returns:: - - integral : SDEIntegrator - The numerical solver of `f`. - """ - method = _DEFAULT_SDE_METHOD if method is None else method - if method not in name2method: - raise ValueError(f'Unknown SDE numerical method "{method}". Currently ' - f'BrainPy only support: {list(name2method.keys())}') - - if f is not None and g is not None: - return name2method[method](f=f, - g=g, - dt=dt, - name=name, - show_code=show_code, - var_type=var_type, - intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - - elif f is not None: - return lambda g: name2method[method](f=f, - g=g, - dt=dt, - name=name, - show_code=show_code, - var_type=var_type, - intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - - elif g is not None: - return lambda f: name2method[method](f=f, - g=g, - dt=dt, - name=name, - show_code=show_code, - var_type=var_type, - intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - - else: - raise ValueError('Must provide "f" or "g".') + """Numerical integration for SDEs. + + Parameters:: + + f : callable, function + The derivative function. + method : str + The shortcut name of the numerical integrator. + + Returns:: + + integral : SDEIntegrator + The numerical solver of `f`. + """ + method = _DEFAULT_SDE_METHOD if method is None else method + if method not in name2method: + raise ValueError(f'Unknown SDE numerical method "{method}". Currently ' + f'BrainPy only support: {list(name2method.keys())}') + + if f is not None and g is not None: + return name2method[method](f=f, + g=g, + dt=dt, + name=name, + show_code=show_code, + var_type=var_type, + intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) + + elif f is not None: + return lambda g: name2method[method](f=f, + g=g, + dt=dt, + name=name, + show_code=show_code, + var_type=var_type, + intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) + + elif g is not None: + return lambda f: name2method[method](f=f, + g=g, + dt=dt, + name=name, + show_code=show_code, + var_type=var_type, + intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) + + else: + raise ValueError('Must provide "f" or "g".') def set_default_sdeint(method): - """Set the default SDE numerical integrator method for differential equations. + """Set the default SDE numerical integrator method for differential equations. - Parameters:: + Parameters:: - method : str, callable - Numerical integrator method. - """ - if not isinstance(method, str): - raise ValueError(f'Only support string, not {type(method)}.') - if method not in name2method: - raise ValueError(f'Unsupported SDE_INT numerical method: {method}.') + method : str, callable + Numerical integrator method. + """ + if not isinstance(method, str): + raise ValueError(f'Only support string, not {type(method)}.') + if method not in name2method: + raise ValueError(f'Unsupported SDE_INT numerical method: {method}.') - global _DEFAULT_SDE_METHOD - _DEFAULT_SDE_METHOD = method + global _DEFAULT_SDE_METHOD + _DEFAULT_SDE_METHOD = method def get_default_sdeint(): - """Get the default SDE numerical integrator method. + """Get the default SDE numerical integrator method. - Returns:: + Returns:: - method : str - The default numerical integrator method. - """ - return _DEFAULT_SDE_METHOD + method : str + The default numerical integrator method. + """ + return _DEFAULT_SDE_METHOD def register_sde_integrator(name, integrator): - """Register a new SDE integrator. + """Register a new SDE integrator. - Parameters:: + Parameters:: - name: ste - integrator: type - """ - if name in name2method: - raise ValueError(f'"{name}" has been registered in SDE integrators.') - if not issubclass(integrator, SDEIntegrator): - raise ValueError(f'"integrator" must be an instance of {SDEIntegrator.__name__}') - name2method[name] = integrator + name: ste + integrator: type + """ + if name in name2method: + raise ValueError(f'"{name}" has been registered in SDE integrators.') + if not issubclass(integrator, SDEIntegrator): + raise ValueError(f'"integrator" must be an instance of {SDEIntegrator.__name__}') + name2method[name] = integrator def get_supported_methods(): - """Get all supported numerical methods for DDEs.""" - return list(name2method.keys()) + """Get all supported numerical methods for DDEs.""" + return list(name2method.keys()) diff --git a/brainpy/_src/integrators/sde/normal.py b/brainpy/_src/integrators/sde/normal.py index 62bf4b0f4..e861a8fe0 100644 --- a/brainpy/_src/integrators/sde/normal.py +++ b/brainpy/_src/integrators/sde/normal.py @@ -7,495 +7,495 @@ from brainpy import errors from brainpy._src import math as bm from brainpy._src.integrators import constants, utils, joint_eq +from brainpy._src.integrators.constants import DT from brainpy._src.integrators.sde.base import SDEIntegrator from brainpy._src.integrators.sde.generic import register_sde_integrator from brainpy._src.integrators.utils import format_args -from brainpy._src.integrators.constants import DT __all__ = [ - 'Euler', - 'Heun', - 'Milstein', - 'MilsteinGradFree', - 'ExponentialEuler', + 'Euler', + 'Heun', + 'Milstein', + 'MilsteinGradFree', + 'ExponentialEuler', ] def df_and_dg(code_lines, variables, parameters): - # 1. df - # df = f(x, t, *args) - all_df = [f'{var}_df' for var in variables] - code_lines.append(f' {", ".join(all_df)} = f({", ".join(variables + parameters)})') + # 1. df + # df = f(x, t, *args) + all_df = [f'{var}_df' for var in variables] + code_lines.append(f' {", ".join(all_df)} = f({", ".join(variables + parameters)})') - # 2. dg - # dg = g(x, t, *args) - all_dg = [f'{var}_dg' for var in variables] - code_lines.append(f' {", ".join(all_dg)} = g({", ".join(variables + parameters)})') - code_lines.append(' ') + # 2. dg + # dg = g(x, t, *args) + all_dg = [f'{var}_dg' for var in variables] + code_lines.append(f' {", ".join(all_dg)} = g({", ".join(variables + parameters)})') + code_lines.append(' ') def dfdt(code_lines, variables): - for var in variables: - code_lines.append(f' {var}_dfdt = {var}_df * {constants.DT}') - code_lines.append(' ') + for var in variables: + code_lines.append(f' {var}_dfdt = {var}_df * {constants.DT}') + code_lines.append(' ') def noise_terms(code_lines, variables): - for var in variables: - code_lines.append(f' if {var}_dg is not None:') - code_lines.append(f' {var}_dW = random.normal(0.000, dt_sqrt, math.shape({var})).value') - code_lines.append(' ') + for var in variables: + code_lines.append(f' if {var}_dg is not None:') + code_lines.append(f' {var}_dW = random.normal(0.000, dt_sqrt, math.shape({var})).value') + code_lines.append(' ') class Euler(SDEIntegrator): - r"""Euler method for the Ito and Stratonovich integrals. + r"""Euler method for the Ito and Stratonovich integrals. - For Ito schema, the Euler method (also called as Euler-Maruyama method) is given by: - - .. math:: - - \begin{aligned} - Y_{n+1} &=Y_{n}+f\left(Y_{n}\right) h_{n}+g\left(Y_{n}\right) \Delta W_{n} \\ - \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} + For Ito schema, the Euler method (also called as Euler-Maruyama method) is given by: - As the order of convergence for the Euler-Maruyama method is low (strong - order of convergence 0.5, weak order of convergence 1), the numerical results - are inaccurate unless a small step size is used. In fact, Euler-Maruyama - represents the order 0.5 strong Taylor scheme. + .. math:: - For Stratonovich scheme, the Euler-Heun method has to be used instead of the Euler-Maruyama method + \begin{aligned} + Y_{n+1} &=Y_{n}+f\left(Y_{n}\right) h_{n}+g\left(Y_{n}\right) \Delta W_{n} \\ + \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} - .. math:: + As the order of convergence for the Euler-Maruyama method is low (strong + order of convergence 0.5, weak order of convergence 1), the numerical results + are inaccurate unless a small step size is used. In fact, Euler-Maruyama + represents the order 0.5 strong Taylor scheme. - \begin{aligned} - Y_{n+1} &=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} \\ - \bar{Y}_{n} &=Y_{n}+g_{n} \Delta W_{n} \\ - \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} + For Stratonovich scheme, the Euler-Heun method has to be used instead of the Euler-Maruyama method + .. math:: - See Also:: - - Heun - - """ - - def __init__( - self, f, g, dt=None, name=None, show_code=False, - var_type=None, intg_type=None, wiener_type=None, - state_delays=None, - ): - super(Euler, self).__init__(f=f, g=g, dt=dt, name=name, - var_type=var_type, intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - - self.set_integral(self.step) - - def step(self, *args, **kwargs): - all_args = format_args(args, kwargs, self.arg_names) - dt = all_args.pop(DT, self.dt) - - # drift values - drifts = self.f(**all_args) - if len(self.variables) == 1: - if not isinstance(drifts, (bm.ndarray, jnp.ndarray)): - raise ValueError('Drift values must be a tensor when there ' - 'is only one variable in the equation.') - drifts = {self.variables[0]: drifts} - else: - if not isinstance(drifts, (tuple, list)): - raise ValueError('Drift values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - drifts = {var: drifts[i] for i, var in enumerate(self.variables)} - - # diffusion values - diffusions = self.g(**all_args) - if len(self.variables) == 1: - # if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): - # raise ValueError('Diffusion values must be a tensor when there ' - # 'is only one variable in the equation.') - diffusions = {self.variables[0]: diffusions} - else: - if not isinstance(diffusions, (tuple, list)): - raise ValueError('Diffusion values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)} - if self.wiener_type == constants.VECTOR_WIENER: - for key, val in diffusions.items(): - if val is not None and jnp.ndim(val) == 0: - raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple " - f"dimensional diffusion value. But we got a scale value for " - f"variable {key}.") - - # integral results - integrals = [] - if self.intg_type == constants.ITO_SDE: - for key in self.variables: - integral = all_args[key] + drifts[key] * dt - if diffusions[key] is not None: - shape = jnp.shape(all_args[key]) - if self.wiener_type == constants.SCALAR_WIENER: - integral += diffusions[key] * bm.random.randn(*shape) * jnp.sqrt(dt) - else: - shape += jnp.shape(diffusions[key])[-1:] - integral += jnp.sum(diffusions[key] * bm.random.randn(*shape), axis=-1) * jnp.sqrt(dt) - integrals.append(integral) - - else: - # \bar{Y}_{n}=Y_{n}+g_{n} \Delta W_{n} - all_args_bar = {key: val for key, val in all_args.items()} - all_noises = {} - for key in self.variables: - if diffusions[key] is None: - all_args_bar[key] = all_args[key] - else: - shape = jnp.shape(all_args[key]) - if self.wiener_type == constants.VECTOR_WIENER: - noise_shape = jnp.shape(diffusions[key]) - self._check_vector_wiener_dim(noise_shape, shape) - shape += noise_shape[-1:] - noise = bm.random.randn(*shape) - all_noises[key] = noise * jnp.sqrt(dt) - if self.wiener_type == constants.VECTOR_WIENER: - y_bar = all_args[key] + jnp.sum(diffusions[key] * noise, axis=-1) - else: - y_bar = all_args[key] + diffusions[key] * noise - all_args_bar[key] = y_bar - # g(\bar{Y}_{n}) - diffusion_bars = self.g(**all_args_bar) - if len(self.variables) == 1: - diffusion_bars = {self.variables[0]: diffusion_bars} - else: - diffusion_bars = {var: diffusion_bars[i] for i, var in enumerate(self.variables)} - # Y_{n+1}=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} - for key in self.variables: - integral = all_args[key] + drifts[key] * dt - if diffusion_bars[key] is not None: - integral += (diffusions[key] + diffusion_bars[key]) / 2 * all_noises[key] - integrals.append(integral) - - # return integrals - if len(self.variables) == 1: - return integrals[0] - else: - return integrals + \begin{aligned} + Y_{n+1} &=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} \\ + \bar{Y}_{n} &=Y_{n}+g_{n} \Delta W_{n} \\ + \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} -register_sde_integrator('euler', Euler) + See Also:: + Heun -class Heun(Euler): - r"""The Euler-Heun method for Stratonovich integral scheme. + """ - Its mathematical expression is given by + def __init__( + self, f, g, dt=None, name=None, show_code=False, + var_type=None, intg_type=None, wiener_type=None, + state_delays=None, + ): + super(Euler, self).__init__(f=f, g=g, dt=dt, name=name, + var_type=var_type, intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) - .. math:: + self.set_integral(self.step) - \begin{aligned} - Y_{n+1} &=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} \\ - \bar{Y}_{n} &=Y_{n}+g_{n} \Delta W_{n} \\ - \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} + def step(self, *args, **kwargs): + all_args = format_args(args, kwargs, self.arg_names) + dt = all_args.pop(DT, self.dt) + # drift values + drifts = self.f(**all_args) + if len(self.variables) == 1: + if not isinstance(drifts, (bm.ndarray, jnp.ndarray)): + raise ValueError('Drift values must be a tensor when there ' + 'is only one variable in the equation.') + drifts = {self.variables[0]: drifts} + else: + if not isinstance(drifts, (tuple, list)): + raise ValueError('Drift values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + drifts = {var: drifts[i] for i, var in enumerate(self.variables)} + + # diffusion values + diffusions = self.g(**all_args) + if len(self.variables) == 1: + # if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): + # raise ValueError('Diffusion values must be a tensor when there ' + # 'is only one variable in the equation.') + diffusions = {self.variables[0]: diffusions} + else: + if not isinstance(diffusions, (tuple, list)): + raise ValueError('Diffusion values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)} + if self.wiener_type == constants.VECTOR_WIENER: + for key, val in diffusions.items(): + if val is not None and jnp.ndim(val) == 0: + raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple " + f"dimensional diffusion value. But we got a scale value for " + f"variable {key}.") + + # integral results + integrals = [] + if self.intg_type == constants.ITO_SDE: + for key in self.variables: + integral = all_args[key] + drifts[key] * dt + if diffusions[key] is not None: + shape = jnp.shape(all_args[key]) + if self.wiener_type == constants.SCALAR_WIENER: + integral += diffusions[key] * bm.random.randn(*shape) * jnp.sqrt(dt) + else: + shape += jnp.shape(diffusions[key])[-1:] + integral += jnp.sum(diffusions[key] * bm.random.randn(*shape), axis=-1) * jnp.sqrt(dt) + integrals.append(integral) - See Also:: + else: + # \bar{Y}_{n}=Y_{n}+g_{n} \Delta W_{n} + all_args_bar = {key: val for key, val in all_args.items()} + all_noises = {} + for key in self.variables: + if diffusions[key] is None: + all_args_bar[key] = all_args[key] + else: + shape = jnp.shape(all_args[key]) + if self.wiener_type == constants.VECTOR_WIENER: + noise_shape = jnp.shape(diffusions[key]) + self._check_vector_wiener_dim(noise_shape, shape) + shape += noise_shape[-1:] + noise = bm.random.randn(*shape) + all_noises[key] = noise * jnp.sqrt(dt) + if self.wiener_type == constants.VECTOR_WIENER: + y_bar = all_args[key] + jnp.sum(diffusions[key] * noise, axis=-1) + else: + y_bar = all_args[key] + diffusions[key] * noise + all_args_bar[key] = y_bar + # g(\bar{Y}_{n}) + diffusion_bars = self.g(**all_args_bar) + if len(self.variables) == 1: + diffusion_bars = {self.variables[0]: diffusion_bars} + else: + diffusion_bars = {var: diffusion_bars[i] for i, var in enumerate(self.variables)} + # Y_{n+1}=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} + for key in self.variables: + integral = all_args[key] + drifts[key] * dt + if diffusion_bars[key] is not None: + integral += (diffusions[key] + diffusion_bars[key]) / 2 * all_noises[key] + integrals.append(integral) + + # return integrals + if len(self.variables) == 1: + return integrals[0] + else: + return integrals - Euler - """ +register_sde_integrator('euler', Euler) - def __init__(self, f, g, dt=None, name=None, show_code=False, - var_type=None, intg_type=None, wiener_type=None, - state_delays=None, ): - if intg_type != constants.STRA_SDE: - raise errors.IntegratorError(f'Heun method only supports Stranovich ' - f'integral of SDEs, but we got {intg_type} integral.') - super(Heun, self).__init__(f=f, g=g, dt=dt, name=name, - var_type=var_type, intg_type=intg_type, - wiener_type=wiener_type, state_delays=state_delays) +class Heun(Euler): + r"""The Euler-Heun method for Stratonovich integral scheme. -register_sde_integrator('heun', Heun) + Its mathematical expression is given by + .. math:: -class Milstein(SDEIntegrator): - r"""Milstein method for Ito or Stratonovich integrals. + \begin{aligned} + Y_{n+1} &=Y_{n}+f_{n} h+\frac{1}{2}\left[g_{n}+g\left(\bar{Y}_{n}\right)\right] \Delta W_{n} \\ + \bar{Y}_{n} &=Y_{n}+g_{n} \Delta W_{n} \\ + \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} - The Milstein scheme represents the order 1.0 strong Taylor scheme. For the Ito integral, - .. math:: + See Also:: - \begin{aligned} - &Y_{n+1}=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2} g_{n} g_{n}^{\prime}\left[\left(\Delta W_{n}\right)^{2}-h\right] \\ - &\Delta W_{n}=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} + Euler - where :math:`g_{n}^{\prime}=\frac{d g\left(Y_{n}\right)}{d Y_{n}}` is the first derivative of :math:`g_n`. + """ + def __init__(self, f, g, dt=None, name=None, show_code=False, + var_type=None, intg_type=None, wiener_type=None, + state_delays=None, ): + if intg_type != constants.STRA_SDE: + raise errors.IntegratorError(f'Heun method only supports Stranovich ' + f'integral of SDEs, but we got {intg_type} integral.') + super(Heun, self).__init__(f=f, g=g, dt=dt, name=name, + var_type=var_type, intg_type=intg_type, + wiener_type=wiener_type, state_delays=state_delays) - For the Stratonovich integral, the Milstein method is given by - .. math:: +register_sde_integrator('heun', Heun) - \begin{aligned} - &Y_{n+1}=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2} g_{n} g_{n}^{\prime}\left(\Delta W_{n}\right)^{2} \\ - &\Delta W_{n}=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} - - """ - - def __init__( - self, - f: Callable, - g: Callable, - dt: float = None, - name: str = None, - show_code=False, - var_type: str = None, - intg_type: str = None, - wiener_type: str = None, - state_delays: Dict[str, bm.AbstractDelay] = None, - ): - super(Milstein, self).__init__(f=f, - g=g, - dt=dt, - name=name, - var_type=var_type, - intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - self.set_integral(self.step) - - def _get_g_grad(self, f, allow_raise=False, need_grad=True): - if isinstance(f, joint_eq.JointEq): - results = [] - state = True - for sub_eq in f.eqs: - r, r_state = self._get_g_grad(sub_eq, allow_raise, need_grad) - results.extend(r) - state &= r_state - return results, state - else: - res = [None, None, None] - state = True - try: - vars, pars, _ = utils.get_args(f) - if len(vars) != 1: - raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__, - vars=str(vars), eq=str(f))) - res[1] = vars - res[2] = pars - except errors.DiffEqError as e: - state = False - if not allow_raise: - raise e - if need_grad: - res[0] = bm.vector_grad(f, argnums=0) - return [tuple(res)], state - - def step(self, *args, **kwargs): - # parse grad function and individual arguments - parses, state = self._get_g_grad(self.g, allow_raise=False, need_grad=True) - if not state: - parses2 = self._get_g_grad(self.f, allow_raise=True, need_grad=False) - if len(parses2) != len(parses): - raise ValueError(f'"f" and "g" should defined with JointEq both, and should ' - f'keep the same structure.') - parses = [a[:1] + b[1:] for a, b in zip(parses, parses2)] - - # input arguments - all_args = format_args(args, kwargs, self.arg_names) - dt = all_args.pop(DT, self.dt) - - # drift values - drifts = self.f(**all_args) - if len(self.variables) == 1: - if not isinstance(drifts, (bm.ndarray, jnp.ndarray)): - raise ValueError('Drift values must be a tensor when there ' - 'is only one variable in the equation.') - drifts = {self.variables[0]: drifts} - else: - if not isinstance(drifts, (tuple, list)): - raise ValueError('Drift values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - drifts = {var: drifts[i] for i, var in enumerate(self.variables)} - - # diffusion values - diffusions = self.g(**all_args) - if len(self.variables) == 1: - if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): - raise ValueError('Diffusion values must be a tensor when there ' - 'is only one variable in the equation.') - diffusions = {self.variables[0]: diffusions} - else: - if not isinstance(diffusions, (tuple, list)): - raise ValueError('Diffusion values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)} - if self.wiener_type == constants.VECTOR_WIENER: - for key, val in diffusions.items(): - if val is not None and jnp.ndim(val) == 0: - raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple " - f"dimensional diffusion value. But we got a scale value for " - f"variable {key}.") - - # derivative of diffusion parts - all_dg = {} - for i, key in enumerate(self.variables): - f_dg, vars_, pars_ = parses[i] - vps = vars_ + pars_ - all_dg[key] = f_dg(all_args[vps[0]], **{arg: all_args[arg] for arg in vps[1:] if arg in all_args}) - - # integral results - integrals = [] - for i, key in enumerate(self.variables): - integral = all_args[key] + drifts[key] * dt - if diffusions[key] is not None: - shape = jnp.shape(all_args[key]) - if self.wiener_type == constants.VECTOR_WIENER: - noise_shape = jnp.shape(diffusions[key]) - self._check_vector_wiener_dim(noise_shape, shape) - shape += noise_shape[-1:] - noise = bm.random.randn(*shape) * jnp.sqrt(dt) - if self.wiener_type == constants.VECTOR_WIENER: - integral += jnp.sum(diffusions[key] * noise, axis=-1) + +class Milstein(SDEIntegrator): + r"""Milstein method for Ito or Stratonovich integrals. + + The Milstein scheme represents the order 1.0 strong Taylor scheme. For the Ito integral, + + .. math:: + + \begin{aligned} + &Y_{n+1}=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2} g_{n} g_{n}^{\prime}\left[\left(\Delta W_{n}\right)^{2}-h\right] \\ + &\Delta W_{n}=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} + + where :math:`g_{n}^{\prime}=\frac{d g\left(Y_{n}\right)}{d Y_{n}}` is the first derivative of :math:`g_n`. + + + For the Stratonovich integral, the Milstein method is given by + + .. math:: + + \begin{aligned} + &Y_{n+1}=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2} g_{n} g_{n}^{\prime}\left(\Delta W_{n}\right)^{2} \\ + &\Delta W_{n}=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} + + """ + + def __init__( + self, + f: Callable, + g: Callable, + dt: float = None, + name: str = None, + show_code=False, + var_type: str = None, + intg_type: str = None, + wiener_type: str = None, + state_delays: Dict[str, bm.AbstractDelay] = None, + ): + super(Milstein, self).__init__(f=f, + g=g, + dt=dt, + name=name, + var_type=var_type, + intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) + self.set_integral(self.step) + + def _get_g_grad(self, f, allow_raise=False, need_grad=True): + if isinstance(f, joint_eq.JointEq): + results = [] + state = True + for sub_eq in f.eqs: + r, r_state = self._get_g_grad(sub_eq, allow_raise, need_grad) + results.extend(r) + state &= r_state + return results, state + else: + res = [None, None, None] + state = True + try: + vars, pars, _ = utils.get_args(f) + if len(vars) != 1: + raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__, + vars=str(vars), eq=str(f))) + res[1] = vars + res[2] = pars + except errors.DiffEqError as e: + state = False + if not allow_raise: + raise e + if need_grad: + res[0] = bm.vector_grad(f, argnums=0) + return [tuple(res)], state + + def step(self, *args, **kwargs): + # parse grad function and individual arguments + parses, state = self._get_g_grad(self.g, allow_raise=False, need_grad=True) + if not state: + parses2 = self._get_g_grad(self.f, allow_raise=True, need_grad=False) + if len(parses2) != len(parses): + raise ValueError(f'"f" and "g" should defined with JointEq both, and should ' + f'keep the same structure.') + parses = [a[:1] + b[1:] for a, b in zip(parses, parses2)] + + # input arguments + all_args = format_args(args, kwargs, self.arg_names) + dt = all_args.pop(DT, self.dt) + + # drift values + drifts = self.f(**all_args) + if len(self.variables) == 1: + if not isinstance(drifts, (bm.ndarray, jnp.ndarray)): + raise ValueError('Drift values must be a tensor when there ' + 'is only one variable in the equation.') + drifts = {self.variables[0]: drifts} + else: + if not isinstance(drifts, (tuple, list)): + raise ValueError('Drift values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + drifts = {var: drifts[i] for i, var in enumerate(self.variables)} + + # diffusion values + diffusions = self.g(**all_args) + if len(self.variables) == 1: + if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): + raise ValueError('Diffusion values must be a tensor when there ' + 'is only one variable in the equation.') + diffusions = {self.variables[0]: diffusions} else: - integral += diffusions[key] * noise - noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2 - diffusion = diffusions[key] * all_dg[key] / 2 * noise_p2 - diffusion = jnp.sum(diffusion, axis=-1) if self.wiener_type == constants.VECTOR_WIENER else diffusion - integral += diffusion - integrals.append(integral) - return integrals if len(self.variables) > 1 else integrals[0] + if not isinstance(diffusions, (tuple, list)): + raise ValueError('Diffusion values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)} + if self.wiener_type == constants.VECTOR_WIENER: + for key, val in diffusions.items(): + if val is not None and jnp.ndim(val) == 0: + raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple " + f"dimensional diffusion value. But we got a scale value for " + f"variable {key}.") + + # derivative of diffusion parts + all_dg = {} + for i, key in enumerate(self.variables): + f_dg, vars_, pars_ = parses[i] + vps = vars_ + pars_ + all_dg[key] = f_dg(all_args[vps[0]], **{arg: all_args[arg] for arg in vps[1:] if arg in all_args}) + + # integral results + integrals = [] + for i, key in enumerate(self.variables): + integral = all_args[key] + drifts[key] * dt + if diffusions[key] is not None: + shape = jnp.shape(all_args[key]) + if self.wiener_type == constants.VECTOR_WIENER: + noise_shape = jnp.shape(diffusions[key]) + self._check_vector_wiener_dim(noise_shape, shape) + shape += noise_shape[-1:] + noise = bm.random.randn(*shape) * jnp.sqrt(dt) + if self.wiener_type == constants.VECTOR_WIENER: + integral += jnp.sum(diffusions[key] * noise, axis=-1) + else: + integral += diffusions[key] * noise + noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2 + diffusion = diffusions[key] * all_dg[key] / 2 * noise_p2 + diffusion = jnp.sum(diffusion, axis=-1) if self.wiener_type == constants.VECTOR_WIENER else diffusion + integral += diffusion + integrals.append(integral) + return integrals if len(self.variables) > 1 else integrals[0] register_sde_integrator('milstein', Milstein) class MilsteinGradFree(SDEIntegrator): - r"""Derivative-free Milstein method for Ito or Stratonovich integrals. + r"""Derivative-free Milstein method for Ito or Stratonovich integrals. - The following implementation approximates the frist derivative of :math:`g` thanks to a Runge-Kutta approach. - For the Ito integral, the derivative-free Milstein method is given by + The following implementation approximates the frist derivative of :math:`g` thanks to a Runge-Kutta approach. + For the Ito integral, the derivative-free Milstein method is given by - .. math:: + .. math:: - \begin{aligned} - Y_{n+1} &=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2 \sqrt{h}}\left[g\left(\bar{Y}_{n}\right)-g_{n}\right]\left[\left(\Delta W_{n}\right)^{2}-h\right] \\ - \bar{Y}_{n} &=Y_{n}+f_{n} h+g_{n} \sqrt{h} \\ - \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} + \begin{aligned} + Y_{n+1} &=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2 \sqrt{h}}\left[g\left(\bar{Y}_{n}\right)-g_{n}\right]\left[\left(\Delta W_{n}\right)^{2}-h\right] \\ + \bar{Y}_{n} &=Y_{n}+f_{n} h+g_{n} \sqrt{h} \\ + \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} - For the Stratonovich integral, the derivative-free Milstein method is given by + For the Stratonovich integral, the derivative-free Milstein method is given by - .. math:: + .. math:: - \begin{aligned} - Y_{n+1} &=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2 \sqrt{h}}\left[g\left(\bar{Y}_{n}\right)-g_{n}\right]\left(\Delta W_{n}\right)^{2} \\ - \bar{Y}_{n} &=Y_{n}+f_{n} h+g_{n} \sqrt{h} \\ - \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) - \end{aligned} - - """ - - def __init__( - self, - f: Callable, - g: Callable, - dt: float = None, - name: str = None, - show_code=False, - var_type: str = None, - intg_type: str = None, - wiener_type: str = None, - state_delays: Dict[str, bm.AbstractDelay] = None, - ): - super(MilsteinGradFree, self).__init__(f=f, - g=g, - dt=dt, - name=name, - var_type=var_type, - intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - self.set_integral(self.step) - - def step(self, *args, **kwargs): - # input arguments - all_args = format_args(args, kwargs, self.arg_names) - dt = all_args.pop(DT, self.dt) - - # drift values - drifts = self.f(**all_args) - if len(self.variables) == 1: - if not isinstance(drifts, (bm.ndarray, jnp.ndarray)): - raise ValueError('Drift values must be a tensor when there ' - 'is only one variable in the equation.') - drifts = {self.variables[0]: drifts} - else: - if not isinstance(drifts, (tuple, list)): - raise ValueError('Drift values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - drifts = {var: drifts[i] for i, var in enumerate(self.variables)} - - # diffusion values - diffusions = self.g(**all_args) - if len(self.variables) == 1: - if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): - raise ValueError('Diffusion values must be a tensor when there ' - 'is only one variable in the equation.') - diffusions = {self.variables[0]: diffusions} - else: - if not isinstance(diffusions, (tuple, list)): - raise ValueError('Diffusion values must be a list/tuple of tensors ' - 'when there are multiple variables in the equation.') - diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)} - if self.wiener_type == constants.VECTOR_WIENER: - for key, val in diffusions.items(): - if val is not None and jnp.ndim(val) == 0: - raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple " - f"dimensional diffusion value. But we got a scale value for " - f"variable {key}.") - - # intermediate results - y_bars = {k: v for k, v in all_args.items()} - for key in self.variables: - bar = all_args[key] + drifts[key] * dt - if diffusions[key] is not None: - bar += diffusions[key] * jnp.sqrt(dt) - y_bars[key] = bar - diffusion_bars = self.g(**y_bars) - if len(self.variables) == 1: - diffusion_bars = {self.variables[0]: diffusion_bars} - else: - diffusion_bars = {var: diffusion_bars[i] for i, var in enumerate(self.variables)} - - # integral results - integrals = [] - for i, key in enumerate(self.variables): - integral = all_args[key] + drifts[key] * dt - if diffusions[key] is not None: - shape = jnp.shape(all_args[key]) - if self.wiener_type == constants.VECTOR_WIENER: - noise_shape = jnp.shape(diffusions[key]) - self._check_vector_wiener_dim(noise_shape, shape) - shape += noise_shape[-1:] - noise = bm.random.randn(*shape) * jnp.sqrt(dt) - if self.wiener_type == constants.VECTOR_WIENER: - integral += jnp.sum(diffusions[key] * noise, axis=-1) + \begin{aligned} + Y_{n+1} &=Y_{n}+f_{n} h+g_{n} \Delta W_{n}+\frac{1}{2 \sqrt{h}}\left[g\left(\bar{Y}_{n}\right)-g_{n}\right]\left(\Delta W_{n}\right)^{2} \\ + \bar{Y}_{n} &=Y_{n}+f_{n} h+g_{n} \sqrt{h} \\ + \Delta W_{n} &=\left[W_{t+h}-W_{t}\right] \sim \sqrt{h} \mathcal{N}(0,1) + \end{aligned} + + """ + + def __init__( + self, + f: Callable, + g: Callable, + dt: float = None, + name: str = None, + show_code=False, + var_type: str = None, + intg_type: str = None, + wiener_type: str = None, + state_delays: Dict[str, bm.AbstractDelay] = None, + ): + super(MilsteinGradFree, self).__init__(f=f, + g=g, + dt=dt, + name=name, + var_type=var_type, + intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) + self.set_integral(self.step) + + def step(self, *args, **kwargs): + # input arguments + all_args = format_args(args, kwargs, self.arg_names) + dt = all_args.pop(DT, self.dt) + + # drift values + drifts = self.f(**all_args) + if len(self.variables) == 1: + if not isinstance(drifts, (bm.ndarray, jnp.ndarray)): + raise ValueError('Drift values must be a tensor when there ' + 'is only one variable in the equation.') + drifts = {self.variables[0]: drifts} + else: + if not isinstance(drifts, (tuple, list)): + raise ValueError('Drift values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + drifts = {var: drifts[i] for i, var in enumerate(self.variables)} + + # diffusion values + diffusions = self.g(**all_args) + if len(self.variables) == 1: + if not isinstance(diffusions, (bm.ndarray, jnp.ndarray)): + raise ValueError('Diffusion values must be a tensor when there ' + 'is only one variable in the equation.') + diffusions = {self.variables[0]: diffusions} else: - integral += diffusions[key] * noise - noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2 - minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt) + if not isinstance(diffusions, (tuple, list)): + raise ValueError('Diffusion values must be a list/tuple of tensors ' + 'when there are multiple variables in the equation.') + diffusions = {var: diffusions[i] for i, var in enumerate(self.variables)} if self.wiener_type == constants.VECTOR_WIENER: - integral += minus * jnp.sum(noise_p2, axis=-1) + for key, val in diffusions.items(): + if val is not None and jnp.ndim(val) == 0: + raise ValueError(f"{constants.VECTOR_WIENER} wiener process needs multiple " + f"dimensional diffusion value. But we got a scale value for " + f"variable {key}.") + + # intermediate results + y_bars = {k: v for k, v in all_args.items()} + for key in self.variables: + bar = all_args[key] + drifts[key] * dt + if diffusions[key] is not None: + bar += diffusions[key] * jnp.sqrt(dt) + y_bars[key] = bar + diffusion_bars = self.g(**y_bars) + if len(self.variables) == 1: + diffusion_bars = {self.variables[0]: diffusion_bars} else: - integral += minus * noise_p2 - integrals.append(integral) - return integrals if len(self.variables) > 1 else integrals[0] + diffusion_bars = {var: diffusion_bars[i] for i, var in enumerate(self.variables)} + + # integral results + integrals = [] + for i, key in enumerate(self.variables): + integral = all_args[key] + drifts[key] * dt + if diffusions[key] is not None: + shape = jnp.shape(all_args[key]) + if self.wiener_type == constants.VECTOR_WIENER: + noise_shape = jnp.shape(diffusions[key]) + self._check_vector_wiener_dim(noise_shape, shape) + shape += noise_shape[-1:] + noise = bm.random.randn(*shape) * jnp.sqrt(dt) + if self.wiener_type == constants.VECTOR_WIENER: + integral += jnp.sum(diffusions[key] * noise, axis=-1) + else: + integral += diffusions[key] * noise + noise_p2 = (noise ** 2 - dt) if self.intg_type == constants.ITO_SDE else noise ** 2 + minus = (diffusion_bars[key] - diffusions[key]) / 2 / jnp.sqrt(dt) + if self.wiener_type == constants.VECTOR_WIENER: + integral += minus * jnp.sum(noise_p2, axis=-1) + else: + integral += minus * noise_p2 + integrals.append(integral) + return integrals if len(self.variables) > 1 else integrals[0] register_sde_integrator('milstein2', MilsteinGradFree) @@ -503,133 +503,134 @@ def step(self, *args, **kwargs): class ExponentialEuler(SDEIntegrator): - r"""First order, explicit exponential Euler method. - - For a SDE equation of the form - - .. math:: - - d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \quad y(0)=y_{0} - - its schema is given by [1]_ - - .. math:: - - y_{n+1} & =e^{\Delta t A}(y_{n}+ g(y_n)\Delta W_{n})+\varphi(\Delta t A) F(y_{n}) \Delta t \\ - &= y_n + \Delta t \varphi(\Delta t A) f(y) + e^{\Delta t A}g(y_n)\Delta W_{n} - - where :math:`\varphi(z)=\frac{e^{z}-1}{z}`. - - References:: - - .. [1] Erdoğan, Utku, and Gabriel J. Lord. "A new class of exponential integrators for stochastic - differential equations with multiplicative noise." arXiv preprint arXiv:1608.07096 (2016). - - - See Also:: - - Euler, Heun, Milstein - """ - - def __init__( - self, - f: Callable, - g: Callable, - dt: float = None, - name: str = None, - show_code: bool = False, - var_type: str = None, - intg_type: str = None, - wiener_type: str = None, - dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, - state_delays: Dict[str, bm.AbstractDelay] = None - ): - super(ExponentialEuler, self).__init__(f=f, - g=g, - dt=dt, - show_code=show_code, - name=name, - var_type=var_type, - intg_type=intg_type, - wiener_type=wiener_type, - state_delays=state_delays) - - if self.intg_type == constants.STRA_SDE: - raise NotImplementedError(f'{self.__class__.__name__} does not support integral type of {constants.STRA_SDE}. ' - f'It only supports {constants.ITO_SDE} now. ') - - # build the integrator - self.integral = self.build() - - def build(self): - parses = self._build_integrator(self.f) - all_vps = self.variables + self.parameters - - def integral_func(*args, **kwargs): - # format arguments - params_in = bm.Collector() - for i, arg in enumerate(args): - params_in[all_vps[i]] = arg - params_in.update(kwargs) - dt = params_in.pop(constants.DT, self.dt) - - # diffusion part - diffusions = self.g(**params_in) - - # call integrals - results = [] - params_in[constants.DT] = dt - for i, parse in enumerate(parses): - f_integral, vars_, pars_ = parse - vps = vars_ + pars_ + [constants.DT] - # integral of the drift part - r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) - if isinstance(diffusions, (tuple, list)): - diffusion = diffusions[i] + r"""First order, explicit exponential Euler method. + + For a SDE equation of the form + + .. math:: + + d y=(Ay+ F(y))dt + g(y)dW(t) = f(y)dt + g(y)dW(t), \quad y(0)=y_{0} + + its schema is given by [1]_ + + .. math:: + + y_{n+1} & =e^{\Delta t A}(y_{n}+ g(y_n)\Delta W_{n})+\varphi(\Delta t A) F(y_{n}) \Delta t \\ + &= y_n + \Delta t \varphi(\Delta t A) f(y) + e^{\Delta t A}g(y_n)\Delta W_{n} + + where :math:`\varphi(z)=\frac{e^{z}-1}{z}`. + + References:: + + .. [1] Erdoğan, Utku, and Gabriel J. Lord. "A new class of exponential integrators for stochastic + differential equations with multiplicative noise." arXiv preprint arXiv:1608.07096 (2016). + + + See Also:: + + Euler, Heun, Milstein + """ + + def __init__( + self, + f: Callable, + g: Callable, + dt: float = None, + name: str = None, + show_code: bool = False, + var_type: str = None, + intg_type: str = None, + wiener_type: str = None, + dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + state_delays: Dict[str, bm.AbstractDelay] = None + ): + super(ExponentialEuler, self).__init__(f=f, + g=g, + dt=dt, + show_code=show_code, + name=name, + var_type=var_type, + intg_type=intg_type, + wiener_type=wiener_type, + state_delays=state_delays) + + if self.intg_type == constants.STRA_SDE: + raise NotImplementedError( + f'{self.__class__.__name__} does not support integral type of {constants.STRA_SDE}. ' + f'It only supports {constants.ITO_SDE} now. ') + + # build the integrator + self.integral = self.build() + + def build(self): + parses = self._build_integrator(self.f) + all_vps = self.variables + self.parameters + + def integral_func(*args, **kwargs): + # format arguments + params_in = bm.Collector() + for i, arg in enumerate(args): + params_in[all_vps[i]] = arg + params_in.update(kwargs) + dt = params_in.pop(constants.DT, self.dt) + + # diffusion part + diffusions = self.g(**params_in) + + # call integrals + results = [] + params_in[constants.DT] = dt + for i, parse in enumerate(parses): + f_integral, vars_, pars_ = parse + vps = vars_ + pars_ + [constants.DT] + # integral of the drift part + r = f_integral(params_in[vps[0]], **{arg: params_in[arg] for arg in vps[1:] if arg in params_in}) + if isinstance(diffusions, (tuple, list)): + diffusion = diffusions[i] + else: + assert len(parses) == 1 + diffusion = diffusions + # diffusion part + shape = jnp.shape(params_in[vps[0]]) + if diffusion is not None: + diffusion = bm.as_jax(diffusion) + if self.wiener_type == constants.VECTOR_WIENER: + noise_shape = jnp.shape(diffusion) + self._check_vector_wiener_dim(noise_shape, shape) + shape += noise_shape[-1:] + diffusion = jnp.sum(diffusion * bm.random.randn(*shape), axis=-1) + else: + diffusion = diffusion * bm.random.randn(*shape) + r += diffusion * jnp.sqrt(params_in[constants.DT]) + # final result + results.append(r) + return results if len(self.variables) > 1 else results[0] + + return integral_func + + def _build_integrator(self, f): + if isinstance(f, joint_eq.JointEq): + results = [] + for sub_eq in f.eqs: + results.extend(self._build_integrator(sub_eq)) + return results + else: - assert len(parses) == 1 - diffusion = diffusions - # diffusion part - shape = jnp.shape(params_in[vps[0]]) - if diffusion is not None: - diffusion = bm.as_jax(diffusion) - if self.wiener_type == constants.VECTOR_WIENER: - noise_shape = jnp.shape(diffusion) - self._check_vector_wiener_dim(noise_shape, shape) - shape += noise_shape[-1:] - diffusion = jnp.sum(diffusion * bm.random.randn(*shape), axis=-1) - else: - diffusion = diffusion * bm.random.randn(*shape) - r += diffusion * jnp.sqrt(params_in[constants.DT]) - # final result - results.append(r) - return results if len(self.variables) > 1 else results[0] - - return integral_func - - def _build_integrator(self, f): - if isinstance(f, joint_eq.JointEq): - results = [] - for sub_eq in f.eqs: - results.extend(self._build_integrator(sub_eq)) - return results - - else: - vars, pars, _ = utils.get_args(f) - if len(vars) != 1: - raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__, - vars=str(vars), eq=str(f))) - value_and_grad = bm.vector_grad(f, argnums=0, return_value=True) - - # integration function - def integral(*args, **kwargs): - assert len(args) > 0 - dt = kwargs.pop('dt', self.dt) - linear, derivative = value_and_grad(*args, **kwargs) - phi = bm.as_jax(bm.exprel(dt * linear)) - return args[0] + dt * phi * derivative - - return [(integral, vars, pars), ] + vars, pars, _ = utils.get_args(f) + if len(vars) != 1: + raise errors.DiffEqError(constants.multi_vars_msg.format(cls=self.__class__.__name__, + vars=str(vars), eq=str(f))) + value_and_grad = bm.vector_grad(f, argnums=0, return_value=True) + + # integration function + def integral(*args, **kwargs): + assert len(args) > 0 + dt = kwargs.pop('dt', self.dt) + linear, derivative = value_and_grad(*args, **kwargs) + phi = bm.as_jax(bm.exprel(dt * linear)) + return args[0] + dt * phi * derivative + + return [(integral, vars, pars), ] register_sde_integrator('exponential_euler', ExponentialEuler) diff --git a/brainpy/_src/integrators/sde/srk_scalar.py b/brainpy/_src/integrators/sde/srk_scalar.py index 2336c325e..ab0640489 100644 --- a/brainpy/_src/integrators/sde/srk_scalar.py +++ b/brainpy/_src/integrators/sde/srk_scalar.py @@ -5,373 +5,374 @@ from .generic import register_sde_integrator __all__ = [ - 'SRK1W1', - 'SRK2W1', - 'KlPl', + 'SRK1W1', + 'SRK2W1', + 'KlPl', ] def _noise_terms(code_lines, variables, triple_integral=True): - # num_vars = len(variables) - # if num_vars > 1: - # code_lines.append(f' all_I1 = math.normal(0.0, dt_sqrt, ({num_vars},)+math.shape({variables[0]}))') - # code_lines.append(f' all_I0 = math.normal(0.0, dt_sqrt, ({num_vars},)+math.shape({variables[0]}))') - # code_lines.append(f' all_I10 = 0.5 * {constants.DT} * (all_I1 + all_I0 / 3.0 ** 0.5)') - # code_lines.append(f' all_I11 = 0.5 * (all_I1 ** 2 - {constants.DT})') - # if triple_integral: - # code_lines.append(f' all_I111 = (all_I1 ** 3 - 3 * {constants.DT} * all_I1) / 6') - # code_lines.append(f' ') - # for i, var in enumerate(variables): - # code_lines.append(f' {var}_I1 = all_I1[{i}]') - # code_lines.append(f' {var}_I0 = all_I0[{i}]') - # code_lines.append(f' {var}_I10 = all_I10[{i}]') - # code_lines.append(f' {var}_I11 = all_I11[{i}]') - # if triple_integral: - # code_lines.append(f' {var}_I111 = all_I111[{i}]') - # code_lines.append(f' ') - # else: - # var = variables[0] - # code_lines.append(f' {var}_I1 = math.normal(0.0, dt_sqrt, math.shape({var}))') - # code_lines.append(f' {var}_I0 = math.normal(0.0, dt_sqrt, math.shape({var}))') - # code_lines.append(f' {var}_I10 = 0.5 * {constants.DT} * ({var}_I1 + {var}_I0 / 3.0 ** 0.5)') - # code_lines.append(f' {var}_I11 = 0.5 * ({var}_I1 ** 2 - {constants.DT})') - # if triple_integral: - # code_lines.append(f' {var}_I111 = ({var}_I1 ** 3 - 3 * {constants.DT} * {var}_I1) / 6') - # code_lines.append(' ') - - for var in variables: - code_lines.append(f' {var}_I1 = dt_sqrt * random.randn(*math.shape({var}))') - code_lines.append(f' {var}_I0 = dt_sqrt * random.randn(*math.shape({var}))') - code_lines.append(f' {var}_I10 = 0.5 * {constants.DT} * ({var}_I1 + {var}_I0 / 3.0 ** 0.5)') - code_lines.append(f' {var}_I11 = 0.5 * ({var}_I1 ** 2 - {constants.DT})') - if triple_integral: - code_lines.append(f' {var}_I111 = ({var}_I1 ** 3 - 3 * {constants.DT} * {var}_I1) / 6') - code_lines.append(' ') + # num_vars = len(variables) + # if num_vars > 1: + # code_lines.append(f' all_I1 = math.normal(0.0, dt_sqrt, ({num_vars},)+math.shape({variables[0]}))') + # code_lines.append(f' all_I0 = math.normal(0.0, dt_sqrt, ({num_vars},)+math.shape({variables[0]}))') + # code_lines.append(f' all_I10 = 0.5 * {constants.DT} * (all_I1 + all_I0 / 3.0 ** 0.5)') + # code_lines.append(f' all_I11 = 0.5 * (all_I1 ** 2 - {constants.DT})') + # if triple_integral: + # code_lines.append(f' all_I111 = (all_I1 ** 3 - 3 * {constants.DT} * all_I1) / 6') + # code_lines.append(f' ') + # for i, var in enumerate(variables): + # code_lines.append(f' {var}_I1 = all_I1[{i}]') + # code_lines.append(f' {var}_I0 = all_I0[{i}]') + # code_lines.append(f' {var}_I10 = all_I10[{i}]') + # code_lines.append(f' {var}_I11 = all_I11[{i}]') + # if triple_integral: + # code_lines.append(f' {var}_I111 = all_I111[{i}]') + # code_lines.append(f' ') + # else: + # var = variables[0] + # code_lines.append(f' {var}_I1 = math.normal(0.0, dt_sqrt, math.shape({var}))') + # code_lines.append(f' {var}_I0 = math.normal(0.0, dt_sqrt, math.shape({var}))') + # code_lines.append(f' {var}_I10 = 0.5 * {constants.DT} * ({var}_I1 + {var}_I0 / 3.0 ** 0.5)') + # code_lines.append(f' {var}_I11 = 0.5 * ({var}_I1 ** 2 - {constants.DT})') + # if triple_integral: + # code_lines.append(f' {var}_I111 = ({var}_I1 ** 3 - 3 * {constants.DT} * {var}_I1) / 6') + # code_lines.append(' ') + + for var in variables: + code_lines.append(f' {var}_I1 = dt_sqrt * random.randn(*math.shape({var}))') + code_lines.append(f' {var}_I0 = dt_sqrt * random.randn(*math.shape({var}))') + code_lines.append(f' {var}_I10 = 0.5 * {constants.DT} * ({var}_I1 + {var}_I0 / 3.0 ** 0.5)') + code_lines.append(f' {var}_I11 = 0.5 * ({var}_I1 ** 2 - {constants.DT})') + if triple_integral: + code_lines.append(f' {var}_I111 = ({var}_I1 ** 3 - 3 * {constants.DT} * {var}_I1) / 6') + code_lines.append(' ') def _state1(code_lines, variables, parameters): - f_names = [f'{var}_f_H0s1' for var in variables] - g_names = [f'{var}_g_H1s1' for var in variables] - code_lines.append(f' {", ".join(f_names)} = f({", ".join(variables + parameters)})') - code_lines.append(f' {", ".join(g_names)} = g({", ".join(variables + parameters)})') - code_lines.append(' ') + f_names = [f'{var}_f_H0s1' for var in variables] + g_names = [f'{var}_g_H1s1' for var in variables] + code_lines.append(f' {", ".join(f_names)} = f({", ".join(variables + parameters)})') + code_lines.append(f' {", ".join(g_names)} = g({", ".join(variables + parameters)})') + code_lines.append(' ') class SRK1W1(SDEIntegrator): - r"""Order 2.0 weak SRK methods for SDEs with scalar Wiener process. - - This method has have strong orders :math:`(p_d, p_s) = (2.0,1.5)`. - - The Butcher table is: - - .. math:: - - \begin{array}{l|llll|llll|llll} - 0 &&&&& &&&& &&&& \\ - 3/4 &3/4&&&& 3/2&&& &&&& \\ - 0 &0&0&0&& 0&0&0&& &&&&\\ - \hline - 0 \\ - 1/4 & 1/4&&& & 1/2&&&\\ - 1 & 1&0&&& -1&0&\\ - 1/4& 0&0&1/4&& -5&3&1/2\\ - \hline - & 1/3& 2/3& 0 & 0 & -1 & 4/3 & 2/3&0 & -1 &4/3 &-1/3 &0 \\ - \hline - & &&&& 2 &-4/3 & -2/3 & 0 & -2 & 5/3 & -2/3 & 1 - \end{array} - - - References:: - - .. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential - equations—some recent developments." Recent developments in applied probability and - statistics. Physica-Verlag HD, 2010. 127-153. - .. [2] Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions of - stochastic differential equations." SIAM Journal on Numerical Analysis 48.3 - (2010): 922-952. - - """ - - def __init__(self, f, g, dt=None, name=None, show_code=False, - var_type=None, intg_type=None, wiener_type=None, state_delays=None): - super(SRK1W1, self).__init__(f=f, g=g, dt=dt, show_code=show_code, name=name, - var_type=var_type, intg_type=intg_type, - wiener_type=wiener_type, state_delays=state_delays) - assert self.wiener_type == constants.SCALAR_WIENER - self.build() - - def build(self): - # 2. code lines - self.code_lines.append(f' {constants.DT}_sqrt = {constants.DT} ** 0.5') - - # 2.1 noise - _noise_terms(self.code_lines, self.variables, triple_integral=True) - - # 2.2 stage 1 - _state1(self.code_lines, self.variables, self.parameters) - - # 2.3 stage 2 - all_H0s2, all_H1s2 = [], [] - for var in self.variables: - self.code_lines.append(f' {var}_H0s2 = {var} + {constants.DT} * 0.75 * {var}_f_H0s1 + ' - f'1.5 * {var}_g_H1s1 * {var}_I10 / {constants.DT}') - all_H0s2.append(f'{var}_H0s2') - self.code_lines.append(f' {var}_H1s2 = {var} + {constants.DT} * 0.25 * {var}_f_H0s1 + ' - f'dt_sqrt * 0.5 * {var}_g_H1s1') - all_H1s2.append(f'{var}_H1s2') - all_H0s2.append(f't + 0.75 * {constants.DT}') # t - all_H1s2.append(f't + 0.25 * {constants.DT}') # t - f_names = [f'{var}_f_H0s2' for var in self.variables] - self.code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s2 + self.parameters[1:])})') - g_names = [f'{var}_g_H1s2' for var in self.variables] - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + self.parameters[1:])})') - self.code_lines.append(' ') - - # 2.4 state 3 - all_H1s3 = [] - for var in self.variables: - self.code_lines.append(f' {var}_H1s3 = {var} + {constants.DT} * {var}_f_H0s1 - dt_sqrt * {var}_g_H1s1') - all_H1s3.append(f'{var}_H1s3') - all_H1s3.append(f't + {constants.DT}') # t - g_names = [f'{var}_g_H1s3' for var in self.variables] - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s3 + self.parameters[1:])})') - self.code_lines.append(' ') - - # 2.5 state 4 - all_H1s4 = [] - for var in self.variables: - self.code_lines.append(f' {var}_H1s4 = {var} + 0.25 * {constants.DT} * {var}_f_H0s1 + dt_sqrt * ' - f'(-5 * {var}_g_H1s1 + 3 * {var}_g_H1s2 + 0.5 * {var}_g_H1s3)') - all_H1s4.append(f'{var}_H1s4') - all_H1s4.append(f't + 0.25 * {constants.DT}') # t - g_names = [f'{var}_g_H1s4' for var in self.variables] - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s4 + self.parameters[1:])})') - self.code_lines.append(' ') - - # 2.6 final stage - for var in self.variables: - self.code_lines.append(f' {var}_f1 = {var}_f_H0s1/3 + {var}_f_H0s2 * 2/3') - self.code_lines.append( - f' {var}_g1 = -{var}_I1 - {var}_I11/dt_sqrt + 2 * {var}_I10/{constants.DT} - 2 * {var}_I111/{constants.DT}') - self.code_lines.append(f' {var}_g2 = {var}_I1 * 4/3 + {var}_I11 / dt_sqrt * 4/3 - ' - f'{var}_I10 / {constants.DT} * 4/3 + {var}_I111 / {constants.DT} * 5/3') - self.code_lines.append(f' {var}_g3 = {var}_I1 * 2/3 - {var}_I11/dt_sqrt/3 - ' - f'{var}_I10 / {constants.DT} * 2/3 - {var}_I111 / {constants.DT} * 2/3') - self.code_lines.append(f' {var}_g4 = {var}_I111 / {constants.DT}') - self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f1 + {var}_g1 * {var}_g_H1s1 + ' - f'{var}_g2 * {var}_g_H1s2 + {var}_g3 * {var}_g_H1s3 + {var}_g4 * {var}_g_H1s4') - self.code_lines.append(' ') - - # returns - new_vars = [f'{var}_new' for var in self.variables] - self.code_lines.append(f' return {", ".join(new_vars)}') - - # return and compile - self.integral = utils.compile_code( - code_scope={k: v for k, v in self.code_scope.items()}, - code_lines=self.code_lines, - show_code=self.show_code, - func_name=self.func_name) + r"""Order 2.0 weak SRK methods for SDEs with scalar Wiener process. + + This method has have strong orders :math:`(p_d, p_s) = (2.0,1.5)`. + + The Butcher table is: + + .. math:: + + \begin{array}{l|llll|llll|llll} + 0 &&&&& &&&& &&&& \\ + 3/4 &3/4&&&& 3/2&&& &&&& \\ + 0 &0&0&0&& 0&0&0&& &&&&\\ + \hline + 0 \\ + 1/4 & 1/4&&& & 1/2&&&\\ + 1 & 1&0&&& -1&0&\\ + 1/4& 0&0&1/4&& -5&3&1/2\\ + \hline + & 1/3& 2/3& 0 & 0 & -1 & 4/3 & 2/3&0 & -1 &4/3 &-1/3 &0 \\ + \hline + & &&&& 2 &-4/3 & -2/3 & 0 & -2 & 5/3 & -2/3 & 1 + \end{array} + + + References:: + + .. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential + equations—some recent developments." Recent developments in applied probability and + statistics. Physica-Verlag HD, 2010. 127-153. + .. [2] Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions of + stochastic differential equations." SIAM Journal on Numerical Analysis 48.3 + (2010): 922-952. + + """ + + def __init__(self, f, g, dt=None, name=None, show_code=False, + var_type=None, intg_type=None, wiener_type=None, state_delays=None): + super(SRK1W1, self).__init__(f=f, g=g, dt=dt, show_code=show_code, name=name, + var_type=var_type, intg_type=intg_type, + wiener_type=wiener_type, state_delays=state_delays) + assert self.wiener_type == constants.SCALAR_WIENER + self.build() + + def build(self): + # 2. code lines + self.code_lines.append(f' {constants.DT}_sqrt = {constants.DT} ** 0.5') + + # 2.1 noise + _noise_terms(self.code_lines, self.variables, triple_integral=True) + + # 2.2 stage 1 + _state1(self.code_lines, self.variables, self.parameters) + + # 2.3 stage 2 + all_H0s2, all_H1s2 = [], [] + for var in self.variables: + self.code_lines.append(f' {var}_H0s2 = {var} + {constants.DT} * 0.75 * {var}_f_H0s1 + ' + f'1.5 * {var}_g_H1s1 * {var}_I10 / {constants.DT}') + all_H0s2.append(f'{var}_H0s2') + self.code_lines.append(f' {var}_H1s2 = {var} + {constants.DT} * 0.25 * {var}_f_H0s1 + ' + f'dt_sqrt * 0.5 * {var}_g_H1s1') + all_H1s2.append(f'{var}_H1s2') + all_H0s2.append(f't + 0.75 * {constants.DT}') # t + all_H1s2.append(f't + 0.25 * {constants.DT}') # t + f_names = [f'{var}_f_H0s2' for var in self.variables] + self.code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s2 + self.parameters[1:])})') + g_names = [f'{var}_g_H1s2' for var in self.variables] + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + self.parameters[1:])})') + self.code_lines.append(' ') + + # 2.4 state 3 + all_H1s3 = [] + for var in self.variables: + self.code_lines.append(f' {var}_H1s3 = {var} + {constants.DT} * {var}_f_H0s1 - dt_sqrt * {var}_g_H1s1') + all_H1s3.append(f'{var}_H1s3') + all_H1s3.append(f't + {constants.DT}') # t + g_names = [f'{var}_g_H1s3' for var in self.variables] + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s3 + self.parameters[1:])})') + self.code_lines.append(' ') + + # 2.5 state 4 + all_H1s4 = [] + for var in self.variables: + self.code_lines.append(f' {var}_H1s4 = {var} + 0.25 * {constants.DT} * {var}_f_H0s1 + dt_sqrt * ' + f'(-5 * {var}_g_H1s1 + 3 * {var}_g_H1s2 + 0.5 * {var}_g_H1s3)') + all_H1s4.append(f'{var}_H1s4') + all_H1s4.append(f't + 0.25 * {constants.DT}') # t + g_names = [f'{var}_g_H1s4' for var in self.variables] + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s4 + self.parameters[1:])})') + self.code_lines.append(' ') + + # 2.6 final stage + for var in self.variables: + self.code_lines.append(f' {var}_f1 = {var}_f_H0s1/3 + {var}_f_H0s2 * 2/3') + self.code_lines.append( + f' {var}_g1 = -{var}_I1 - {var}_I11/dt_sqrt + 2 * {var}_I10/{constants.DT} - 2 * {var}_I111/{constants.DT}') + self.code_lines.append(f' {var}_g2 = {var}_I1 * 4/3 + {var}_I11 / dt_sqrt * 4/3 - ' + f'{var}_I10 / {constants.DT} * 4/3 + {var}_I111 / {constants.DT} * 5/3') + self.code_lines.append(f' {var}_g3 = {var}_I1 * 2/3 - {var}_I11/dt_sqrt/3 - ' + f'{var}_I10 / {constants.DT} * 2/3 - {var}_I111 / {constants.DT} * 2/3') + self.code_lines.append(f' {var}_g4 = {var}_I111 / {constants.DT}') + self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f1 + {var}_g1 * {var}_g_H1s1 + ' + f'{var}_g2 * {var}_g_H1s2 + {var}_g3 * {var}_g_H1s3 + {var}_g4 * {var}_g_H1s4') + self.code_lines.append(' ') + + # returns + new_vars = [f'{var}_new' for var in self.variables] + self.code_lines.append(f' return {", ".join(new_vars)}') + + # return and compile + self.integral = utils.compile_code( + code_scope={k: v for k, v in self.code_scope.items()}, + code_lines=self.code_lines, + show_code=self.show_code, + func_name=self.func_name) register_sde_integrator('srk1w1', SRK1W1) class SRK2W1(SDEIntegrator): - r"""Order 1.5 Strong SRK Methods for SDEs with Scalar Noise. - - This method has have strong orders :math:`(p_d, p_s) = (3.0,1.5)`. - - The Butcher table is: - - .. math:: - - \begin{array}{c|cccc|cccc|ccc|} - 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ - 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ - 1 / 2 & 1 / 4 & 1 / 4 & 0 & 0 & 1 & 1 / 2 & 0 & 0 & & & & \\ - 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ - \hline 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ - 1 / 4 & 1 / 4 & 0 & 0 & 0 & -1 / 2 & 0 & 0 & 0 & & & & \\ - 1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & & & & \\ - 1 / 4 & 0 & 0 & 1 / 4 & 0 & 2 & -1 & 1 / 2 & 0 & & & & \\ - \hline & 1 / 6 & 1 / 6 & 2 / 3 & 0 & -1 & 4 / 3 & 2 / 3 & 0 & -1 & -4 / 3 & 1 / 3 & 0 \\ - \hline & & & & &2 & -4 / 3 & -2 / 3 & 0 & -2 & 5 / 3 & -2 / 3 & 1 - \end{array} - - - References:: - - .. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential - equations—some recent developments." Recent developments in applied probability and - statistics. Physica-Verlag HD, 2010. 127-153. - .. [2] Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions of - stochastic differential equations." SIAM Journal on Numerical Analysis 48.3 - (2010): 922-952. - """ - - def __init__(self, f, g, dt=None, name=None, show_code=False, - var_type=None, intg_type=None, wiener_type=None, state_delays=None): - super(SRK2W1, self).__init__(f=f, g=g, dt=dt, show_code=show_code, name=name, - var_type=var_type, intg_type=intg_type, - wiener_type=wiener_type, state_delays=state_delays) - assert self.wiener_type == constants.SCALAR_WIENER - self.build() - - def build(self): - self.code_lines.append(f' {constants.DT}_sqrt = {constants.DT} ** 0.5') - - # 2.1 noise - _noise_terms(self.code_lines, self.variables, triple_integral=True) - - # 2.2 stage 1 - _state1(self.code_lines, self.variables, self.parameters) - - # 2.3 stage 2 - # ---- - # H0s2 = x + dt * f_H0s1 - # H1s2 = x + dt * 0.25 * f_H0s1 - dt_sqrt * 0.5 * g_H1s1 - # f_H0s2 = f(H0s2, t + dt, *args) - # g_H1s2 = g(H1s2, t + 0.25 * dt, *args) - all_H0s2, all_H1s2 = [], [] - for var in self.variables: - self.code_lines.append(f' {var}_H0s2 = {var} + {constants.DT} * {var}_f_H0s1') - all_H0s2.append(f'{var}_H0s2') - self.code_lines.append(f' {var}_H1s2 = {var} + {constants.DT} * 0.25 * {var}_f_H0s1 - ' - f'dt_sqrt * 0.5 * {var}_g_H1s1') - all_H1s2.append(f'{var}_H1s2') - all_H0s2.append(f't + {constants.DT}') # t - all_H1s2.append(f't + 0.25 * {constants.DT}') # t - f_names = [f'{var}_f_H0s2' for var in self.variables] - self.code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s2 + self.parameters[1:])})') - g_names = [f'{var}_g_H1s2' for var in self.variables] - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + self.parameters[1:])})') - self.code_lines.append(' ') - - # 2.4 state 3 - # --- - # H0s3 = x + dt * (0.25 * f_H0s1 + 0.25 * f_H0s2) + (g_H1s1 + 0.5 * g_H1s2) * I10 / dt - # H1s3 = x + dt * f_H0s1 + dt_sqrt * g_H1s1 - # f_H0s3 = g(H0s3, t + 0.5 * dt, *args) - # g_H1s3 = g(H1s3, t + dt, *args) - all_H0s3, all_H1s3 = [], [] - for var in self.variables: - self.code_lines.append(f' {var}_H0s3 = {var} + {constants.DT} * (0.25 * {var}_f_H0s1 + 0.25 * {var}_f_H0s2) + ' - f'({var}_g_H1s1 + 0.5 * {var}_g_H1s2) * {var}_I10 / {constants.DT}') - all_H0s3.append(f'{var}_H0s3') - self.code_lines.append(f' {var}_H1s3 = {var} + {constants.DT} * {var}_f_H0s1 + dt_sqrt * {var}_g_H1s1') - all_H1s3.append(f'{var}_H1s3') - all_H0s3.append(f't + 0.5 * {constants.DT}') # t - all_H1s3.append(f't + {constants.DT}') # t - f_names = [f'{var}_f_H0s3' for var in self.variables] - g_names = [f'{var}_g_H1s3' for var in self.variables] - self.code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s3 + self.parameters[1:])})') - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s3 + self.parameters[1:])})') - self.code_lines.append(' ') - - # 2.5 state 4 - # ---- - # H1s4 = x + dt * 0.25 * f_H0s3 + dt_sqrt * (2 * g_H1s1 - g_H1s2 + 0.5 * g_H1s3) - # g_H1s4 = g(H1s4, t + 0.25 * dt, *args) - all_H1s4 = [] - for var in self.variables: - self.code_lines.append(f' {var}_H1s4 = {var} + 0.25 * {constants.DT} * {var}_f_H0s1 + dt_sqrt * ' - f'(2 * {var}_g_H1s1 - {var}_g_H1s2 + 0.5 * {var}_g_H1s3)') - all_H1s4.append(f'{var}_H1s4') - all_H1s4.append(f't + 0.25 * {constants.DT}') # t - g_names = [f'{var}_g_H1s4' for var in self.variables] - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s4 + self.parameters[1:])})') - self.code_lines.append(' ') - - # 2.6 final stage - # ---- - # f1 = f_H0s1 / 6 + f_H0s2 / 6 + f_H0s3 * 2 / 3 - # g1 = - I1 + I11 / dt_sqrt + 2 * I10 / dt - 2 * I111 / dt - # g2 = I1 * 4 / 3 - I11 / dt_sqrt * 4 / 3 - I10 / dt * 4 / 3 + I111 / dt * 5 / 3 - # g3 = I1 * 2 / 3 + I11 / dt_sqrt / 3 - I10 / dt * 2 / 3 - I111 / dt * 2 / 3 - # g4 = I111 / dt - # y1 = x + dt * f1 + g1 * g_H1s1 + g2 * g_H1s2 + g3 * g_H1s3 + g4 * g_H1s4 - for var in self.variables: - self.code_lines.append(f' {var}_f1 = {var}_f_H0s1/6 + {var}_f_H0s2/6 + {var}_f_H0s3*2/3') - self.code_lines.append( - f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + 2 * {var}_I10/{constants.DT} - 2 * {var}_I111/{constants.DT}') - self.code_lines.append(f' {var}_g2 = {var}_I1 * 4/3 - {var}_I11 / dt_sqrt * 4/3 - ' - f'{var}_I10 / {constants.DT} * 4/3 + {var}_I111 / {constants.DT} * 5/3') - self.code_lines.append(f' {var}_g3 = {var}_I1 * 2/3 + {var}_I11/dt_sqrt/3 - ' - f'{var}_I10 / {constants.DT} * 2/3 - {var}_I111 / {constants.DT} * 2/3') - self.code_lines.append(f' {var}_g4 = {var}_I111 / {constants.DT}') - self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f1 + {var}_g1 * {var}_g_H1s1 + ' - f'{var}_g2 * {var}_g_H1s2 + {var}_g3 * {var}_g_H1s3 + {var}_g4 * {var}_g_H1s4') - self.code_lines.append(' ') - - # returns - new_vars = [f'{var}_new' for var in self.variables] - self.code_lines.append(f' return {", ".join(new_vars)}') - - # return and compile - self.integral = utils.compile_code( - code_scope={k: v for k, v in self.code_scope.items()}, - code_lines=self.code_lines, - show_code=self.show_code, - func_name=self.func_name) + r"""Order 1.5 Strong SRK Methods for SDEs with Scalar Noise. + + This method has have strong orders :math:`(p_d, p_s) = (3.0,1.5)`. + + The Butcher table is: + + .. math:: + + \begin{array}{c|cccc|cccc|ccc|} + 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ + 1 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ + 1 / 2 & 1 / 4 & 1 / 4 & 0 & 0 & 1 & 1 / 2 & 0 & 0 & & & & \\ + 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ + \hline 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & & & & \\ + 1 / 4 & 1 / 4 & 0 & 0 & 0 & -1 / 2 & 0 & 0 & 0 & & & & \\ + 1 & 1 & 0 & 0 & 0 & 1 & 0 & 0 & 0 & & & & \\ + 1 / 4 & 0 & 0 & 1 / 4 & 0 & 2 & -1 & 1 / 2 & 0 & & & & \\ + \hline & 1 / 6 & 1 / 6 & 2 / 3 & 0 & -1 & 4 / 3 & 2 / 3 & 0 & -1 & -4 / 3 & 1 / 3 & 0 \\ + \hline & & & & &2 & -4 / 3 & -2 / 3 & 0 & -2 & 5 / 3 & -2 / 3 & 1 + \end{array} + + + References:: + + .. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential + equations—some recent developments." Recent developments in applied probability and + statistics. Physica-Verlag HD, 2010. 127-153. + .. [2] Rößler, Andreas. "Runge–Kutta methods for the strong approximation of solutions of + stochastic differential equations." SIAM Journal on Numerical Analysis 48.3 + (2010): 922-952. + """ + + def __init__(self, f, g, dt=None, name=None, show_code=False, + var_type=None, intg_type=None, wiener_type=None, state_delays=None): + super(SRK2W1, self).__init__(f=f, g=g, dt=dt, show_code=show_code, name=name, + var_type=var_type, intg_type=intg_type, + wiener_type=wiener_type, state_delays=state_delays) + assert self.wiener_type == constants.SCALAR_WIENER + self.build() + + def build(self): + self.code_lines.append(f' {constants.DT}_sqrt = {constants.DT} ** 0.5') + + # 2.1 noise + _noise_terms(self.code_lines, self.variables, triple_integral=True) + + # 2.2 stage 1 + _state1(self.code_lines, self.variables, self.parameters) + + # 2.3 stage 2 + # ---- + # H0s2 = x + dt * f_H0s1 + # H1s2 = x + dt * 0.25 * f_H0s1 - dt_sqrt * 0.5 * g_H1s1 + # f_H0s2 = f(H0s2, t + dt, *args) + # g_H1s2 = g(H1s2, t + 0.25 * dt, *args) + all_H0s2, all_H1s2 = [], [] + for var in self.variables: + self.code_lines.append(f' {var}_H0s2 = {var} + {constants.DT} * {var}_f_H0s1') + all_H0s2.append(f'{var}_H0s2') + self.code_lines.append(f' {var}_H1s2 = {var} + {constants.DT} * 0.25 * {var}_f_H0s1 - ' + f'dt_sqrt * 0.5 * {var}_g_H1s1') + all_H1s2.append(f'{var}_H1s2') + all_H0s2.append(f't + {constants.DT}') # t + all_H1s2.append(f't + 0.25 * {constants.DT}') # t + f_names = [f'{var}_f_H0s2' for var in self.variables] + self.code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s2 + self.parameters[1:])})') + g_names = [f'{var}_g_H1s2' for var in self.variables] + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + self.parameters[1:])})') + self.code_lines.append(' ') + + # 2.4 state 3 + # --- + # H0s3 = x + dt * (0.25 * f_H0s1 + 0.25 * f_H0s2) + (g_H1s1 + 0.5 * g_H1s2) * I10 / dt + # H1s3 = x + dt * f_H0s1 + dt_sqrt * g_H1s1 + # f_H0s3 = g(H0s3, t + 0.5 * dt, *args) + # g_H1s3 = g(H1s3, t + dt, *args) + all_H0s3, all_H1s3 = [], [] + for var in self.variables: + self.code_lines.append( + f' {var}_H0s3 = {var} + {constants.DT} * (0.25 * {var}_f_H0s1 + 0.25 * {var}_f_H0s2) + ' + f'({var}_g_H1s1 + 0.5 * {var}_g_H1s2) * {var}_I10 / {constants.DT}') + all_H0s3.append(f'{var}_H0s3') + self.code_lines.append(f' {var}_H1s3 = {var} + {constants.DT} * {var}_f_H0s1 + dt_sqrt * {var}_g_H1s1') + all_H1s3.append(f'{var}_H1s3') + all_H0s3.append(f't + 0.5 * {constants.DT}') # t + all_H1s3.append(f't + {constants.DT}') # t + f_names = [f'{var}_f_H0s3' for var in self.variables] + g_names = [f'{var}_g_H1s3' for var in self.variables] + self.code_lines.append(f' {", ".join(f_names)} = f({", ".join(all_H0s3 + self.parameters[1:])})') + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s3 + self.parameters[1:])})') + self.code_lines.append(' ') + + # 2.5 state 4 + # ---- + # H1s4 = x + dt * 0.25 * f_H0s3 + dt_sqrt * (2 * g_H1s1 - g_H1s2 + 0.5 * g_H1s3) + # g_H1s4 = g(H1s4, t + 0.25 * dt, *args) + all_H1s4 = [] + for var in self.variables: + self.code_lines.append(f' {var}_H1s4 = {var} + 0.25 * {constants.DT} * {var}_f_H0s1 + dt_sqrt * ' + f'(2 * {var}_g_H1s1 - {var}_g_H1s2 + 0.5 * {var}_g_H1s3)') + all_H1s4.append(f'{var}_H1s4') + all_H1s4.append(f't + 0.25 * {constants.DT}') # t + g_names = [f'{var}_g_H1s4' for var in self.variables] + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s4 + self.parameters[1:])})') + self.code_lines.append(' ') + + # 2.6 final stage + # ---- + # f1 = f_H0s1 / 6 + f_H0s2 / 6 + f_H0s3 * 2 / 3 + # g1 = - I1 + I11 / dt_sqrt + 2 * I10 / dt - 2 * I111 / dt + # g2 = I1 * 4 / 3 - I11 / dt_sqrt * 4 / 3 - I10 / dt * 4 / 3 + I111 / dt * 5 / 3 + # g3 = I1 * 2 / 3 + I11 / dt_sqrt / 3 - I10 / dt * 2 / 3 - I111 / dt * 2 / 3 + # g4 = I111 / dt + # y1 = x + dt * f1 + g1 * g_H1s1 + g2 * g_H1s2 + g3 * g_H1s3 + g4 * g_H1s4 + for var in self.variables: + self.code_lines.append(f' {var}_f1 = {var}_f_H0s1/6 + {var}_f_H0s2/6 + {var}_f_H0s3*2/3') + self.code_lines.append( + f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + 2 * {var}_I10/{constants.DT} - 2 * {var}_I111/{constants.DT}') + self.code_lines.append(f' {var}_g2 = {var}_I1 * 4/3 - {var}_I11 / dt_sqrt * 4/3 - ' + f'{var}_I10 / {constants.DT} * 4/3 + {var}_I111 / {constants.DT} * 5/3') + self.code_lines.append(f' {var}_g3 = {var}_I1 * 2/3 + {var}_I11/dt_sqrt/3 - ' + f'{var}_I10 / {constants.DT} * 2/3 - {var}_I111 / {constants.DT} * 2/3') + self.code_lines.append(f' {var}_g4 = {var}_I111 / {constants.DT}') + self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f1 + {var}_g1 * {var}_g_H1s1 + ' + f'{var}_g2 * {var}_g_H1s2 + {var}_g3 * {var}_g_H1s3 + {var}_g4 * {var}_g_H1s4') + self.code_lines.append(' ') + + # returns + new_vars = [f'{var}_new' for var in self.variables] + self.code_lines.append(f' return {", ".join(new_vars)}') + + # return and compile + self.integral = utils.compile_code( + code_scope={k: v for k, v in self.code_scope.items()}, + code_lines=self.code_lines, + show_code=self.show_code, + func_name=self.func_name) register_sde_integrator('srk2w1', SRK2W1) class KlPl(SDEIntegrator): - def __init__(self, f, g, dt=None, name=None, show_code=False, - var_type=None, intg_type=None, wiener_type=None, state_delays=None): - super(KlPl, self).__init__(f=f, g=g, dt=dt, show_code=show_code, name=name, - var_type=var_type, intg_type=intg_type, - wiener_type=wiener_type, state_delays=state_delays) - assert self.wiener_type == constants.SCALAR_WIENER - self.build() - - def build(self): - self.code_lines.append(f' {constants.DT}_sqrt = {constants.DT} ** 0.5') - - # 2.1 noise - _noise_terms(self.code_lines, self.variables, triple_integral=False) - - # 2.2 stage 1 - _state1(self.code_lines, self.variables, self.parameters) - - # 2.3 stage 2 - # ---- - # H1s2 = x + dt * f_H0s1 + dt_sqrt * g_H1s1 - # g_H1s2 = g(H1s2, t0, *args) - all_H1s2 = [] - for var in self.variables: - self.code_lines.append(f' {var}_H1s2 = {var} + {constants.DT} * {var}_f_H0s1 + dt_sqrt * {var}_g_H1s1') - all_H1s2.append(f'{var}_H1s2') - g_names = [f'{var}_g_H1s2' for var in self.variables] - self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + self.parameters)})') - self.code_lines.append(' ') - - # 2.4 final stage - # ---- - # g1 = (I1 - I11 / dt_sqrt + I10 / dt) - # g2 = I11 / dt_sqrt - # y1 = x + dt * f_H0s1 + g1 * g_H1s1 + g2 * g_H1s2 - for var in self.variables: - self.code_lines.append(f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + {var}_I10/{constants.DT}') - self.code_lines.append(f' {var}_g2 = {var}_I11 / dt_sqrt') - self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f_H0s1 + ' - f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2') - self.code_lines.append(' ') - - # returns - new_vars = [f'{var}_new' for var in self.variables] - self.code_lines.append(f' return {", ".join(new_vars)}') - - # return and compile - self.integral = utils.compile_code( - code_scope={k: v for k, v in self.code_scope.items()}, - code_lines=self.code_lines, - show_code=self.show_code, - func_name=self.func_name) + def __init__(self, f, g, dt=None, name=None, show_code=False, + var_type=None, intg_type=None, wiener_type=None, state_delays=None): + super(KlPl, self).__init__(f=f, g=g, dt=dt, show_code=show_code, name=name, + var_type=var_type, intg_type=intg_type, + wiener_type=wiener_type, state_delays=state_delays) + assert self.wiener_type == constants.SCALAR_WIENER + self.build() + + def build(self): + self.code_lines.append(f' {constants.DT}_sqrt = {constants.DT} ** 0.5') + + # 2.1 noise + _noise_terms(self.code_lines, self.variables, triple_integral=False) + + # 2.2 stage 1 + _state1(self.code_lines, self.variables, self.parameters) + + # 2.3 stage 2 + # ---- + # H1s2 = x + dt * f_H0s1 + dt_sqrt * g_H1s1 + # g_H1s2 = g(H1s2, t0, *args) + all_H1s2 = [] + for var in self.variables: + self.code_lines.append(f' {var}_H1s2 = {var} + {constants.DT} * {var}_f_H0s1 + dt_sqrt * {var}_g_H1s1') + all_H1s2.append(f'{var}_H1s2') + g_names = [f'{var}_g_H1s2' for var in self.variables] + self.code_lines.append(f' {", ".join(g_names)} = g({", ".join(all_H1s2 + self.parameters)})') + self.code_lines.append(' ') + + # 2.4 final stage + # ---- + # g1 = (I1 - I11 / dt_sqrt + I10 / dt) + # g2 = I11 / dt_sqrt + # y1 = x + dt * f_H0s1 + g1 * g_H1s1 + g2 * g_H1s2 + for var in self.variables: + self.code_lines.append(f' {var}_g1 = -{var}_I1 + {var}_I11/dt_sqrt + {var}_I10/{constants.DT}') + self.code_lines.append(f' {var}_g2 = {var}_I11 / dt_sqrt') + self.code_lines.append(f' {var}_new = {var} + {constants.DT} * {var}_f_H0s1 + ' + f'{var}_g1 * {var}_g_H1s1 + {var}_g2 * {var}_g_H1s2') + self.code_lines.append(' ') + + # returns + new_vars = [f'{var}_new' for var in self.variables] + self.code_lines.append(f' return {", ".join(new_vars)}') + + # return and compile + self.integral = utils.compile_code( + code_scope={k: v for k, v in self.code_scope.items()}, + code_lines=self.code_lines, + show_code=self.show_code, + func_name=self.func_name) register_sde_integrator('klpl', KlPl) diff --git a/brainpy/_src/integrators/sde/srk_strong.py b/brainpy/_src/integrators/sde/srk_strong.py index 5d798db90..f75f4e284 100644 --- a/brainpy/_src/integrators/sde/srk_strong.py +++ b/brainpy/_src/integrators/sde/srk_strong.py @@ -4,39 +4,38 @@ from brainpy._src.integrators import constants, utils __all__ = [ - 'srk1_strong', + 'srk1_strong', ] - _SDE_UNKNOWN_NO = 0 def basic_info(f, g): - vdt = 'dt' - if f.__name__.isidentifier(): - func_name = f.__name__ - elif g.__name__.isidentifier(): - func_name = g.__name__ - else: - global _SDE_UNKNOWN_NO - func_name = f'unknown_sde{_SDE_UNKNOWN_NO}' - func_new_name = constants.SDE_INT + func_name - variables, parameters, arguments = utils.get_args(f) - return vdt, variables, parameters, arguments, func_new_name + vdt = 'dt' + if f.__name__.isidentifier(): + func_name = f.__name__ + elif g.__name__.isidentifier(): + func_name = g.__name__ + else: + global _SDE_UNKNOWN_NO + func_name = f'unknown_sde{_SDE_UNKNOWN_NO}' + func_new_name = constants.SDE_INT + func_name + variables, parameters, arguments = utils.get_args(f) + return vdt, variables, parameters, arguments, func_new_name def _vector_wiener_terms(code_lines, sde_type, vdt, shape_D, shape_m): - if sde_type == constants.ITO_SDE: - I2 = f'0.5*(_term3 - {vdt} * math.eye({shape_m})) + _a*0.5*{vdt}/math.pi' - elif sde_type == constants.STRA_SDE: - I2 = f'0.5*_term3 + _a*0.5*dt/math.pi' - else: - raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') + if sde_type == constants.ITO_SDE: + I2 = f'0.5*(_term3 - {vdt} * math.eye({shape_m})) + _a*0.5*{vdt}/math.pi' + elif sde_type == constants.STRA_SDE: + I2 = f'0.5*_term3 + _a*0.5*dt/math.pi' + else: + raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') - if shape_D: - shape_D = shape_D + '+' + if shape_D: + shape_D = shape_D + '+' - noise_string = f''' + noise_string = f''' # Noise Terms # # ----------- # @@ -56,8 +55,8 @@ def _vector_wiener_terms(code_lines, sde_type, vdt, shape_D, shape_m): _term3 = math.matmul(_I1_rs, math.reshape(_I1, {shape_D}(1, {shape_m}))) _I2 = {I2} ''' - noise_lines = noise_string.split('\n') - code_lines.extend(noise_lines) + noise_lines = noise_string.split('\n') + code_lines.extend(noise_lines) # ---------- @@ -66,89 +65,89 @@ def _vector_wiener_terms(code_lines, sde_type, vdt, shape_D, shape_m): def _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt): - # shape information - # ----- - all_f = [f'f_{var}' for var in variables] - all_g = [f'g_{var}' for var in variables] - noise_string = f''' + # shape information + # ----- + all_f = [f'f_{var}' for var in variables] + all_g = [f'g_{var}' for var in variables] + noise_string = f''' {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..) {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (.., m) noise_shape = math.shape(g_x1) _D = noise_shape[:-1] _m = noise_shape[-1] ''' - code_lines.extend(noise_string.split("\n")) - - # noise terms - _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m') - - # numerical integration - # step 1 - # --- - # g_x1_rs = math.reshape(g_x1, _D + (1, _m)) - # g_x2_rs = math.reshape(g_x2, _D + (1, _m)) - for var in variables: - code_lines.append(f" g_{var}_rs = math.reshape(g_{var}, _D+(1, _m))") - # step 2 - # --- - # g_H1_x1 = math.reshape(math.matmul(g_x1_rs, _I2) / dt_sqrt, _D + (_m,)) - # g_H1_x2 = math.reshape(math.matmul(g_x2_rs, _I2) / dt_sqrt, _D + (_m,)) - for var in variables: - code_lines.append(f' g_H1_{var} = math.reshape(math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt, _D + (_m,))') - # step 3 - # --- - # x1_rs = math.reshape(x1, _D + (1,)) - # x2_rs = math.reshape(x2, _D + (1,)) - for var in variables: - code_lines.append(f' {var}_rs = math.reshape({var}, _D + (1,))') - # step 4 - # --- - # H2_x1 = x1_rs + g_H1_x1 - # H3_x1 = x1_rs - g_H1_x1 - for var in variables: - code_lines.append(f' H2_{var} = {var}_rs + g_H1_{var}') - code_lines.append(f' H3_{var} = {var}_rs - g_H1_{var}') - code_lines.append(' ') - # step 5 - # --- - # _g_x1 = math.matmul(g_x1_rs, _I1_rs) - for var in variables: - code_lines.append(f' _g_{var} = math.matmul(g_{var}_rs, _I1_rs)') - # step 6 - # ---- - # x1_new = x1 + f_x1 + _g_x1[..., 0, 0] - for var in variables: - code_lines.append(f' {var}_new = {var} + f_{var} + _g_{var}[..., 0, 0]') - # for _k in range(_m): - code_lines.append('for _k in range(_m):') - # g_x1_H2, g_x2_H2 = g(H2_x1[..., _k], H2_x2[..., _k], t, *args) - all_H2 = [f'H2_{var}[..., _k]' for var in variables] - all_g_H2 = [f'g_{var}_H2' for var in variables] - code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})') - # g_x1_H3, g_x2_H3 = g(H3_x1[..., _k], H3_x2[..., _k], t, *args) - all_H3 = [f'H3_{var}[..., _k]' for var in variables] - all_g_H3 = [f'g_{var}_H3' for var in variables] - code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})') - # x1_new += 0.5 * dt_sqrt * (g_x1_H2[..., _k] - g_x1_H3[..., _k]) - # x2_new += 0.5 * dt_sqrt * (g_x2_H2[..., _k] - g_x2_H3[..., _k]) - for var in variables: - code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[..., _k] - g_{var}_H3[..., _k])') + code_lines.extend(noise_string.split("\n")) + + # noise terms + _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m') + + # numerical integration + # step 1 + # --- + # g_x1_rs = math.reshape(g_x1, _D + (1, _m)) + # g_x2_rs = math.reshape(g_x2, _D + (1, _m)) + for var in variables: + code_lines.append(f" g_{var}_rs = math.reshape(g_{var}, _D+(1, _m))") + # step 2 + # --- + # g_H1_x1 = math.reshape(math.matmul(g_x1_rs, _I2) / dt_sqrt, _D + (_m,)) + # g_H1_x2 = math.reshape(math.matmul(g_x2_rs, _I2) / dt_sqrt, _D + (_m,)) + for var in variables: + code_lines.append(f' g_H1_{var} = math.reshape(math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt, _D + (_m,))') + # step 3 + # --- + # x1_rs = math.reshape(x1, _D + (1,)) + # x2_rs = math.reshape(x2, _D + (1,)) + for var in variables: + code_lines.append(f' {var}_rs = math.reshape({var}, _D + (1,))') + # step 4 + # --- + # H2_x1 = x1_rs + g_H1_x1 + # H3_x1 = x1_rs - g_H1_x1 + for var in variables: + code_lines.append(f' H2_{var} = {var}_rs + g_H1_{var}') + code_lines.append(f' H3_{var} = {var}_rs - g_H1_{var}') + code_lines.append(' ') + # step 5 + # --- + # _g_x1 = math.matmul(g_x1_rs, _I1_rs) + for var in variables: + code_lines.append(f' _g_{var} = math.matmul(g_{var}_rs, _I1_rs)') + # step 6 + # ---- + # x1_new = x1 + f_x1 + _g_x1[..., 0, 0] + for var in variables: + code_lines.append(f' {var}_new = {var} + f_{var} + _g_{var}[..., 0, 0]') + # for _k in range(_m): + code_lines.append('for _k in range(_m):') + # g_x1_H2, g_x2_H2 = g(H2_x1[..., _k], H2_x2[..., _k], t, *args) + all_H2 = [f'H2_{var}[..., _k]' for var in variables] + all_g_H2 = [f'g_{var}_H2' for var in variables] + code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})') + # g_x1_H3, g_x2_H3 = g(H3_x1[..., _k], H3_x2[..., _k], t, *args) + all_H3 = [f'H3_{var}[..., _k]' for var in variables] + all_g_H3 = [f'g_{var}_H3' for var in variables] + code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})') + # x1_new += 0.5 * dt_sqrt * (g_x1_H2[..., _k] - g_x1_H3[..., _k]) + # x2_new += 0.5 * dt_sqrt * (g_x2_H2[..., _k] - g_x2_H3[..., _k]) + for var in variables: + code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[..., _k] - g_{var}_H3[..., _k])') def _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt): - if sde_type == constants.ITO_SDE: - I2 = f'0.5 * (_I1 * _I1 - {vdt})' - elif sde_type == constants.STRA_SDE: - I2 = f'0.5 * _I1 * _I1' - else: - raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') - - # shape info - # ----- - all_f = [f'f_{var}' for var in variables] - all_g = [f'g_{var}' for var in variables] - - code_string = f''' + if sde_type == constants.ITO_SDE: + I2 = f'0.5 * (_I1 * _I1 - {vdt})' + elif sde_type == constants.STRA_SDE: + I2 = f'0.5 * _I1 * _I1' + else: + raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') + + # shape info + # ----- + all_f = [f'f_{var}' for var in variables] + all_g = [f'g_{var}' for var in variables] + + code_string = f''' {", ".join(all_f)} = f({", ".join(variables + parameters)}) # shape = (..) {", ".join(all_g)} = g({", ".join(variables + parameters)}) # shape = (..) @@ -157,39 +156,39 @@ def _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, param # double Ito integrals _I2 = {I2} # shape = (..) ''' - code_splits = code_string.split('\n') - code_lines.extend(code_splits) - - # numerical integration - # ----- - # H1 - for var in variables: - code_lines.append(f' g_H1_{var} = g_{var} * _I2 / {vdt}_sqrt # shape (.., )') - # H2 - all_H2 = [f'H2_{var}' for var in variables] - for var in variables: - code_lines.append(f' H2_{var} = {var} + g_H1_{var} # shape (.., )') - all_g_H2 = [f'g_{var}_H2' for var in variables] - code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})') - code_lines.append(f' ') - # H3 - all_H3 = [f'H3_{var}' for var in variables] - for var in variables: - code_lines.append(f' H3_{var} = {var} - g_H1_{var} # shape (.., )') - all_g_H3 = [f'g_{var}_H3' for var in variables] - code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})') - code_lines.append(f' ') - # final results - for var in variables: - code_lines.append(f' {var}_new = {var} + f_{var} + g_{var} * _I1 ' - f'+ 0.5 * {vdt}_sqrt * (g_{var}_H2 - g_{var}_H3)') + code_splits = code_string.split('\n') + code_lines.extend(code_splits) + + # numerical integration + # ----- + # H1 + for var in variables: + code_lines.append(f' g_H1_{var} = g_{var} * _I2 / {vdt}_sqrt # shape (.., )') + # H2 + all_H2 = [f'H2_{var}' for var in variables] + for var in variables: + code_lines.append(f' H2_{var} = {var} + g_H1_{var} # shape (.., )') + all_g_H2 = [f'g_{var}_H2' for var in variables] + code_lines.append(f' {", ".join(all_g_H2)} = g({", ".join(all_H2 + parameters)})') + code_lines.append(f' ') + # H3 + all_H3 = [f'H3_{var}' for var in variables] + for var in variables: + code_lines.append(f' H3_{var} = {var} - g_H1_{var} # shape (.., )') + all_g_H3 = [f'g_{var}_H3' for var in variables] + code_lines.append(f' {", ".join(all_g_H3)} = g({", ".join(all_H3 + parameters)})') + code_lines.append(f' ') + # final results + for var in variables: + code_lines.append(f' {var}_new = {var} + f_{var} + g_{var} * _I1 ' + f'+ 0.5 * {vdt}_sqrt * (g_{var}_H2 - g_{var}_H3)') def _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt): - # shape information - all_f = [f'f_{var}' for var in variables] - all_g = [f'g_{var}' for var in variables] - code1 = f''' + # shape information + all_f = [f'f_{var}' for var in variables] + all_g = [f'g_{var}' for var in variables] + code1 = f''' # shape info # # ---------- # @@ -198,64 +197,64 @@ def _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, paramet noise_shape = math.shape(g_x1) _m = noise_shape[0] ''' - code_lines.extend(code1.split('\n')) - - # noise term - _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='', shape_m='_m') - - # numerical integration - - # p1 - # --- - # g_x1_rs = math.reshape(g_x1, (1, _m)) - # g_x2_rs = math.reshape(g_x2, (1, _m)) - for var in variables: - code_lines.append(f' g_{var}_rs = math.reshape(g_{var}, (1, _m))') - - # p2 - # --- - # g_H1_x1 = math.matmul(g_x1_rs, _I2) / dt_sqrt # shape (1, m) - # g_H1_x2 = math.matmul(g_x2_rs, _I2) / dt_sqrt # shape (1, m) - for var in variables: - code_lines.append(f' g_H1_{var} = math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt # shape (1, m)') - - # p3 - # --- - # H2_x1 = x1 + g_H1_x1[0] # shape (m) - # H3_x1 = x1 - g_H1_x1[0] # shape (m) - for var in variables: - code_lines.append(f' H2_{var} = {var} + g_H1_{var}[0] # shape (m)') - code_lines.append(' ') - - # p4 - # --- - # g1_x1 = math.matmul(g_x1_rs, _I1_rs) # shape (1, 1) - # x1_new = x1 + f_x1 + g1_x1[0, 0] # shape () - for var in variables: - code_lines.append(f' g1_{var} = math.matmul(g_{var}_rs, _I1_rs) # shape (1, 1)') - code_lines.append(f' {var}_new = {var} + f_{var} + g1_{var}[0, 0] # shape ()') - - # p5 - # --- - # for _k in range(_m): - # g_x1_H2, g_x2_H2 = g(H2_x1[_k], H2_x2[_k], t, *args) - # g_x1_H3, g_x2_H3 = g(H3_x1[_k], H3_x2[_k], t, *args) - # x1_new += 0.5 * dt_sqrt * (g_x1_H2[_k] - g_x1_H3[_k]) - # x2_new += 0.5 * dt_sqrt * (g_x2_H2[_k] - g_x2_H3[_k]) - code_lines.append(' for _k in range(_m):') - all_h2_k = [f'H2_{var}[_k]' for var in variables] - all_g_h2 = [f'g_{var}_H2' for var in variables] - code_lines.append(f' {", ".join(all_g_h2)} = g({", ".join(all_h2_k + parameters)})') - all_h3_k = [f'H3_{var}[_k]' for var in variables] - all_g_h3 = [f'g_{var}_H3' for var in variables] - code_lines.append(f' {", ".join(all_g_h3)} = g({", ".join(all_h3_k + parameters)})') - for var in variables: - code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[_k] - g_{var}_H3[_k])') + code_lines.extend(code1.split('\n')) + + # noise term + _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='', shape_m='_m') + + # numerical integration + + # p1 + # --- + # g_x1_rs = math.reshape(g_x1, (1, _m)) + # g_x2_rs = math.reshape(g_x2, (1, _m)) + for var in variables: + code_lines.append(f' g_{var}_rs = math.reshape(g_{var}, (1, _m))') + + # p2 + # --- + # g_H1_x1 = math.matmul(g_x1_rs, _I2) / dt_sqrt # shape (1, m) + # g_H1_x2 = math.matmul(g_x2_rs, _I2) / dt_sqrt # shape (1, m) + for var in variables: + code_lines.append(f' g_H1_{var} = math.matmul(g_{var}_rs, _I2) / {vdt}_sqrt # shape (1, m)') + + # p3 + # --- + # H2_x1 = x1 + g_H1_x1[0] # shape (m) + # H3_x1 = x1 - g_H1_x1[0] # shape (m) + for var in variables: + code_lines.append(f' H2_{var} = {var} + g_H1_{var}[0] # shape (m)') + code_lines.append(' ') + + # p4 + # --- + # g1_x1 = math.matmul(g_x1_rs, _I1_rs) # shape (1, 1) + # x1_new = x1 + f_x1 + g1_x1[0, 0] # shape () + for var in variables: + code_lines.append(f' g1_{var} = math.matmul(g_{var}_rs, _I1_rs) # shape (1, 1)') + code_lines.append(f' {var}_new = {var} + f_{var} + g1_{var}[0, 0] # shape ()') + + # p5 + # --- + # for _k in range(_m): + # g_x1_H2, g_x2_H2 = g(H2_x1[_k], H2_x2[_k], t, *args) + # g_x1_H3, g_x2_H3 = g(H3_x1[_k], H3_x2[_k], t, *args) + # x1_new += 0.5 * dt_sqrt * (g_x1_H2[_k] - g_x1_H3[_k]) + # x2_new += 0.5 * dt_sqrt * (g_x2_H2[_k] - g_x2_H3[_k]) + code_lines.append(' for _k in range(_m):') + all_h2_k = [f'H2_{var}[_k]' for var in variables] + all_g_h2 = [f'g_{var}_H2' for var in variables] + code_lines.append(f' {", ".join(all_g_h2)} = g({", ".join(all_h2_k + parameters)})') + all_h3_k = [f'H3_{var}[_k]' for var in variables] + all_g_h3 = [f'g_{var}_H3' for var in variables] + code_lines.append(f' {", ".join(all_g_h3)} = g({", ".join(all_h3_k + parameters)})') + for var in variables: + code_lines.append(f' {var}_new += 0.5 * {vdt}_sqrt * (g_{var}_H2[_k] - g_{var}_H3[_k])') def _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt): - # shape information - code1 = f''' + # shape information + code1 = f''' # shape infor # # ----------- # @@ -266,13 +265,13 @@ def _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, paramet _m = _shape[-1] _D = _shape[1:-1] ''' - code_lines.extend(code1.split('\n')) + code_lines.extend(code1.split('\n')) - # noise term - _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m') + # noise term + _vector_wiener_terms(code_lines, sde_type, vdt, shape_D='_D', shape_m='_m') - # numerical integration - code2 = f''' + # numerical integration + code2 = f''' # numerical integration # # --------------------- # @@ -290,18 +289,18 @@ def _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, paramet y += 0.5 * dt_sqrt * g(H2[..., _k], t, *args)[..., _k] y -= 0.5 * dt_sqrt * g(H3[..., _k], t, *args)[..., _k] ''' - code_lines.extend(code2.split('\n')) + code_lines.extend(code2.split('\n')) def _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt): - if sde_type == constants.ITO_SDE: - I2 = f'0.5 * (_I1 * _I1 - {vdt})' - elif sde_type == constants.STRA_SDE: - I2 = f'0.5 * _I1 * _I1' - else: - raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') - - code_string = f''' + if sde_type == constants.ITO_SDE: + I2 = f'0.5 * (_I1 * _I1 - {vdt})' + elif sde_type == constants.STRA_SDE: + I2 = f'0.5 * _I1 * _I1' + else: + raise ValueError(f'Unknown SDE_INT type: {sde_type}. We only supports {constants.SUPPORTED_INTG_TYPE}.') + + code_string = f''' f_x = f({", ".join(variables + parameters)}) # shape = (d, ..) g_x = g({", ".join(variables + parameters)}) # shape = (d, ..) _shape = math.shape(g_x) @@ -324,126 +323,126 @@ def _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, paramet x_new += 0.5 * {vdt}_sqrt * g(H2, {", ".join(parameters)}) x_new -= 0.5 * {vdt}_sqrt * g(H3, {", ".join(parameters)}) ''' - code_splits = code_string.split('\n') - code_lines.extend(code_splits) + code_splits = code_string.split('\n') + code_lines.extend(code_splits) def _srk1_wrapper(f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter): - vdt, variables, parameters, arguments, func_name = basic_info(f=f, g=g) + vdt, variables, parameters, arguments, func_name = basic_info(f=f, g=g) + + # 1. code scope + code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, + 'math': math, 'num_iter': num_iter} + + # 2. code lines + code_lines = [f'def {func_name}({", ".join(arguments)}):'] + + if var_type == constants.SYSTEM_VAR: + if len(variables) > 1: + raise ValueError(f'SDE_INT with {constants.SYSTEM_VAR} variable type only ' + f'supports one system variable. But we got {variables}.') + + if wiener_type == constants.SCALAR_WIENER: + _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) + elif wiener_type == constants.VECTOR_WIENER: + _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt) + else: + raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' + f'supports {constants.SUPPORTED_WIENER_TYPE}') + + elif var_type == constants.SCALAR_VAR: + if wiener_type == constants.SCALAR_WIENER: + _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) + elif wiener_type == constants.VECTOR_WIENER: + _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt) + else: + raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' + f'supports {constants.SUPPORTED_WIENER_TYPE}') + + elif var_type == constants.POP_VAR: + if wiener_type == constants.SCALAR_WIENER: + _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) + elif wiener_type == constants.VECTOR_WIENER: + _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt) + else: + raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' + f'supports {constants.SUPPORTED_WIENER_TYPE}') - # 1. code scope - code_scope = {'f': f, 'g': g, vdt: dt, f'{vdt}_sqrt': dt ** 0.5, - 'math': math, 'num_iter': num_iter} - - # 2. code lines - code_lines = [f'def {func_name}({", ".join(arguments)}):'] - - if var_type == constants.SYSTEM_VAR: - if len(variables) > 1: - raise ValueError(f'SDE_INT with {constants.SYSTEM_VAR} variable type only ' - f'supports one system variable. But we got {variables}.') - - if wiener_type == constants.SCALAR_WIENER: - _srk1_system_var_with_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) - elif wiener_type == constants.VECTOR_WIENER: - _srk1_system_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt) - else: - raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' - f'supports {constants.SUPPORTED_WIENER_TYPE}') - - elif var_type == constants.SCALAR_VAR: - if wiener_type == constants.SCALAR_WIENER: - _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) - elif wiener_type == constants.VECTOR_WIENER: - _srk1_scalar_var_with_vector_wiener(sde_type, code_lines, variables, parameters, vdt) else: - raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' - f'supports {constants.SUPPORTED_WIENER_TYPE}') - - elif var_type == constants.POP_VAR: - if wiener_type == constants.SCALAR_WIENER: - _srk2_pop_or_scalar_var_scalar_wiener(sde_type, code_lines, variables, parameters, vdt) - elif wiener_type == constants.VECTOR_WIENER: - _srk2_pop_var_vector_wiener(sde_type, code_lines, variables, parameters, vdt) - else: - raise ValueError(f'Unknown Wiener type: {wiener_type}, we only ' - f'supports {constants.SUPPORTED_WIENER_TYPE}') - - else: - raise ValueError(f'Unknown var type: {var_type}, we only ' - f'supports {constants.SUPPORTED_VAR_TYPE}') - # returns - new_vars = [f'{var}_new' for var in variables] - code_lines.append(f' return {", ".join(new_vars)}') + raise ValueError(f'Unknown var type: {var_type}, we only ' + f'supports {constants.SUPPORTED_VAR_TYPE}') + # returns + new_vars = [f'{var}_new' for var in variables] + code_lines.append(f' return {", ".join(new_vars)}') - # return and compile - utils.compile_code(code_lines, code_scope, show_code, variables) - return code_scope[func_name] + # return and compile + utils.compile_code(code_lines, code_scope, show_code, variables) + return code_scope[func_name] def _srk2_wrapper(): - pass + pass def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter): - """The brainpy_object function to format a SRK method. - - Parameters:: - - f : callable - The drift function of the SDE_INT. - g : callable - The diffusion function of the SDE_INT. - dt : float - The numerical precision. - sde_type : str - "utils.ITO_SDE" : Ito's Stochastic Calculus. - "utils.STRA_SDE" : Stratonovich's Stochastic Calculus. - wiener_type : str - var_type : str - "scalar" : with the shape of (). - "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...). - "system": with the shape of (d, ), (d, N), or (d, N1, N2). - show_code : bool - Whether show the formatted code. - - Returns:: - - numerical_func : callable - The numerical function. - """ - - sde_type = constants.ITO_SDE if sde_type is None else sde_type - assert sde_type in constants.SUPPORTED_INTG_TYPE, f'Currently, BrainPy only support SDE_INT types: ' \ - f'{constants.SUPPORTED_INTG_TYPE}. But we got {sde_type}.' - - var_type = constants.POP_VAR if var_type is None else var_type - assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \ - f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.' - - wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type - assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \ - f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \ - f'But we got {wiener_type}.' - - show_code = False if show_code is None else show_code - dt = math.get_dt() if dt is None else dt - num_iter = 10 if num_iter is None else num_iter - - if f is not None and g is not None: - return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, - var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) - - elif f is not None: - return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, - var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) - - elif g is not None: - return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, - var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) - - else: - raise ValueError('Must provide "f" or "g".') + """The brainpy_object function to format a SRK method. + + Parameters:: + + f : callable + The drift function of the SDE_INT. + g : callable + The diffusion function of the SDE_INT. + dt : float + The numerical precision. + sde_type : str + "utils.ITO_SDE" : Ito's Stochastic Calculus. + "utils.STRA_SDE" : Stratonovich's Stochastic Calculus. + wiener_type : str + var_type : str + "scalar" : with the shape of (). + "population" : with the shape of (N,) or (N1, N2) or (N1, N2, ...). + "system": with the shape of (d, ), (d, N), or (d, N1, N2). + show_code : bool + Whether show the formatted code. + + Returns:: + + numerical_func : callable + The numerical function. + """ + + sde_type = constants.ITO_SDE if sde_type is None else sde_type + assert sde_type in constants.SUPPORTED_INTG_TYPE, f'Currently, BrainPy only support SDE_INT types: ' \ + f'{constants.SUPPORTED_INTG_TYPE}. But we got {sde_type}.' + + var_type = constants.POP_VAR if var_type is None else var_type + assert var_type in constants.SUPPORTED_VAR_TYPE, f'Currently, BrainPy only supports variable types: ' \ + f'{constants.SUPPORTED_VAR_TYPE}. But we got {var_type}.' + + wiener_type = constants.SCALAR_WIENER if wiener_type is None else wiener_type + assert wiener_type in constants.SUPPORTED_WIENER_TYPE, f'Currently, BrainPy only supports Wiener ' \ + f'Process types: {constants.SUPPORTED_WIENER_TYPE}. ' \ + f'But we got {wiener_type}.' + + show_code = False if show_code is None else show_code + dt = math.get_dt() if dt is None else dt + num_iter = 10 if num_iter is None else num_iter + + if f is not None and g is not None: + return wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, + var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) + + elif f is not None: + return lambda g: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, + var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) + + elif g is not None: + return lambda f: wrapper(f=f, g=g, dt=dt, show_code=show_code, sde_type=sde_type, + var_type=var_type, wiener_type=wiener_type, num_iter=num_iter) + + else: + raise ValueError('Must provide "f" or "g".') # ------------------ @@ -452,10 +451,10 @@ def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_ite def srk1_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None): - return _wrap(_srk1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type, - wiener_type=wiener_type, show_code=show_code, num_iter=num_iter) + return _wrap(_srk1_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type, + wiener_type=wiener_type, show_code=show_code, num_iter=num_iter) def srk2_strong(f=None, g=None, dt=None, sde_type=None, var_type=None, wiener_type=None, num_iter=None, show_code=None): - return _wrap(_srk2_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type, - wiener_type=wiener_type, show_code=show_code, num_iter=num_iter) + return _wrap(_srk2_wrapper, f=f, g=g, dt=dt, sde_type=sde_type, var_type=var_type, + wiener_type=wiener_type, show_code=show_code, num_iter=num_iter) diff --git a/brainpy/_src/integrators/sde/tests/test_normal.py b/brainpy/_src/integrators/sde/tests/test_normal.py index 73a2734b0..4c73c8c28 100644 --- a/brainpy/_src/integrators/sde/tests/test_normal.py +++ b/brainpy/_src/integrators/sde/tests/test_normal.py @@ -3,134 +3,130 @@ import unittest +import matplotlib.pyplot as plt + import brainpy as bp import brainpy.math as bm -import matplotlib.pyplot as plt from brainpy._src.integrators.sde.normal import ExponentialEuler show = False class TestExpEuler(unittest.TestCase): - def test1(self): - p = 0.1 - - def lorenz_g(x, y, z, t, **kwargs): - return p * x, p * y, p * z - - dx = lambda x, t, y, sigma=10: sigma * (y - x) - dy = lambda y, t, x, z, rho=28: x * (rho - z) - y - dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z - - bm.random.seed() - intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]), - g=lorenz_g, - intg_type=bp.integrators.ITO_SDE, - wiener_type=bp.integrators.SCALAR_WIENER, - var_type=bp.integrators.POP_VAR, - show_code=True) - runner = bp.IntegratorRunner(intg, - monitors=['x', 'y', 'z'], - dt=0.001, inits=[1., 1., 0.]) - runner.run(100.) - - plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) - if show: - plt.show() - plt.close() - - - - def test2(self): - p = 0.1 - p2 = 0.02 - - def lorenz_g(x, y, z, t, **kwargs): - return bp.math.asarray([p * x, p2 * x]), \ - bp.math.asarray([p * y, p2 * y]), \ - bp.math.asarray([p * z, p2 * z]) - - dx = lambda x, t, y, sigma=10: sigma * (y - x) - dy = lambda y, t, x, z, rho=28: x * (rho - z) - y - dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z - - bm.random.seed() - intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]), - g=lorenz_g, - intg_type=bp.integrators.ITO_SDE, - wiener_type=bp.integrators.VECTOR_WIENER, - var_type=bp.integrators.POP_VAR, - show_code=True) - runner = bp.IntegratorRunner(intg, monitors=['x', 'y', 'z'], - dt=0.001, inits=[1., 1., 0.], jit=False) - with self.assertRaises(ValueError): - runner.run(100.) - - - def test3(self): - p = 0.1 - p2 = 0.02 - - def lorenz_g(x, y, z, t, **kwargs): - return bp.math.asarray([p * x, p2 * x]).T, \ - bp.math.asarray([p * y, p2 * y]).T, \ - bp.math.asarray([p * z, p2 * z]).T - - bm.random.seed() - dx = lambda x, t, y, sigma=10: sigma * (y - x) - dy = lambda y, t, x, z, rho=28: x * (rho - z) - y - dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z - - intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]), - g=lorenz_g, - intg_type=bp.integrators.ITO_SDE, - wiener_type=bp.integrators.VECTOR_WIENER, - var_type=bp.integrators.POP_VAR, - show_code=True) - runner = bp.IntegratorRunner(intg, - monitors=['x', 'y', 'z'], - dt=0.001, - inits=[1., 1., 0.], - jit=True) - runner.run(100.) - - plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) - if show: - plt.show() - plt.close() - + def test1(self): + p = 0.1 + + def lorenz_g(x, y, z, t, **kwargs): + return p * x, p * y, p * z + + dx = lambda x, t, y, sigma=10: sigma * (y - x) + dy = lambda y, t, x, z, rho=28: x * (rho - z) - y + dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z + + bm.random.seed() + intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]), + g=lorenz_g, + intg_type=bp.integrators.ITO_SDE, + wiener_type=bp.integrators.SCALAR_WIENER, + var_type=bp.integrators.POP_VAR, + show_code=True) + runner = bp.IntegratorRunner(intg, + monitors=['x', 'y', 'z'], + dt=0.001, inits=[1., 1., 0.]) + runner.run(100.) + + plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) + if show: + plt.show() + plt.close() + + def test2(self): + p = 0.1 + p2 = 0.02 + + def lorenz_g(x, y, z, t, **kwargs): + return bp.math.asarray([p * x, p2 * x]), \ + bp.math.asarray([p * y, p2 * y]), \ + bp.math.asarray([p * z, p2 * z]) + + dx = lambda x, t, y, sigma=10: sigma * (y - x) + dy = lambda y, t, x, z, rho=28: x * (rho - z) - y + dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z + + bm.random.seed() + intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]), + g=lorenz_g, + intg_type=bp.integrators.ITO_SDE, + wiener_type=bp.integrators.VECTOR_WIENER, + var_type=bp.integrators.POP_VAR, + show_code=True) + runner = bp.IntegratorRunner(intg, monitors=['x', 'y', 'z'], + dt=0.001, inits=[1., 1., 0.], jit=False) + with self.assertRaises(ValueError): + runner.run(100.) + + def test3(self): + p = 0.1 + p2 = 0.02 + + def lorenz_g(x, y, z, t, **kwargs): + return bp.math.asarray([p * x, p2 * x]).T, \ + bp.math.asarray([p * y, p2 * y]).T, \ + bp.math.asarray([p * z, p2 * z]).T + + bm.random.seed() + dx = lambda x, t, y, sigma=10: sigma * (y - x) + dy = lambda y, t, x, z, rho=28: x * (rho - z) - y + dz = lambda z, t, x, y, beta=8 / 3: x * y - beta * z + + intg = ExponentialEuler(f=bp.JointEq([dx, dy, dz]), + g=lorenz_g, + intg_type=bp.integrators.ITO_SDE, + wiener_type=bp.integrators.VECTOR_WIENER, + var_type=bp.integrators.POP_VAR, + show_code=True) + runner = bp.IntegratorRunner(intg, + monitors=['x', 'y', 'z'], + dt=0.001, + inits=[1., 1., 0.], + jit=True) + runner.run(100.) + + plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) + if show: + plt.show() + plt.close() class TestMilstein(unittest.TestCase): - def test1(self): - p = 0.1 - sigma = 10 - rho = 28 - beta = 8 / 3 - - gx = lambda x, t, y: p * x - gy = lambda y, t, x, z: p * y - gz = lambda z, t, x, y: p * z - - fx = lambda x, t, y: sigma * (y - x) - fy = lambda y, t, x, z: x * (rho - z) - y - fz = lambda z, t, x, y: x * y - beta * z - - bm.random.seed() - intg = bp.sdeint(f=bp.JointEq(fx, fy, fz), - g=bp.JointEq(gx, gy, gz), - intg_type=bp.integrators.ITO_SDE, - wiener_type=bp.integrators.SCALAR_WIENER, - var_type=bp.integrators.POP_VAR, - method='milstein') - runner = bp.IntegratorRunner(intg, - monitors=['x', 'y', 'z'], - dt=0.001, inits=[1., 1., 0.], - jit=True) - runner.run(100.) - - plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) - if show: - plt.show() - plt.close() - + def test1(self): + p = 0.1 + sigma = 10 + rho = 28 + beta = 8 / 3 + + gx = lambda x, t, y: p * x + gy = lambda y, t, x, z: p * y + gz = lambda z, t, x, y: p * z + + fx = lambda x, t, y: sigma * (y - x) + fy = lambda y, t, x, z: x * (rho - z) - y + fz = lambda z, t, x, y: x * y - beta * z + + bm.random.seed() + intg = bp.sdeint(f=bp.JointEq(fx, fy, fz), + g=bp.JointEq(gx, gy, gz), + intg_type=bp.integrators.ITO_SDE, + wiener_type=bp.integrators.SCALAR_WIENER, + var_type=bp.integrators.POP_VAR, + method='milstein') + runner = bp.IntegratorRunner(intg, + monitors=['x', 'y', 'z'], + dt=0.001, inits=[1., 1., 0.], + jit=True) + runner.run(100.) + + plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten()) + if show: + plt.show() + plt.close() diff --git a/brainpy/_src/integrators/sde/tests/test_sde_scalar.py b/brainpy/_src/integrators/sde/tests/test_sde_scalar.py index fe582ee6b..3ac83e62b 100644 --- a/brainpy/_src/integrators/sde/tests/test_sde_scalar.py +++ b/brainpy/_src/integrators/sde/tests/test_sde_scalar.py @@ -17,63 +17,63 @@ def lorenz_f(x, y, z, t): - dx = sigma * (y - x) - dy = x * (rho - z) - y - dz = x * y - beta * z - return dx, dy, dz + dx = sigma * (y - x) + dy = x * (rho - z) - y + dz = x * y - beta * z + return dx, dy, dz def lorenz_g(x, y, z, t): - return p * x, p * y, p * z + return p * x, p * y, p * z def lorenz_system(method, **kwargs): - bp.math.random.seed() - integral = bp.math.jit(method(f=lorenz_f, - g=lorenz_g, - show_code=True, - dt=0.005, - **kwargs)) - - times = np.arange(0, 10, 0.01) - mon1 = [] - mon2 = [] - mon3 = [] - x, y, z = 1, 1, 1 - for t in times: - x, y, z = integral(x, y, z, t) - mon1.append(x) - mon2.append(y) - mon3.append(z) - mon1 = bp.math.array(mon1).to_numpy() - mon2 = bp.math.array(mon2).to_numpy() - mon3 = bp.math.array(mon3).to_numpy() - - fig = plt.figure() - ax = fig.add_subplot(111, projection='3d') - plt.plot(mon1, mon2, mon3) - ax.set_xlabel('x') - ax.set_xlabel('y') - ax.set_xlabel('z') - plt.show(block=block) - plt.close(fig) + bp.math.random.seed() + integral = bp.math.jit(method(f=lorenz_f, + g=lorenz_g, + show_code=True, + dt=0.005, + **kwargs)) + + times = np.arange(0, 10, 0.01) + mon1 = [] + mon2 = [] + mon3 = [] + x, y, z = 1, 1, 1 + for t in times: + x, y, z = integral(x, y, z, t) + mon1.append(x) + mon2.append(y) + mon3.append(z) + mon1 = bp.math.array(mon1).to_numpy() + mon2 = bp.math.array(mon2).to_numpy() + mon3 = bp.math.array(mon3).to_numpy() + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + plt.plot(mon1, mon2, mon3) + ax.set_xlabel('x') + ax.set_xlabel('y') + ax.set_xlabel('z') + plt.show(block=block) + plt.close(fig) class TestScalarWienerIntegral(unittest.TestCase): - def test_srk1w1_try1(self): - lorenz_system(sde.SRK1W1) + def test_srk1w1_try1(self): + lorenz_system(sde.SRK1W1) - def test_srk1w1_try2(self): - with pytest.raises(AssertionError): - lorenz_system(sde.SRK1W1, wiener_type=bp.integrators.VECTOR_WIENER) + def test_srk1w1_try2(self): + with pytest.raises(AssertionError): + lorenz_system(sde.SRK1W1, wiener_type=bp.integrators.VECTOR_WIENER) - def test_srk2w1(self): - lorenz_system(sde.SRK2W1) + def test_srk2w1(self): + lorenz_system(sde.SRK2W1) - def test_euler(self): - lorenz_system(sde.Euler, intg_type=bp.integrators.ITO_SDE) - lorenz_system(sde.Euler, intg_type=bp.integrators.STRA_SDE) + def test_euler(self): + lorenz_system(sde.Euler, intg_type=bp.integrators.ITO_SDE) + lorenz_system(sde.Euler, intg_type=bp.integrators.STRA_SDE) - def test_milstein(self): - lorenz_system(sde.MilsteinGradFree, intg_type=bp.integrators.ITO_SDE) - lorenz_system(sde.MilsteinGradFree, intg_type=bp.integrators.STRA_SDE) + def test_milstein(self): + lorenz_system(sde.MilsteinGradFree, intg_type=bp.integrators.ITO_SDE) + lorenz_system(sde.MilsteinGradFree, intg_type=bp.integrators.STRA_SDE) diff --git a/brainpy/_src/integrators/tests/test_integ_runner.py b/brainpy/_src/integrators/tests/test_integ_runner.py index 98c71241c..134199379 100644 --- a/brainpy/_src/integrators/tests/test_integ_runner.py +++ b/brainpy/_src/integrators/tests/test_integ_runner.py @@ -5,105 +5,103 @@ import matplotlib.pyplot as plt import brainpy as bp -import brainpy.math as bm show = False + class TestIntegratorRunnerForODEs(TestCase): - def test_ode(self): - - sigma = 10 - beta = 8 / 3 - rho = 28 - - @bp.odeint(method='rk4', dt=0.001) - def lorenz(x, y, z, t): - dx = sigma * (y - x) - dy = x * (rho - z) - y - dz = x * y - beta * z - return dx, dy, dz - - runner = bp.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.]) - runner.run(100.) - fig = plt.figure() - fig.add_subplot(111, projection='3d') - plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0], ) - if show:plt.show() - - runner = bp.IntegratorRunner(lorenz, - monitors=['x', 'y', 'z'], - inits=[1., (1., 0.), (1., 0.)]) - runner.run(100.) - for i in range(2): - fig = plt.figure() - fig.add_subplot(111, projection='3d') - plt.plot(runner.mon.x[:, i], runner.mon.y[:, i], runner.mon.z[:, i]) - plt.show() - - plt.close() - - def test_ode2(self): - a, b, tau = 0.7, 0.8, 12.5 - dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext - dw = lambda w, t, V: (V + a - b * w) / tau - fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) - - runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.]) - runner.run(100., args=dict(Iext=1.5)) - bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') - bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=show) - plt.close() - - - def test_ode3(self): - a, b, tau = 0.7, 0.8, 12.5 - dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext - dw = lambda w, t, V: (V + a - b * w) / tau - fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) - - Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 500, 200], return_length=True) - runner = bp.IntegratorRunner(fhn, - monitors=['V', 'w'], - inits=[1., 1.]) - runner.run(duration, dyn_args=dict(Iext=Iext)) - bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') - bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=show) - plt.close() - - def test_ode_continuous_run(self): - a, b, tau = 0.7, 0.8, 12.5 - dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext - dw = lambda w, t, V: (V + a - b * w) / tau - fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) - - runner = bp.IntegratorRunner(fhn, - monitors=['V', 'w'], - inits=[1., 1.]) - Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True) - runner.run(duration, dyn_args=dict(Iext=Iext)) - bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') - bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w') - - Iext, duration = bp.inputs.section_input([0.5], [200], return_length=True) - runner.run(duration, dyn_args=dict(Iext=Iext)) - bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V-run2') - bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w-run2', show=show) - plt.close() - - - def test_ode_dyn_args(self): - a, b, tau = 0.7, 0.8, 12.5 - dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext - dw = lambda w, t, V: (V + a - b * w) / tau - fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) - - Iext, duration = bp.inputs.section_input([0., 1., 0.5], - [200, 500, 199], - return_length=True) - runner = bp.IntegratorRunner(fhn, - monitors=['V', 'w'], - inits=[1., 1.]) - with self.assertRaises(ValueError): - runner.run(duration + 1, dyn_args=dict(Iext=Iext)) - - plt.close() + def test_ode(self): + + sigma = 10 + beta = 8 / 3 + rho = 28 + + @bp.odeint(method='rk4', dt=0.001) + def lorenz(x, y, z, t): + dx = sigma * (y - x) + dy = x * (rho - z) - y + dz = x * y - beta * z + return dx, dy, dz + + runner = bp.IntegratorRunner(lorenz, monitors=['x', 'y', 'z'], inits=[1., 1., 1.]) + runner.run(100.) + fig = plt.figure() + fig.add_subplot(111, projection='3d') + plt.plot(runner.mon.x[:, 0], runner.mon.y[:, 0], runner.mon.z[:, 0], ) + if show: plt.show() + + runner = bp.IntegratorRunner(lorenz, + monitors=['x', 'y', 'z'], + inits=[1., (1., 0.), (1., 0.)]) + runner.run(100.) + for i in range(2): + fig = plt.figure() + fig.add_subplot(111, projection='3d') + plt.plot(runner.mon.x[:, i], runner.mon.y[:, i], runner.mon.z[:, i]) + plt.show() + + plt.close() + + def test_ode2(self): + a, b, tau = 0.7, 0.8, 12.5 + dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext + dw = lambda w, t, V: (V + a - b * w) / tau + fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) + + runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.]) + runner.run(100., args=dict(Iext=1.5)) + bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') + bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=show) + plt.close() + + def test_ode3(self): + a, b, tau = 0.7, 0.8, 12.5 + dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext + dw = lambda w, t, V: (V + a - b * w) / tau + fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) + + Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 500, 200], return_length=True) + runner = bp.IntegratorRunner(fhn, + monitors=['V', 'w'], + inits=[1., 1.]) + runner.run(duration, dyn_args=dict(Iext=Iext)) + bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') + bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w', show=show) + plt.close() + + def test_ode_continuous_run(self): + a, b, tau = 0.7, 0.8, 12.5 + dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext + dw = lambda w, t, V: (V + a - b * w) / tau + fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) + + runner = bp.IntegratorRunner(fhn, + monitors=['V', 'w'], + inits=[1., 1.]) + Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True) + runner.run(duration, dyn_args=dict(Iext=Iext)) + bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') + bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w') + + Iext, duration = bp.inputs.section_input([0.5], [200], return_length=True) + runner.run(duration, dyn_args=dict(Iext=Iext)) + bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V-run2') + bp.visualize.line_plot(runner.mon.ts, runner.mon['w'], legend='w-run2', show=show) + plt.close() + + def test_ode_dyn_args(self): + a, b, tau = 0.7, 0.8, 12.5 + dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext + dw = lambda w, t, V: (V + a - b * w) / tau + fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) + + Iext, duration = bp.inputs.section_input([0., 1., 0.5], + [200, 500, 199], + return_length=True) + runner = bp.IntegratorRunner(fhn, + monitors=['V', 'w'], + inits=[1., 1.]) + with self.assertRaises(ValueError): + runner.run(duration + 1, dyn_args=dict(Iext=Iext)) + + plt.close() diff --git a/brainpy/_src/integrators/tests/test_joint_eq.py b/brainpy/_src/integrators/tests/test_joint_eq.py index 017afcba1..cdff39b35 100644 --- a/brainpy/_src/integrators/tests/test_joint_eq.py +++ b/brainpy/_src/integrators/tests/test_joint_eq.py @@ -1,46 +1,46 @@ # -*- coding: utf-8 -*- +import unittest + import brainpy.math as bm from brainpy._src.integrators.joint_eq import _get_args, JointEq - from brainpy.errors import DiffEqError -import unittest class TestGetArgs(unittest.TestCase): - def test_POSITIONAL_OR_KEYWORD(self): - def f(a, b, t, c, d=1.): - pass + def test_POSITIONAL_OR_KEYWORD(self): + def f(a, b, t, c, d=1.): + pass - print(_get_args(f)) + print(_get_args(f)) - def test_VAR_POSITIONAL(self): - def f(a, b, t, *c, d=1.): - pass + def test_VAR_POSITIONAL(self): + def f(a, b, t, *c, d=1.): + pass - with self.assertRaises(DiffEqError): - _get_args(f) + with self.assertRaises(DiffEqError): + _get_args(f) - def test_KEYWORD_ONLY(self): - def f(a, b, t, *, d=1.): - pass + def test_KEYWORD_ONLY(self): + def f(a, b, t, *, d=1.): + pass - with self.assertRaises(DiffEqError): - _get_args(f) + with self.assertRaises(DiffEqError): + _get_args(f) - # def test_POSITIONAL_ONLY(self): - # def f(a, b, t, /, d=1.): - # pass - # - # with self.assertRaises(DiffEqError): - # _get_args(f) + # def test_POSITIONAL_ONLY(self): + # def f(a, b, t, /, d=1.): + # pass + # + # with self.assertRaises(DiffEqError): + # _get_args(f) - def test_VAR_KEYWORD(self): - def f(a, b, t, **kwargs): - pass + def test_VAR_KEYWORD(self): + def f(a, b, t, **kwargs): + pass - with self.assertRaises(DiffEqError): - _get_args(f) + with self.assertRaises(DiffEqError): + _get_args(f) ENa, gNa = 50., 120. @@ -50,69 +50,68 @@ def f(a, b, t, **kwargs): def dm(m, t, V): - alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) - beta = 4.0 * bm.exp(-(V + 65) / 18) - return alpha * (1 - m) - beta * m + alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) + beta = 4.0 * bm.exp(-(V + 65) / 18) + return alpha * (1 - m) - beta * m def dh(h, t, V): - alpha = 0.07 * bm.exp(-(V + 65) / 20.) - beta = 1 / (1 + bm.exp(-(V + 35) / 10)) - return alpha * (1 - h) - beta * h + alpha = 0.07 * bm.exp(-(V + 65) / 20.) + beta = 1 / (1 + bm.exp(-(V + 35) / 10)) + return alpha * (1 - h) - beta * h def dn(n, t, V): - alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) - beta = 0.125 * bm.exp(-(V + 65) / 80) - dndt = alpha * (1 - n) - beta * n - return dndt + alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) + beta = 0.125 * bm.exp(-(V + 65) / 80) + dndt = alpha * (1 - n) - beta * n + return dndt def dV(V, t, m, h, n, I): - I_Na = (gNa * m ** 3.0 * h) * (V - ENa) - I_K = (gK * n ** 4.0) * (V - EK) - I_leak = gL * (V - EL) - dVdt = (- I_Na - I_K - I_leak + I) / C - return dVdt + I_Na = (gNa * m ** 3.0 * h) * (V - ENa) + I_K = (gK * n ** 4.0) * (V - EK) + I_leak = gL * (V - EL) + dVdt = (- I_Na - I_K - I_leak + I) / C + return dVdt class TestJointEqs(unittest.TestCase): - # def test_variables1(self): - # je = JointEq([dV, dn]) - # with self.assertRaises(DiffEqError): - # je(10., 1., 0., I=0.1) - - def test_variables2(self): - # with self.assertRaises(DiffEqError): - EQ = JointEq((dV,)) - EQ = JointEq(dV) - - def test_call1(self): - je1 = JointEq([dV, dn]) - res1 = je1(10., 1., 0., I=0.1, m=0.5, h=0.5) - je2 = JointEq(dV, dn) - res2 = je2(10., 1., 0., I=0.1, m=0.5, h=0.5) - self.assertTrue(res1 == res2) - - def test_do_not_change_par_position(self): - EQ = JointEq((dV,)) - self.assertEqual(EQ(10., 0., 0.1, 0.2, 0.3, 0.), - EQ(V=10., t=0., m=0.1, h=0.2, n=0.3, I=0.)) - - def test_return_is_list(self): - EQ = JointEq((dV,)) - self.assertTrue(isinstance(EQ(V=10., t=0., m=0.1, h=0.2, n=0.3, I=0.), list)) - EQ = JointEq(dV) - self.assertTrue(isinstance(EQ(V=10., t=0., m=0.1, h=0.2, n=0.3, I=0.), list)) - - def test_nested_joint_eq1(self): - EQ1 = JointEq((dm, dh)) - EQ2 = JointEq((EQ1, dn)) - EQ3 = JointEq((EQ2, dV)) - print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.)) - - EQ1 = JointEq(dm, dh) - EQ2 = JointEq(EQ1, dn) - EQ3 = JointEq(EQ2, dV) - print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.)) - + # def test_variables1(self): + # je = JointEq([dV, dn]) + # with self.assertRaises(DiffEqError): + # je(10., 1., 0., I=0.1) + + def test_variables2(self): + # with self.assertRaises(DiffEqError): + EQ = JointEq((dV,)) + EQ = JointEq(dV) + + def test_call1(self): + je1 = JointEq([dV, dn]) + res1 = je1(10., 1., 0., I=0.1, m=0.5, h=0.5) + je2 = JointEq(dV, dn) + res2 = je2(10., 1., 0., I=0.1, m=0.5, h=0.5) + self.assertTrue(res1 == res2) + + def test_do_not_change_par_position(self): + EQ = JointEq((dV,)) + self.assertEqual(EQ(10., 0., 0.1, 0.2, 0.3, 0.), + EQ(V=10., t=0., m=0.1, h=0.2, n=0.3, I=0.)) + + def test_return_is_list(self): + EQ = JointEq((dV,)) + self.assertTrue(isinstance(EQ(V=10., t=0., m=0.1, h=0.2, n=0.3, I=0.), list)) + EQ = JointEq(dV) + self.assertTrue(isinstance(EQ(V=10., t=0., m=0.1, h=0.2, n=0.3, I=0.), list)) + + def test_nested_joint_eq1(self): + EQ1 = JointEq((dm, dh)) + EQ2 = JointEq((EQ1, dn)) + EQ3 = JointEq((EQ2, dV)) + print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.)) + + EQ1 = JointEq(dm, dh) + EQ2 = JointEq(EQ1, dn) + EQ3 = JointEq(EQ2, dV) + print(EQ3(m=0.1, h=0.2, n=0.3, V=10., t=0., I=0.)) diff --git a/brainpy/_src/integrators/tests/test_to_math_expr.py b/brainpy/_src/integrators/tests/test_to_math_expr.py index aecf83230..12fbfc9bf 100644 --- a/brainpy/_src/integrators/tests/test_to_math_expr.py +++ b/brainpy/_src/integrators/tests/test_to_math_expr.py @@ -18,31 +18,31 @@ class EINet3(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=4000, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=4000, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value def test1(): - model = EINet3() + model = EINet3() - bp.integrators.compile_integrators(model.step_run, 0, 0.) - for intg in model.nodes().subset(bp.Integrator).values(): - print(intg.to_math_expr()) + bp.integrators.compile_integrators(model.step_run, 0, 0.) + for intg in model.nodes().subset(bp.Integrator).values(): + print(intg.to_math_expr()) diff --git a/brainpy/_src/integrators/utils.py b/brainpy/_src/integrators/utils.py index a8f1ad0ab..adaae74ae 100644 --- a/brainpy/_src/integrators/utils.py +++ b/brainpy/_src/integrators/utils.py @@ -5,139 +5,139 @@ from pprint import pprint import jax.numpy as jnp -import brainpy.math as bm -from brainpy.errors import UnsupportedError +import brainpy.math as bm from brainpy import errors +from brainpy.errors import UnsupportedError __all__ = [ - 'get_args', - 'check_kws', - 'compile_code', - 'check_inits', - 'format_args', + 'get_args', + 'check_kws', + 'compile_code', + 'check_inits', + 'format_args', ] def check_kws(parameters, keywords): - for key, meaning in keywords.items(): - if key in parameters: - raise errors.CodeError(f'"{key}" is a keyword for ' - f'numerical solvers in BrainPy, denoting ' - f'"{meaning}". Please change another name.') + for key, meaning in keywords.items(): + if key in parameters: + raise errors.CodeError(f'"{key}" is a keyword for ' + f'numerical solvers in BrainPy, denoting ' + f'"{meaning}". Please change another name.') def get_args(f): - """Get the function arguments. - - >>> def f1(a, b, t, *args, c=1): pass - >>> get_args(f1) - (['a', 'b'], ['t', '*args', 'c'], ['a', 'b', 't', '*args', 'c=1']) - - >>> def f2(a, b, *args, c=1, **kwargs): pass - >>> get_args(f2) - ValueError: Do not support dict of keyword arguments: **kwargs - - >>> def f3(a, b, t, c=1, d=2): pass - >>> get_args(f4) - (['a', 'b'], ['t', 'c', 'd'], ['a', 'b', 't', 'c=1', 'd=2']) - - >>> def f4(a, b, t, *args): pass - >>> get_args(f4) - (['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args']) - - >>> scope = {} - >>> exec(compile('def f5(a, b, t, *args): pass', '', 'exec'), scope) - >>> get_args(scope['f5']) - (['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args']) - - Parameters:: - - f : callable - The function. - - Returns:: - - args : tuple - The variable names, the other arguments, and the original args. - """ - - # get the function arguments - reduced_args = [] - args = [] - - for name, par in inspect.signature(f).parameters.items(): - if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: - reduced_args.append(par.name) - - elif par.kind is inspect.Parameter.VAR_POSITIONAL: - reduced_args.append(f'*{par.name}') - - elif par.kind is inspect.Parameter.KEYWORD_ONLY: - raise errors.DiffEqError(f'In BrainPy, numerical integrators do not support KEYWORD_ONLY ' - f'parameters, e.g., * (error in {f}).') - elif par.kind is inspect.Parameter.POSITIONAL_ONLY: - raise errors.DiffEqError(f'In BrainPy, numerical integrators do not support POSITIONAL_ONLY ' - f'parameters, e.g., / (error in {f}).') - elif par.kind is inspect.Parameter.VAR_KEYWORD: # TODO - raise errors.DiffEqError(f'In BrainPy, numerical integrators do not support VAR_KEYWORD ' - f'arguments: {str(par)} (error in {f}).') + """Get the function arguments. + + >>> def f1(a, b, t, *args, c=1): pass + >>> get_args(f1) + (['a', 'b'], ['t', '*args', 'c'], ['a', 'b', 't', '*args', 'c=1']) + + >>> def f2(a, b, *args, c=1, **kwargs): pass + >>> get_args(f2) + ValueError: Do not support dict of keyword arguments: **kwargs + + >>> def f3(a, b, t, c=1, d=2): pass + >>> get_args(f4) + (['a', 'b'], ['t', 'c', 'd'], ['a', 'b', 't', 'c=1', 'd=2']) + + >>> def f4(a, b, t, *args): pass + >>> get_args(f4) + (['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args']) + + >>> scope = {} + >>> exec(compile('def f5(a, b, t, *args): pass', '', 'exec'), scope) + >>> get_args(scope['f5']) + (['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args']) + + Parameters:: + + f : callable + The function. + + Returns:: + + args : tuple + The variable names, the other arguments, and the original args. + """ + + # get the function arguments + reduced_args = [] + args = [] + + for name, par in inspect.signature(f).parameters.items(): + if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + reduced_args.append(par.name) + + elif par.kind is inspect.Parameter.VAR_POSITIONAL: + reduced_args.append(f'*{par.name}') + + elif par.kind is inspect.Parameter.KEYWORD_ONLY: + raise errors.DiffEqError(f'In BrainPy, numerical integrators do not support KEYWORD_ONLY ' + f'parameters, e.g., * (error in {f}).') + elif par.kind is inspect.Parameter.POSITIONAL_ONLY: + raise errors.DiffEqError(f'In BrainPy, numerical integrators do not support POSITIONAL_ONLY ' + f'parameters, e.g., / (error in {f}).') + elif par.kind is inspect.Parameter.VAR_KEYWORD: # TODO + raise errors.DiffEqError(f'In BrainPy, numerical integrators do not support VAR_KEYWORD ' + f'arguments: {str(par)} (error in {f}).') + else: + raise errors.DiffEqError(f'Unknown argument type: {par.kind} (error in {f}).') + + args.append(str(par)) + + # variable names + vars = [] + for a in reduced_args: + if a == 't': + break + vars.append(a) else: - raise errors.DiffEqError(f'Unknown argument type: {par.kind} (error in {f}).') - - args.append(str(par)) - - # variable names - vars = [] - for a in reduced_args: - if a == 't': - break - vars.append(a) - else: - raise ValueError('Do not find time variable "t".') - pars = reduced_args[len(vars):] - return vars, pars, args + raise ValueError('Do not find time variable "t".') + pars = reduced_args[len(vars):] + return vars, pars, args def compile_code(code_lines, code_scope, func_name, show_code=False): - code = '\n'.join(code_lines) - if show_code: - print(code) - print() - pprint(code_scope) - print() - exec(compile(code, '', 'exec'), code_scope) - new_f = code_scope[func_name] - return new_f + code = '\n'.join(code_lines) + if show_code: + print(code) + print() + pprint(code_scope) + print() + exec(compile(code, '', 'exec'), code_scope) + new_f = code_scope[func_name] + return new_f def check_inits(inits, variables): - if isinstance(inits, (tuple, list, bm.BaseArray, jnp.ndarray)): - assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, ' - f'however we only got {len(inits)} initial values.') - inits = {v: inits[i] for i, v in enumerate(variables)} - elif isinstance(inits, dict): - assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, ' - f'however we only got {len(inits)} initial values.') - else: - raise UnsupportedError('Only supports dict/sequence of data for initial values. ' - f'But we got {type(inits)}: {inits}') - for key in list(inits.keys()): - if key not in variables: - raise ValueError(f'"{key}" is not defined in variables: {variables}') - val = inits[key] - if isinstance(val, (float, int)): - inits[key] = bm.asarray([val], dtype=bm.float_) - return inits + if isinstance(inits, (tuple, list, bm.BaseArray, jnp.ndarray)): + assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, ' + f'however we only got {len(inits)} initial values.') + inits = {v: inits[i] for i, v in enumerate(variables)} + elif isinstance(inits, dict): + assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, ' + f'however we only got {len(inits)} initial values.') + else: + raise UnsupportedError('Only supports dict/sequence of data for initial values. ' + f'But we got {type(inits)}: {inits}') + for key in list(inits.keys()): + if key not in variables: + raise ValueError(f'"{key}" is not defined in variables: {variables}') + val = inits[key] + if isinstance(val, (float, int)): + inits[key] = bm.asarray([val], dtype=bm.float_) + return inits def format_args(args, kwargs, arguments): - all_args = dict() - for i, arg in enumerate(args): - all_args[arguments[i]] = arg - for key, arg in kwargs.items(): - if key in all_args: - raise ValueError(f'{key} has been provided in *args, ' - f'but we detect it again in **kwargs.') - all_args[key] = arg - return all_args + all_args = dict() + for i, arg in enumerate(args): + all_args[arguments[i]] = arg + for key, arg in kwargs.items(): + if key in all_args: + raise ValueError(f'{key} has been provided in *args, ' + f'but we detect it again in **kwargs.') + all_args[key] = arg + return all_args diff --git a/brainpy/_src/losses/base.py b/brainpy/_src/losses/base.py index a01e2aee8..149fa6a92 100644 --- a/brainpy/_src/losses/base.py +++ b/brainpy/_src/losses/base.py @@ -3,22 +3,22 @@ from brainpy._src.dnn.base import Layer __all__ = [ - 'Loss', - 'WeightedLoss', + 'Loss', + 'WeightedLoss', ] class Loss(Layer): - reduction: str + reduction: str - def __init__(self, reduction: str = 'mean') -> None: - super().__init__() - self.reduction = reduction + def __init__(self, reduction: str = 'mean') -> None: + super().__init__() + self.reduction = reduction class WeightedLoss(Loss): - weight: Optional + weight: Optional - def __init__(self, weight: Optional = None, reduction: str = 'mean') -> None: - super().__init__(reduction) - self.weight = weight + def __init__(self, weight: Optional = None, reduction: str = 'mean') -> None: + super().__init__(reduction) + self.weight = weight diff --git a/brainpy/_src/losses/comparison.py b/brainpy/_src/losses/comparison.py index ac025b41d..8ceb5856d 100644 --- a/brainpy/_src/losses/comparison.py +++ b/brainpy/_src/losses/comparison.py @@ -17,870 +17,867 @@ from .utils import _reduce, _multi_return, _is_leaf __all__ = [ - 'CrossEntropyLoss', 'cross_entropy_loss', + 'CrossEntropyLoss', 'cross_entropy_loss', - 'cross_entropy_sparse', - 'cross_entropy_sigmoid', + 'cross_entropy_sparse', + 'cross_entropy_sigmoid', - 'NLLLoss', 'nll_loss', - 'L1Loss', 'l1_loss', + 'NLLLoss', 'nll_loss', + 'L1Loss', 'l1_loss', - 'l2_loss', - 'huber_loss', + 'l2_loss', + 'huber_loss', - 'MAELoss', 'mean_absolute_error', - 'MSELoss', 'mean_squared_error', + 'MAELoss', 'mean_absolute_error', + 'MSELoss', 'mean_squared_error', - 'mean_squared_log_error', - 'binary_logistic_loss', - 'multiclass_logistic_loss', - 'sigmoid_binary_cross_entropy', - 'softmax_cross_entropy', - 'log_cosh_loss', - 'ctc_loss_with_forward_probs', - 'ctc_loss', - 'multi_margin_loss', + 'mean_squared_log_error', + 'binary_logistic_loss', + 'multiclass_logistic_loss', + 'sigmoid_binary_cross_entropy', + 'softmax_cross_entropy', + 'log_cosh_loss', + 'ctc_loss_with_forward_probs', + 'ctc_loss', + 'multi_margin_loss', ] class CrossEntropyLoss(WeightedLoss): - r"""This criterion computes the cross entropy loss between input logits - and target. + r"""This criterion computes the cross entropy loss between input logits + and target. - It is useful when training a classification problem with `C` classes. - If provided, the optional argument :attr:`weight` should be a 1D `Tensor` - assigning weight to each of the classes. - This is particularly useful when you have an unbalanced training set. + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. + This is particularly useful when you have an unbalanced training set. - The `input` is expected to contain the unnormalized logits for each class (which do `not` need - to be positive or sum to 1, in general). - `input` has to be a Tensor of size :math:`(C)` for unbatched input, - :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the - `K`-dimensional case. The last being useful for higher dimension inputs, such - as computing cross entropy loss per-pixel for 2D images. + The `input` is expected to contain the unnormalized logits for each class (which do `not` need + to be positive or sum to 1, in general). + `input` has to be a Tensor of size :math:`(C)` for unbatched input, + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the + `K`-dimensional case. The last being useful for higher dimension inputs, such + as computing cross entropy loss per-pixel for 2D images. - The `target` that this criterion expects should contain either: + The `target` that this criterion expects should contain either: - - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if - `ignore_index` is specified, this loss also accepts this class index (this index - may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` - set to ``'none'``) loss for this case can be described as: + - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if + `ignore_index` is specified, this loss also accepts this class index (this index + may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` + set to ``'none'``) loss for this case can be described as: - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} - \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} + \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} - where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, - :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as - :math:`d_1, ..., d_k` for the `K`-dimensional case. If - :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then - .. math:: - \ell(x, y) = \begin{cases} - \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & - \text{if reduction} = \text{`mean';}\\ - \sum_{n=1}^N l_n, & - \text{if reduction} = \text{`sum'.} - \end{cases} + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} - Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and - :class:`~torch.nn.NLLLoss`. + Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and + :class:`~torch.nn.NLLLoss`. - - Probabilities for each class; useful when labels beyond a single class per minibatch item - are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with - :attr:`reduction` set to ``'none'``) loss for this case can be described as: + - Probabilities for each class; useful when labels beyond a single class per minibatch item + are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with + :attr:`reduction` set to ``'none'``) loss for this case can be described as: - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c} - where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, - :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as - :math:`d_1, ..., d_k` for the `K`-dimensional case. If - :attr:`reduction` is not ``'none'`` (default ``'mean'``), then - - .. math:: - \ell(x, y) = \begin{cases} - \frac{\sum_{n=1}^N l_n}{N}, & - \text{if reduction} = \text{`mean';}\\ - \sum_{n=1}^N l_n, & - \text{if reduction} = \text{`sum'.} - \end{cases} - - .. note:: - The performance of this criterion is generally better when `target` contains class - indices, as this allows for optimized computation. Consider providing `target` as - class probabilities only when a single class label per minibatch item is too restrictive. - - Args: - weight (Tensor, optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size `C` - size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, - the losses are averaged over each loss element in the batch. Note that for - some losses, there are multiple elements per sample. If the field :attr:`size_average` - is set to ``False``, the losses are instead summed for each minibatch. Ignored - when :attr:`reduce` is ``False``. Default: ``True`` - ignore_index (int, optional): Specifies a target value that is ignored - and does not contribute to the input gradient. When :attr:`size_average` is - ``True``, the loss is averaged over non-ignored targets. Note that - :attr:`ignore_index` is only applicable when the target contains class indices. - reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the - losses are averaged or summed over observations for each minibatch depending - on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per - batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in - the meantime, specifying either of those two args will override - :attr:`reduction`. Default: ``'mean'`` - label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount - of smoothing when computing the loss, where 0.0 means no smoothing. The targets - become a mixture of the original ground truth and a uniform distribution as described in - `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. - - Shape: - - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` - in the case of `K`-dimensional loss. - - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with - :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. - If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. - - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` - in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. - - - where: + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as + :math:`d_1, ..., d_k` for the `K`-dimensional case. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then .. math:: - \begin{aligned} - C ={} & \text{number of classes} \\ - N ={} & \text{batch size} \\ - \end{aligned} - - Examples:: - - >>> # Example of target with class indices - >>> loss = nn.CrossEntropyLoss() - >>> input = torch.randn(3, 5, requires_grad=True) - >>> target = torch.empty(3, dtype=torch.long).random_(5) - >>> output = loss(input, target) - >>> output.backward() - >>> - >>> # Example of target with class probabilities - >>> input = torch.randn(3, 5, requires_grad=True) - >>> target = torch.randn(3, 5).softmax(dim=1) - >>> output = loss(input, target) - >>> output.backward() - """ - __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] - ignore_index: int - label_smoothing: float - - def __init__(self, weight: Optional[ArrayType] = None, ignore_index: int = -100, - reduction: str = 'mean', label_smoothing: float = 0.0) -> None: - super().__init__(weight, reduction) - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - - def update(self, input: ArrayType, target: ArrayType) -> ArrayType: - return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction) + \ell(x, y) = \begin{cases} + \frac{\sum_{n=1}^N l_n}{N}, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + The performance of this criterion is generally better when `target` contains class + indices, as this allows for optimized computation. Consider providing `target` as + class probabilities only when a single class label per minibatch item is too restrictive. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + + Shape: + - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with + :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. + If containing class probabilities, same shape as the input and each value should be between :math:`[0, 1]`. + - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar. + + + where: + + .. math:: + \begin{aligned} + C ={} & \text{number of classes} \\ + N ={} & \text{batch size} \\ + \end{aligned} + + Examples:: + + >>> # Example of target with class indices + >>> loss = nn.CrossEntropyLoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(5) + >>> output = loss(input, target) + >>> output.backward() + >>> + >>> # Example of target with class probabilities + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5).softmax(dim=1) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] + ignore_index: int + label_smoothing: float + + def __init__(self, weight: Optional[ArrayType] = None, ignore_index: int = -100, + reduction: str = 'mean', label_smoothing: float = 0.0) -> None: + super().__init__(weight, reduction) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + + def update(self, input: ArrayType, target: ArrayType) -> ArrayType: + return cross_entropy_loss(input, target, weight=self.weight, reduction=self.reduction) def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): - r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. - - It is useful when training a classification problem with `C` classes. - If provided, the optional argument :attr:`weight` should be a 1D `Array` - assigning weight to each of the classes. This is particularly useful when - you have an unbalanced training set. - - The ``input`` is expected to contain raw, unnormalized scores for each class. - ``input`` has to be an array of size either :math:`(minibatch, C)` or - :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1` for the - `K`-dimensional case (described later). - - This criterion expects a class index in the range :math:`[0, C-1]` as the - `target` for each value of a 1D tensor of size `minibatch`. - - The loss can be described as: - - .. math:: - \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) - = -x[class] + \log\left(\sum_j \exp(x[j])\right) - - or in the case of the :attr:`weight` argument being specified: - - .. math:: - \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right) - - Can also be used for higher dimension inputs, such as 2D images, by providing - an input of size :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1`, - where :math:`K` is the number of dimensions, and a target of appropriate shape. - - Parameters:: - - predicts : ArrayType - :math:`(N, C)` where `C = number of classes`, or - :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` - in the case of `K`-dimensional loss. - targets : ArrayType - :math:`(N, C)` or :math:`(N)` where each value is - :math:`0 \leq \text{targets}[i] \leq C-1`, or - :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` - with :math:`K \geq 1` in the case of K-dimensional loss. - weight : ArrayType, optional - A manual rescaling weight given to each class. If given, has to be an array of size `C`. - reduction : str, optional - Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - - ``'none'``: no reduction will be applied, - - ``'mean'``: the weighted mean of the output is taken, - - ``'sum'``: the output will be summed. - - Returns:: - - output : scalar, ArrayType - If :attr:`reduction` is ``'none'``, then the same size as the target: - :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` - in the case of K-dimensional loss. - """ - - def _cel(_pred, _tar): - if bm.ndim(_tar) + 1 == bm.ndim(_pred): - _tar = bm.one_hot(_tar, _pred.shape[-1]) - loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1) - if weight is not None: - loss *= weight - return _reduce(outputs=loss, reduction=reduction) - - r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf) - return _multi_return(r) + r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Array` + assigning weight to each of the classes. This is particularly useful when + you have an unbalanced training set. + + The ``input`` is expected to contain raw, unnormalized scores for each class. + ``input`` has to be an array of size either :math:`(minibatch, C)` or + :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1` for the + `K`-dimensional case (described later). + + This criterion expects a class index in the range :math:`[0, C-1]` as the + `target` for each value of a 1D tensor of size `minibatch`. + + The loss can be described as: + + .. math:: + \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) + = -x[class] + \log\left(\sum_j \exp(x[j])\right) + + or in the case of the :attr:`weight` argument being specified: + + .. math:: + \text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right) + + Can also be used for higher dimension inputs, such as 2D images, by providing + an input of size :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1`, + where :math:`K` is the number of dimensions, and a target of appropriate shape. + + Parameters:: + + predicts : ArrayType + :math:`(N, C)` where `C = number of classes`, or + :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + targets : ArrayType + :math:`(N, C)` or :math:`(N)` where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` + with :math:`K \geq 1` in the case of K-dimensional loss. + weight : ArrayType, optional + A manual rescaling weight given to each class. If given, has to be an array of size `C`. + reduction : str, optional + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. + - ``'none'``: no reduction will be applied, + - ``'mean'``: the weighted mean of the output is taken, + - ``'sum'``: the output will be summed. + + Returns:: + + output : scalar, ArrayType + If :attr:`reduction` is ``'none'``, then the same size as the target: + :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` + in the case of K-dimensional loss. + """ + + def _cel(_pred, _tar): + if bm.ndim(_tar) + 1 == bm.ndim(_pred): + _tar = bm.one_hot(_tar, _pred.shape[-1]) + loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1) + if weight is not None: + loss *= weight + return _reduce(outputs=loss, reduction=reduction) + + r = tree_map(_cel, predicts, targets, is_leaf=_is_leaf) + return _multi_return(r) def cross_entropy_sparse(predicts, targets): - r"""Computes the softmax cross-entropy loss. + r"""Computes the softmax cross-entropy loss. - Args: - predicts: (batch, ..., #class) tensor of logits. - targets: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. + Args: + predicts: (batch, ..., #class) tensor of logits. + targets: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. - Returns: - (batch, ...) tensor of the cross-entropy for each entry. - """ + Returns: + (batch, ...) tensor of the cross-entropy for each entry. + """ - def crs(_prd, _tar): - if isinstance(_tar, int): - logits = _prd[..., _tar] - else: - logits = jnp.take_along_axis(_prd, _tar, -1).squeeze(-1) - return logsumexp(bm.as_jax(_prd), axis=-1) - logits + def crs(_prd, _tar): + if isinstance(_tar, int): + logits = _prd[..., _tar] + else: + logits = jnp.take_along_axis(_prd, _tar, -1).squeeze(-1) + return logsumexp(bm.as_jax(_prd), axis=-1) - logits - r = tree_map(crs, predicts, targets, is_leaf=_is_leaf) - return _multi_return(r) + r = tree_map(crs, predicts, targets, is_leaf=_is_leaf) + return _multi_return(r) def cross_entropy_sigmoid(predicts, targets): - """Computes the sigmoid cross-entropy loss. - - Args: - predicts: (batch, ..., #class) tensor of logits. - targets: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) - - Returns: - (batch, ...) tensor of the cross-entropies for each entry. - """ - r = tree_map( - lambda pred, tar: bm.as_jax( - bm.maximum(pred, 0) - pred * tar + bm.log(1 + bm.exp(-bm.abs(pred))) - ), - predicts, - targets, - is_leaf=_is_leaf - ) - return _multi_return(r) + """Computes the sigmoid cross-entropy loss. + + Args: + predicts: (batch, ..., #class) tensor of logits. + targets: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) + + Returns: + (batch, ...) tensor of the cross-entropies for each entry. + """ + r = tree_map( + lambda pred, tar: bm.as_jax( + bm.maximum(pred, 0) - pred * tar + bm.log(1 + bm.exp(-bm.abs(pred))) + ), + predicts, + targets, + is_leaf=_is_leaf + ) + return _multi_return(r) class NLLLoss(Loss): - r"""The negative log likelihood loss. - - The negative log likelihood loss. It is useful to train a classification - problem with `C` classes. - - If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning - weight to each of the classes. This is particularly useful when you have an - unbalanced training set. - - The `input` given through a forward call is expected to contain - log-probabilities of each class. `input` has to be a Tensor of size either - :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` - with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for - higher dimension inputs, such as computing NLL loss per-pixel for 2D images. - - Obtaining log-probabilities in a neural network is easily achieved by - adding a `LogSoftmax` layer in the last layer of your network. - You may use `CrossEntropyLoss` instead, if you prefer not to add an extra - layer. - - The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` - where `C = number of classes`; if `ignore_index` is specified, this loss also accepts - this class index (this index may not necessarily be in the class range). - - The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - w_{y_n} x_{n,y_n}, \quad - w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, - - where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and - :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` - (default ``'mean'``), then - - .. math:: - \ell(x, y) = \begin{cases} - \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & - \text{if reduction} = \text{`mean';}\\ - \sum_{n=1}^N l_n, & - \text{if reduction} = \text{`sum'.} - \end{cases} - - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in - the meantime, specifying either of those two args will override - :attr:`reduction`. Default: ``'mean'`` - - Shape: - - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or - :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` - in the case of `K`-dimensional loss. - - Target: :math:`(N)` or :math:`()`, where each value is - :math:`0 \leq \text{targets}[i] \leq C-1`, or - :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of - K-dimensional loss. - - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or - :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Otherwise, scalar. - - """ - - def __init__(self, reduction: str = 'mean'): - super().__init__(reduction=reduction) - - def update(self, input, target): - return nll_loss(input, target, reduction=self.reduction) + r"""The negative log likelihood loss. + The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. -def nll_loss(input, target, reduction: str = 'mean'): - r""" - The negative log likelihood loss. - - The negative log likelihood loss. It is useful to train a classification - problem with `C` classes. - - If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning - weight to each of the classes. This is particularly useful when you have an - unbalanced training set. - - The `input` given through a forward call is expected to contain - log-probabilities of each class. `input` has to be a Tensor of size either - :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` - with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for - higher dimension inputs, such as computing NLL loss per-pixel for 2D images. - - Obtaining log-probabilities in a neural network is easily achieved by - adding a `LogSoftmax` layer in the last layer of your network. - You may use `CrossEntropyLoss` instead, if you prefer not to add an extra - layer. - - The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` - where `C = number of classes`; if `ignore_index` is specified, this loss also accepts - this class index (this index may not necessarily be in the class range). - - The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - w_{y_n} x_{n,y_n}, \quad - w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, - - where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and - :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` - (default ``'mean'``), then - - .. math:: - \ell(x, y) = \begin{cases} - \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & - \text{if reduction} = \text{`mean';}\\ - \sum_{n=1}^N l_n, & - \text{if reduction} = \text{`sum'.} - \end{cases} - - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in - the meantime, specifying either of those two args will override - :attr:`reduction`. Default: ``'mean'`` - - Shape: - - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or - :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` - in the case of `K`-dimensional loss. - - Target: :math:`(N)` or :math:`()`, where each value is - :math:`0 \leq \text{targets}[i] \leq C-1`, or - :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of - K-dimensional loss. - - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or - :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Otherwise, scalar. - - """ - assert target.ndim + 1 == input.ndim - input = bm.as_jax(input) - target = bm.as_jax(target) - loss = input[jnp.arange(len(target)), target] - if reduction == 'mean': - return loss.mean() - elif reduction == 'sum': - return loss.sum() - elif reduction == 'none': - return loss - elif reduction is None: - return loss - else: - raise ValueError + If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning + weight to each of the classes. This is particularly useful when you have an + unbalanced training set. + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for + higher dimension inputs, such as computing NLL loss per-pixel for 2D images. -class L1Loss(Loss): - r"""Creates a criterion that measures the mean absolute error (MAE) between each element in - the input :math:`x` and target :math:`y`. + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. - The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; if `ignore_index` is specified, this loss also accepts + this class index (this index may not necessarily be in the class range). - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = \left| x_n - y_n \right|, + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` - (default ``'mean'``), then: + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, - .. math:: - \ell(x, y) = - \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} - \end{cases} + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then - :math:`x` and :math:`y` are tensors of arbitrary shapes with a total - of :math:`n` elements each. + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. + Otherwise, scalar. - The sum operation still operates over all the elements, and divides by :math:`n`. + """ - The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + def __init__(self, reduction: str = 'mean'): + super().__init__(reduction=reduction) - Supports real-valued and complex-valued inputs. + def update(self, input, target): + return nll_loss(input, target, reduction=self.reduction) - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in the meantime, - specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Target: :math:`(*)`, same shape as the input. - - Output: scalar. If :attr:`reduction` is ``'none'``, then - :math:`(*)`, same shape as the input. +def nll_loss(input, target, reduction: str = 'mean'): + r""" + The negative log likelihood loss. - Examples:: + The negative log likelihood loss. It is useful to train a classification + problem with `C` classes. - >>> loss = nn.L1Loss() - >>> input = bm.random.randn(3, 5) - >>> target = bm.random.randn(3, 5) - >>> output = loss(input, target) - >>> output.backward() - """ - __constants__ = ['reduction'] + If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning + weight to each of the classes. This is particularly useful when you have an + unbalanced training set. - def __init__(self, reduction: str = 'mean') -> None: - super().__init__(reduction=reduction) + The `input` given through a forward call is expected to contain + log-probabilities of each class. `input` has to be a Tensor of size either + :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` + with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for + higher dimension inputs, such as computing NLL loss per-pixel for 2D images. - def update(self, input: ArrayType, target: ArrayType) -> ArrayType: - return l1_loss(input, target, reduction=self.reduction) + Obtaining log-probabilities in a neural network is easily achieved by + adding a `LogSoftmax` layer in the last layer of your network. + You may use `CrossEntropyLoss` instead, if you prefer not to add an extra + layer. + The `target` that this loss expects should be a class index in the range :math:`[0, C-1]` + where `C = number of classes`; if `ignore_index` is specified, this loss also accepts + this class index (this index may not necessarily be in the class range). -def l1_loss(logits, targets, reduction='sum'): - r"""Creates a criterion that measures the mean absolute error (MAE) between each element in - the logits :math:`x` and targets :math:`y`. It is useful in regression problems. + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} x_{n,y_n}, \quad + w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = \left| x_n - y_n \right|, + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and + :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then - where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` - (default ``'mean'``), then: + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or + :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. + - Target: :math:`(N)` or :math:`()`, where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of + K-dimensional loss. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or + :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. + Otherwise, scalar. - .. math:: - \ell(x, y) = - \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} - \end{cases} + """ + assert target.ndim + 1 == input.ndim + input = bm.as_jax(input) + target = bm.as_jax(target) + loss = input[jnp.arange(len(target)), target] + if reduction == 'mean': + return loss.mean() + elif reduction == 'sum': + return loss.sum() + elif reduction == 'none': + return loss + elif reduction is None: + return loss + else: + raise ValueError - :math:`x` and :math:`y` are tensors of arbitrary shapes with a total - of :math:`n` elements each. - The sum operation still operates over all the elements, and divides by :math:`n`. +class L1Loss(Loss): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the input :math:`x` and target :math:`y`. - The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - Supports real-valued and complex-valued inputs. + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, - Parameters:: + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: - logits : ArrayType - :math:`(N, *)` where :math:`*` means, any number of additional dimensions. - targets : ArrayType - :math:`(N, *)`, same shape as the input. - reduction : str - Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - Default: ``'mean'``. - - ``'none'``: no reduction will be applied, - - ``'mean'``: the sum of the output will be divided by the number of elements in the output, - - ``'sum'``: the output will be summed. Note: :attr:`size_average` + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Supports real-valued and complex-valued inputs. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + - Output: scalar. If :attr:`reduction` is ``'none'``, then + :math:`(*)`, same shape as the input. + + Examples:: + + >>> loss = nn.L1Loss() + >>> input = bm.random.randn(3, 5) + >>> target = bm.random.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['reduction'] - Returns:: + def __init__(self, reduction: str = 'mean') -> None: + super().__init__(reduction=reduction) - output : scalar. - If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. - """ + def update(self, input: ArrayType, target: ArrayType) -> ArrayType: + return l1_loss(input, target, reduction=self.reduction) - def loss(pred, tar): - diff = (pred - tar).reshape((pred.shape[0], -1)) - norm = jnp.linalg.norm(bm.as_jax(diff), ord=1, axis=1, keepdims=False) - return _reduce(outputs=norm, reduction=reduction) - r = tree_map(loss, logits, targets, is_leaf=_is_leaf) - return _multi_return(r) +def l1_loss(logits, targets, reduction='sum'): + r"""Creates a criterion that measures the mean absolute error (MAE) between each element in + the logits :math:`x` and targets :math:`y`. It is useful in regression problems. + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: -def l2_loss(predicts, targets): - r"""Computes the L2 loss. + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left| x_n - y_n \right|, - The 0.5 term is standard in "Pattern Recognition and Machine Learning" - by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: - Parameters:: + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} - predicts: ArrayType - A vector of arbitrary shape. - targets: ArrayType - A vector of shape compatible with predictions. + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. - Returns:: + The sum operation still operates over all the elements, and divides by :math:`n`. - loss : float - A scalar value containing the l2 loss. + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. - References:: + Supports real-valued and complex-valued inputs. - .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. - """ - r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2, - predicts, - targets) - return _multi_return(r) + Parameters:: + logits : ArrayType + :math:`(N, *)` where :math:`*` means, any number of additional dimensions. + targets : ArrayType + :math:`(N, *)`, same shape as the input. + reduction : str + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. + Default: ``'mean'``. + - ``'none'``: no reduction will be applied, + - ``'mean'``: the sum of the output will be divided by the number of elements in the output, + - ``'sum'``: the output will be summed. Note: :attr:`size_average` + Returns:: -class MAELoss(Loss): - def __init__(self, axis=None, reduction: str = 'mean'): - super().__init__(reduction=reduction) - self.axis = axis + output : scalar. + If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. + """ - def update(self, input, target): - return mean_absolute_error(input, target, self.axis, reduction=self.reduction) + def loss(pred, tar): + diff = (pred - tar).reshape((pred.shape[0], -1)) + norm = jnp.linalg.norm(bm.as_jax(diff), ord=1, axis=1, keepdims=False) + return _reduce(outputs=norm, reduction=reduction) + r = tree_map(loss, logits, targets, is_leaf=_is_leaf) + return _multi_return(r) +def l2_loss(predicts, targets): + r"""Computes the L2 loss. -def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'): - r"""Computes the mean absolute error between x and y. + The 0.5 term is standard in "Pattern Recognition and Machine Learning" + by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. - Args: - x: a tensor of shape (d0, .. dN-1). - y: a tensor of shape (d0, .. dN-1). - axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + Parameters:: - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. - """ - r = tree_map(lambda a, b: _reduce(bm.abs(a - b), reduction=reduction, axis=axis), - x, - y, - is_leaf=_is_leaf) - return _multi_return(r) + predicts: ArrayType + A vector of arbitrary shape. + targets: ArrayType + A vector of shape compatible with predictions. + Returns:: -class MSELoss(Loss): - r"""Creates a criterion that measures the mean squared error (squared L2 norm) between - each element in the input :math:`x` and target :math:`y`. + loss : float + A scalar value containing the l2 loss. - The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + References:: - .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = \left( x_n - y_n \right)^2, + .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. + """ + r = tree_map(lambda pred, tar: 0.5 * (pred - tar) ** 2, + predicts, + targets) + return _multi_return(r) - where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` - (default ``'mean'``), then: - .. math:: - \ell(x, y) = - \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ - \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} - \end{cases} +class MAELoss(Loss): + def __init__(self, axis=None, reduction: str = 'mean'): + super().__init__(reduction=reduction) + self.axis = axis + + def update(self, input, target): + return mean_absolute_error(input, target, self.axis, reduction=self.reduction) - :math:`x` and :math:`y` are tensors of arbitrary shapes with a total - of :math:`n` elements each. - The mean operation still operates over all the elements, and divides by :math:`n`. +def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'): + r"""Computes the mean absolute error between x and y. - The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + Args: + x: a tensor of shape (d0, .. dN-1). + y: a tensor of shape (d0, .. dN-1). + axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in the meantime, - specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. + """ + r = tree_map(lambda a, b: _reduce(bm.abs(a - b), reduction=reduction, axis=axis), + x, + y, + is_leaf=_is_leaf) + return _multi_return(r) + + +class MSELoss(Loss): + r"""Creates a criterion that measures the mean squared error (squared L2 norm) between + each element in the input :math:`x` and target :math:`y`. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Target: :math:`(*)`, same shape as the input. + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = \left( x_n - y_n \right)^2, - Examples:: + where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: - >>> loss = nn.MSELoss() - >>> input = torch.randn(3, 5, requires_grad=True) - >>> target = torch.randn(3, 5) - >>> output = loss(input, target) - >>> output.backward() - """ - __constants__ = ['reduction'] + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The mean operation still operates over all the elements, and divides by :math:`n`. + + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + + Args: + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Target: :math:`(*)`, same shape as the input. + + Examples:: + + >>> loss = nn.MSELoss() + >>> input = torch.randn(3, 5, requires_grad=True) + >>> target = torch.randn(3, 5) + >>> output = loss(input, target) + >>> output.backward() + """ + __constants__ = ['reduction'] - def __init__(self, reduction: str = 'mean') -> None: - super().__init__(reduction=reduction) + def __init__(self, reduction: str = 'mean') -> None: + super().__init__(reduction=reduction) - def update(self, input: ArrayType, target: ArrayType) -> ArrayType: - return mean_squared_error(input, target, reduction=self.reduction) + def update(self, input: ArrayType, target: ArrayType) -> ArrayType: + return mean_squared_error(input, target, reduction=self.reduction) def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'): - r"""Computes the mean squared error between x and y. + r"""Computes the mean squared error between x and y. - Args: - predicts: a tensor of shape (d0, .. dN-1). - targets: a tensor of shape (d0, .. dN-1). - axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + Args: + predicts: a tensor of shape (d0, .. dN-1). + targets: a tensor of shape (d0, .. dN-1). + axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. - """ - r = tree_map(lambda a, b: _reduce((a - b) ** 2, reduction, axis=axis), - predicts, - targets, - is_leaf=_is_leaf) - return _multi_return(r) + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. + """ + r = tree_map(lambda a, b: _reduce((a - b) ** 2, reduction, axis=axis), + predicts, + targets, + is_leaf=_is_leaf) + return _multi_return(r) def mean_squared_log_error(predicts, targets, axis=None, reduction: str = 'mean'): - r"""Computes the mean squared logarithmic error between y_true and y_pred. + r"""Computes the mean squared logarithmic error between y_true and y_pred. - Args: - targets: a tensor of shape (d0, .. dN-1). - predicts: a tensor of shape (d0, .. dN-1). - keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. + Args: + targets: a tensor of shape (d0, .. dN-1). + predicts: a tensor of shape (d0, .. dN-1). + keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. - """ - r = tree_map(lambda a, b: _reduce((jnp.log1p(a) - jnp.log1p(b)) ** 2, reduction, axis=axis), - predicts, - targets, - is_leaf=_is_leaf) - return _multi_return(r) + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. + """ + r = tree_map(lambda a, b: _reduce((jnp.log1p(a) - jnp.log1p(b)) ** 2, reduction, axis=axis), + predicts, + targets, + is_leaf=_is_leaf) + return _multi_return(r) def huber_loss(predicts, targets, delta: float = 1.0): - r"""Huber loss. + r"""Huber loss. - Huber loss is similar to L2 loss close to zero, L1 loss away from zero. - If gradient descent is applied to the `huber loss`, it is equivalent to - clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. + Huber loss is similar to L2 loss close to zero, L1 loss away from zero. + If gradient descent is applied to the `huber loss`, it is equivalent to + clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. - Parameters:: + Parameters:: - predicts: ArrayType - predictions - targets: ArrayType - ground truth - delta: float - radius of quadratic behavior + predicts: ArrayType + predictions + targets: ArrayType + ground truth + delta: float + radius of quadratic behavior - Returns:: + Returns:: - loss : float - The loss value. + loss : float + The loss value. - References:: + References:: - .. [1] https://en.wikipedia.org/wiki/Huber_loss - """ + .. [1] https://en.wikipedia.org/wiki/Huber_loss + """ - def _loss(y_predict, y_target): - # 0.5 * err^2 if |err| <= d - # 0.5 * d^2 + d * (|err| - d) if |err| > d - diff = bm.abs(y_predict - y_target) - r = bm.where(diff > delta, - delta * (diff - .5 * delta), - 0.5 * diff ** 2) - return bm.as_jax(r) + def _loss(y_predict, y_target): + # 0.5 * err^2 if |err| <= d + # 0.5 * d^2 + d * (|err| - d) if |err| > d + diff = bm.abs(y_predict - y_target) + r = bm.where(diff > delta, + delta * (diff - .5 * delta), + 0.5 * diff ** 2) + return bm.as_jax(r) - r = tree_map(_loss, targets, predicts, is_leaf=_is_leaf) - return _multi_return(r) + r = tree_map(_loss, targets, predicts, is_leaf=_is_leaf) + return _multi_return(r) def binary_logistic_loss(predicts: float, targets: int, ) -> float: - """Binary logistic loss. + """Binary logistic loss. - Args: - targets: ground-truth integer label (0 or 1). - predicts: score produced by the model (float). + Args: + targets: ground-truth integer label (0 or 1). + predicts: score produced by the model (float). - Returns: - loss value - """ - # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. - # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), - # where xlogx(proba) = proba * log(proba). - r = tree_map(lambda a, b: bm.softplus(a) - b * a, - predicts, - targets, - is_leaf=_is_leaf) - return _multi_return(r) + Returns: + loss value + """ + # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. + # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), + # where xlogx(proba) = proba * log(proba). + r = tree_map(lambda a, b: bm.softplus(a) - b * a, + predicts, + targets, + is_leaf=_is_leaf) + return _multi_return(r) def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: - """Multiclass logistic loss. + """Multiclass logistic loss. - Args: - label: ground-truth integer label, between 0 and n_classes - 1. - logits: scores produced by the model, shape = (n_classes, ). + Args: + label: ground-truth integer label, between 0 and n_classes - 1. + logits: scores produced by the model, shape = (n_classes, ). - Returns: - loss value - """ + Returns: + loss value + """ - def loss(pred, tar): - pred = bm.as_jax(pred) - one_hot = bm.one_hot(tar, pred.shape[0]) - return logsumexp(pred) - jnp.dot(pred, one_hot) + def loss(pred, tar): + pred = bm.as_jax(pred) + one_hot = bm.one_hot(tar, pred.shape[0]) + return logsumexp(pred) - jnp.dot(pred, one_hot) - r = tree_map(loss, logits, label, is_leaf=_is_leaf) - return _multi_return(r) + r = tree_map(loss, logits, label, is_leaf=_is_leaf) + return _multi_return(r) def sigmoid_binary_cross_entropy(logits, labels): - """Computes sigmoid cross entropy given logits and multiple class labels. - Measures the probability error in discrete classification tasks in which - each class is an independent binary prediction and different classes are - not mutually exclusive. This may be used for multilabel image classification - for instance a model may predict that an image contains both a cat and a dog. + """Computes sigmoid cross entropy given logits and multiple class labels. + Measures the probability error in discrete classification tasks in which + each class is an independent binary prediction and different classes are + not mutually exclusive. This may be used for multilabel image classification + for instance a model may predict that an image contains both a cat and a dog. - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - Args: - logits: unnormalized log probabilities. - labels: the probability for that class. + Args: + logits: unnormalized log probabilities. + labels: the probability for that class. - Returns: - a sigmoid cross entropy loss. - """ + Returns: + a sigmoid cross entropy loss. + """ - def loss(pred, tar): - log_p = bm.log_sigmoid(pred) - # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable - log_not_p = bm.log_sigmoid(-pred) - return -tar * log_p - (1. - tar) * log_not_p + def loss(pred, tar): + log_p = bm.log_sigmoid(pred) + # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable + log_not_p = bm.log_sigmoid(-pred) + return -tar * log_p - (1. - tar) * log_not_p - r = tree_map(loss, logits, labels, is_leaf=_is_leaf) - return _multi_return(r) + r = tree_map(loss, logits, labels, is_leaf=_is_leaf) + return _multi_return(r) def softmax_cross_entropy(logits, labels): - """Computes the softmax cross entropy between sets of logits and labels. - Measures the probability error in discrete classification tasks in which - the classes are mutually exclusive (each entry is in exactly one class). - For example, each CIFAR-10 image is labeled with one and only one label: - an image can be a dog or a truck, but not both. - - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - - Args: - logits: unnormalized log probabilities. - labels: a valid probability distribution (non-negative, sum to 1), e.g a - one hot encoding of which class is the correct one for each input. - - Returns: - the cross entropy loss. - """ - r = tree_map(lambda pred, tar: -jnp.sum(tar * bm.log_softmax(pred, axis=-1), axis=-1), - logits, - labels, - is_leaf=_is_leaf) - return _multi_return(r) + """Computes the softmax cross entropy between sets of logits and labels. + Measures the probability error in discrete classification tasks in which + the classes are mutually exclusive (each entry is in exactly one class). + For example, each CIFAR-10 image is labeled with one and only one label: + an image can be a dog or a truck, but not both. + + References: + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + + Args: + logits: unnormalized log probabilities. + labels: a valid probability distribution (non-negative, sum to 1), e.g a + one hot encoding of which class is the correct one for each input. + + Returns: + the cross entropy loss. + """ + r = tree_map(lambda pred, tar: -jnp.sum(tar * bm.log_softmax(pred, axis=-1), axis=-1), + logits, + labels, + is_leaf=_is_leaf) + return _multi_return(r) def log_cosh_loss(predicts, targets): - r"""Calculates the log-cosh loss for a set of predictions. + r"""Calculates the log-cosh loss for a set of predictions. - log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` - for large x. It is a twice differentiable alternative to the Huber loss. + log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` + for large x. It is a twice differentiable alternative to the Huber loss. - References: - [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) + References: + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) - Args: - predicts: a vector of arbitrary shape. - targets: a vector of shape compatible with predictions; if not provided - then it is assumed to be zero. + Args: + predicts: a vector of arbitrary shape. + targets: a vector of shape compatible with predictions; if not provided + then it is assumed to be zero. - Returns: - the log-cosh loss. - """ + Returns: + the log-cosh loss. + """ - def loss(pred, tar): - errors = bm.as_jax(pred - tar) - return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) + def loss(pred, tar): + errors = bm.as_jax(pred - tar) + return jnp.logaddexp(errors, -errors) - jnp.log(2.0).astype(errors.dtype) - r = tree_map(loss, predicts, targets, is_leaf=_is_leaf) - return _multi_return(r) + r = tree_map(loss, predicts, targets, is_leaf=_is_leaf) + return _multi_return(r) def ctc_loss_with_forward_probs( @@ -891,128 +888,128 @@ def ctc_loss_with_forward_probs( blank_id: int = 0, log_epsilon: float = -1e5 ) -> Tuple[ArrayType, ArrayType, ArrayType]: - r"""Computes CTC loss and CTC forward-probabilities. - - The CTC loss is a loss function based on log-likelihoods of the model that - introduces a special blank symbol :math:`\phi` to represent variable-length - output sequences. - - Forward probabilities returned by this function, as auxiliary results, are - grouped into two part: blank alpha-probability and non-blank alpha - probability. Those are defined as follows: - - .. math:: - \alpha_{\mathrm{BLANK}}(t, n) = - \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ - \alpha_{\mathrm{LABEL}}(t, n) = - \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). - - Here, :math:`\pi` denotes the alignment sequence in the reference - [Graves et al, 2006] that is blank-inserted representations of ``labels``. - The return values are the logarithms of the above probabilities. - - References: - [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) - - Args: - logits: (B, T, K)-array containing logits of each class where B denotes - the batch size, T denotes the max time frames in ``logits``, and K - denotes the number of classes including a class for blanks. - logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each - element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` - denotes that ``logits[b, t, :]`` are padded values. - labels: (B, N)-array containing reference integer labels where N denotes - the max time frames in the label sequence. - label_paddings: (B, N)-array. Padding indicators for ``labels``. Each - element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` - denotes that ``labels[b, n]`` is a padded label. In the current - implementation, ``labels`` must be right-padded, i.e. each row - ``labelpaddings[b, :]`` must be repetition of zeroes, followed by - repetition of ones. - blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as - probabilities of blank symbols. - log_epsilon: Numerically-stable approximation of log(+0). - - Returns: - A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, - ``loss_value`` is a (B,)-array containing the loss values for each sequence - in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are - (T, B, N+1)-arrays where the (t, b, n)-th element denotes - \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th - sequence in the batch. - """ - assert logits.ndim == 3 - assert labels.ndim == 2 - batchsize, unused_maxinputlen, num_classes = logits.shape - batchsize_of_labels, maxlabellen = labels.shape - assert (batchsize == batchsize_of_labels) - assert (labels.shape == label_paddings.shape) - assert (logits.shape[:2] == logit_paddings.shape) - - logits = logits.value if isinstance(logits, bm.BaseArray) else logits - logit_paddings = logit_paddings.value if isinstance(logit_paddings, bm.BaseArray) else logit_paddings - labels = labels.value if isinstance(labels, bm.BaseArray) else labels - label_paddings = label_paddings.value if isinstance(label_paddings, bm.BaseArray) else label_paddings - - logprobs = bm.log_softmax(logits).value - labellens = maxlabellen - jnp.sum(label_paddings, axis=1).astype(jnp.int32) - - # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. - repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) - repeat = jnp.pad(repeat, ((0, 0), (0, 1))) - - logprobs_phi = logprobs[:, :, blank_id:blank_id + 1] # [B, T, 1] - logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] - - one_hot = bm.one_hot(labels, num_classes=num_classes) # [B, N, K] - logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot) - logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] - - logalpha_phi_init = jnp.ones( - (batchsize, maxlabellen + 1)) * log_epsilon # [B, N] - logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) - logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon - - def update_phi_score(phi, added_score): - # Update `phi[:, 1:]`` with adding `added_score` in log space. - return jnp.concatenate( - [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1) - - def loop_body(prev, x): - prev_phi, prev_emit = prev - # emit-to-phi epsilon transition, except if the next label is repetition - prev_phi_orig = prev_phi - prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) - - logprob_emit, logprob_phi, pad = x - - # phi-to-emit transition - next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, - prev_emit + logprob_emit) - # self-loop transition - next_phi = prev_phi + logprob_phi - # emit-to-phi blank transition only when the next label is repetition - next_phi = update_phi_score( - next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) - - pad = pad.reshape((batchsize, 1)) - next_emit = pad * prev_emit + (1.0 - pad) * next_emit - next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi - - return (next_phi, next_emit), (next_phi, next_emit) - - xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) - _, (logalpha_phi, logalpha_emit) = scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) - - # last row needs to be updated with the last epsilon transition - logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) - logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) - - # extract per_seq_loss - one_hot = bm.one_hot(labellens, num_classes=maxlabellen + 1).value # [B, N+1] - per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot) - - return per_seq_loss, logalpha_phi, logalpha_emit + r"""Computes CTC loss and CTC forward-probabilities. + + The CTC loss is a loss function based on log-likelihoods of the model that + introduces a special blank symbol :math:`\phi` to represent variable-length + output sequences. + + Forward probabilities returned by this function, as auxiliary results, are + grouped into two part: blank alpha-probability and non-blank alpha + probability. Those are defined as follows: + + .. math:: + \alpha_{\mathrm{BLANK}}(t, n) = + \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ + \alpha_{\mathrm{LABEL}}(t, n) = + \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). + + Here, :math:`\pi` denotes the alignment sequence in the reference + [Graves et al, 2006] that is blank-inserted representations of ``labels``. + The return values are the logarithms of the above probabilities. + + References: + [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) + + Args: + logits: (B, T, K)-array containing logits of each class where B denotes + the batch size, T denotes the max time frames in ``logits``, and K + denotes the number of classes including a class for blanks. + logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` + denotes that ``logits[b, t, :]`` are padded values. + labels: (B, N)-array containing reference integer labels where N denotes + the max time frames in the label sequence. + label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` + denotes that ``labels[b, n]`` is a padded label. In the current + implementation, ``labels`` must be right-padded, i.e. each row + ``labelpaddings[b, :]`` must be repetition of zeroes, followed by + repetition of ones. + blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + probabilities of blank symbols. + log_epsilon: Numerically-stable approximation of log(+0). + + Returns: + A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, + ``loss_value`` is a (B,)-array containing the loss values for each sequence + in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are + (T, B, N+1)-arrays where the (t, b, n)-th element denotes + \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th + sequence in the batch. + """ + assert logits.ndim == 3 + assert labels.ndim == 2 + batchsize, unused_maxinputlen, num_classes = logits.shape + batchsize_of_labels, maxlabellen = labels.shape + assert (batchsize == batchsize_of_labels) + assert (labels.shape == label_paddings.shape) + assert (logits.shape[:2] == logit_paddings.shape) + + logits = logits.value if isinstance(logits, bm.BaseArray) else logits + logit_paddings = logit_paddings.value if isinstance(logit_paddings, bm.BaseArray) else logit_paddings + labels = labels.value if isinstance(labels, bm.BaseArray) else labels + label_paddings = label_paddings.value if isinstance(label_paddings, bm.BaseArray) else label_paddings + + logprobs = bm.log_softmax(logits).value + labellens = maxlabellen - jnp.sum(label_paddings, axis=1).astype(jnp.int32) + + # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1]. + repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32) + repeat = jnp.pad(repeat, ((0, 0), (0, 1))) + + logprobs_phi = logprobs[:, :, blank_id:blank_id + 1] # [B, T, 1] + logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1] + + one_hot = bm.one_hot(labels, num_classes=num_classes) # [B, N, K] + logprobs_emit = jnp.einsum('btk,bnk->btn', logprobs, one_hot) + logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N] + + logalpha_phi_init = jnp.ones( + (batchsize, maxlabellen + 1)) * log_epsilon # [B, N] + logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0) + logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon + + def update_phi_score(phi, added_score): + # Update `phi[:, 1:]`` with adding `added_score` in log space. + return jnp.concatenate( + [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1) + + def loop_body(prev, x): + prev_phi, prev_emit = prev + # emit-to-phi epsilon transition, except if the next label is repetition + prev_phi_orig = prev_phi + prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat) + + logprob_emit, logprob_phi, pad = x + + # phi-to-emit transition + next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, + prev_emit + logprob_emit) + # self-loop transition + next_phi = prev_phi + logprob_phi + # emit-to-phi blank transition only when the next label is repetition + next_phi = update_phi_score( + next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)) + + pad = pad.reshape((batchsize, 1)) + next_emit = pad * prev_emit + (1.0 - pad) * next_emit + next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi + + return (next_phi, next_emit), (next_phi, next_emit) + + xs = (logprobs_emit, logprobs_phi, logit_paddings.transpose((1, 0))) + _, (logalpha_phi, logalpha_emit) = scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs) + + # last row needs to be updated with the last epsilon transition + logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1]) + logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last) + + # extract per_seq_loss + one_hot = bm.one_hot(labellens, num_classes=maxlabellen + 1).value # [B, N+1] + per_seq_loss = -jnp.einsum('bn,bn->b', logalpha_phi_last, one_hot) + + return per_seq_loss, logalpha_phi, logalpha_emit def ctc_loss(logits: ArrayType, @@ -1021,77 +1018,77 @@ def ctc_loss(logits: ArrayType, label_paddings: ArrayType, blank_id: int = 0, log_epsilon: float = -1e5) -> ArrayType: - """Computes CTC loss. - - See docstring for ``ctc_loss_with_forward_probs`` for details. - - Args: - logits: (B, T, K)-array containing logits of each class where B denotes - the batch size, T denotes the max time frames in ``logits``, and K - denotes the number of classes including a class for blanks. - logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each - element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` - denotes that ``logits[b, t, :]`` are padded values. - labels: (B, N)-array containing reference integer labels where N denotes - the max time frames in the label sequence. - label_paddings: (B, N)-array. Padding indicators for ``labels``. Each - element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` - denotes that ``labels[b, n]`` is a padded label. In the current - implementation, ``labels`` must be right-padded, i.e. each row - ``labelpaddings[b, :]`` must be repetition of zeroes, followed by - repetition of ones. - blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as - probabilities of blank symbols. - log_epsilon: Numerically-stable approximation of log(+0). - - Returns: - (B,)-array containing loss values for each sequence in the batch. - """ - per_seq_loss, _, _ = ctc_loss_with_forward_probs( - logits, logit_paddings, labels, label_paddings, - blank_id=blank_id, log_epsilon=log_epsilon) - return per_seq_loss + """Computes CTC loss. + + See docstring for ``ctc_loss_with_forward_probs`` for details. + + Args: + logits: (B, T, K)-array containing logits of each class where B denotes + the batch size, T denotes the max time frames in ``logits``, and K + denotes the number of classes including a class for blanks. + logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` + denotes that ``logits[b, t, :]`` are padded values. + labels: (B, N)-array containing reference integer labels where N denotes + the max time frames in the label sequence. + label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` + denotes that ``labels[b, n]`` is a padded label. In the current + implementation, ``labels`` must be right-padded, i.e. each row + ``labelpaddings[b, :]`` must be repetition of zeroes, followed by + repetition of ones. + blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + probabilities of blank symbols. + log_epsilon: Numerically-stable approximation of log(+0). + + Returns: + (B,)-array containing loss values for each sequence in the batch. + """ + per_seq_loss, _, _ = ctc_loss_with_forward_probs( + logits, logit_paddings, labels, label_paddings, + blank_id=blank_id, log_epsilon=log_epsilon) + return per_seq_loss def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'): - r"""Computes multi-class margin loss, also called multi-class hinge loss. - - This loss function is often used in multi-class classification problems. - It is a type of hinge loss that tries to ensure the correct class score is greater than the scores of other classes by a margin. - - The loss function for sample :math:`i` is: - - .. math:: - \ell(x, y) = \sum_{j \neq y_i} \max(0, x_{y_j} - x_{y_i} + \text{margin}) - - where :math:`x` is the input, :math:`y` is the target, and :math:`y_i` is the index of the true class, - and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` - and :math:`i \neq y`. - - Args: - predicts: :math:`(N, C)` where `C = number of classes`. - target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. - margin (float, optional): Has a default value of :math:`1`. - p (float, optional): Has a default value of :math:`1`. - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the sum of the output will be divided by the - number of elements in the output, ``'sum'``: the output will be summed. - Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, - and in the meantime, specifying either of those two args will override :attr:`reduction`. - Default: ``'mean'`` - - Returns: - a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. - """ - assert p == 1 or p == 2, 'p should be 1 or 2' - batch_size = predicts.shape[0] - correct_scores = predicts[jnp.arange(batch_size), targets] - margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p) - margins = margins.at[jnp.arange(batch_size), targets].set(0) - if reduction == 'mean': - return jnp.sum(margins) / batch_size - elif reduction == 'sum': - return jnp.sum(margins) - elif reduction == 'none': - return margins + r"""Computes multi-class margin loss, also called multi-class hinge loss. + + This loss function is often used in multi-class classification problems. + It is a type of hinge loss that tries to ensure the correct class score is greater than the scores of other classes by a margin. + + The loss function for sample :math:`i` is: + + .. math:: + \ell(x, y) = \sum_{j \neq y_i} \max(0, x_{y_j} - x_{y_i} + \text{margin}) + + where :math:`x` is the input, :math:`y` is the target, and :math:`y_i` is the index of the true class, + and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` + and :math:`i \neq y`. + + Args: + predicts: :math:`(N, C)` where `C = number of classes`. + target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + margin (float, optional): Has a default value of :math:`1`. + p (float, optional): Has a default value of :math:`1`. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the sum of the output will be divided by the + number of elements in the output, ``'sum'``: the output will be summed. + Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, + and in the meantime, specifying either of those two args will override :attr:`reduction`. + Default: ``'mean'`` + + Returns: + a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. + """ + assert p == 1 or p == 2, 'p should be 1 or 2' + batch_size = predicts.shape[0] + correct_scores = predicts[jnp.arange(batch_size), targets] + margins = jnp.power(jnp.maximum(0, predicts - correct_scores[:, jnp.newaxis] + margin), p) + margins = margins.at[jnp.arange(batch_size), targets].set(0) + if reduction == 'mean': + return jnp.sum(margins) / batch_size + elif reduction == 'sum': + return jnp.sum(margins) + elif reduction == 'none': + return margins diff --git a/brainpy/_src/losses/regularization.py b/brainpy/_src/losses/regularization.py index 320a7d052..c0c53faa2 100644 --- a/brainpy/_src/losses/regularization.py +++ b/brainpy/_src/losses/regularization.py @@ -1,79 +1,79 @@ # -*- coding: utf-8 -*- +import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_map -import jax.numpy as jnp import brainpy.math as bm from .utils import _is_leaf, _multi_return __all__ = [ - 'l2_norm', - 'mean_absolute', - 'mean_square', - 'log_cosh', - 'smooth_labels', + 'l2_norm', + 'mean_absolute', + 'mean_square', + 'log_cosh', + 'smooth_labels', ] def l2_norm(x, axis=None): - """Computes the L2 loss. + """Computes the L2 loss. - Args: - x: n-dimensional tensor of floats. + Args: + x: n-dimensional tensor of floats. - Returns: - scalar tensor containing the l2 loss of x. - """ - leaves, _ = tree_flatten(x) - return jnp.sqrt(jnp.sum(jnp.asarray([jnp.vdot(x, x) for x in leaves]), axis=axis)) + Returns: + scalar tensor containing the l2 loss of x. + """ + leaves, _ = tree_flatten(x) + return jnp.sqrt(jnp.sum(jnp.asarray([jnp.vdot(x, x) for x in leaves]), axis=axis)) def mean_absolute(outputs, axis=None): - r"""Computes the mean absolute error between x and y. + r"""Computes the mean absolute error between x and y. - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. - """ - r = tree_map(lambda a: bm.mean(bm.abs(a), axis=axis), outputs, is_leaf=_is_leaf) - return _multi_return(r) + Returns: + tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. + """ + r = tree_map(lambda a: bm.mean(bm.abs(a), axis=axis), outputs, is_leaf=_is_leaf) + return _multi_return(r) def mean_square(predicts, axis=None): - r = tree_map(lambda a: bm.mean(a ** 2, axis=axis), predicts, is_leaf=_is_leaf) - return _multi_return(r) + r = tree_map(lambda a: bm.mean(a ** 2, axis=axis), predicts, is_leaf=_is_leaf) + return _multi_return(r) def log_cosh(errors): - r"""Calculates the log-cosh loss for a set of predictions. - - log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` - for large x. It is a twice differentiable alternative to the Huber loss. - References: - [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) - Args: - errors: a vector of arbitrary shape. - Returns: - the log-cosh loss. - """ - r = tree_map(lambda a: bm.logaddexp(a, -a) - bm.log(2.0).astype(a.dtype), - errors, is_leaf=_is_leaf) - return _multi_return(r) + r"""Calculates the log-cosh loss for a set of predictions. + + log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` + for large x. It is a twice differentiable alternative to the Huber loss. + References: + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) + Args: + errors: a vector of arbitrary shape. + Returns: + the log-cosh loss. + """ + r = tree_map(lambda a: bm.logaddexp(a, -a) - bm.log(2.0).astype(a.dtype), + errors, is_leaf=_is_leaf) + return _multi_return(r) def smooth_labels(labels, alpha: float) -> jnp.ndarray: - r"""Apply label smoothing. - Label smoothing is often used in combination with a cross-entropy loss. - Smoothed labels favour small logit gaps, and it has been shown that this can - provide better model calibration by preventing overconfident predictions. - References: - [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) - Args: - labels: one hot labels to be smoothed. - alpha: the smoothing factor, the greedy category with be assigned - probability `(1-alpha) + alpha / num_categories` - Returns: - a smoothed version of the one hot input labels. - """ - r = tree_map(lambda tar: (1.0 - alpha) * tar + alpha / tar.shape[-1], - labels, is_leaf=lambda x: isinstance(x, bm.BaseArray)) - return _multi_return(r) + r"""Apply label smoothing. + Label smoothing is often used in combination with a cross-entropy loss. + Smoothed labels favour small logit gaps, and it has been shown that this can + provide better model calibration by preventing overconfident predictions. + References: + [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) + Args: + labels: one hot labels to be smoothed. + alpha: the smoothing factor, the greedy category with be assigned + probability `(1-alpha) + alpha / num_categories` + Returns: + a smoothed version of the one hot input labels. + """ + r = tree_map(lambda tar: (1.0 - alpha) * tar + alpha / tar.shape[-1], + labels, is_leaf=lambda x: isinstance(x, bm.BaseArray)) + return _multi_return(r) diff --git a/brainpy/_src/losses/utils.py b/brainpy/_src/losses/utils.py index add8277da..dfb618c11 100644 --- a/brainpy/_src/losses/utils.py +++ b/brainpy/_src/losses/utils.py @@ -9,28 +9,28 @@ def _is_leaf(x): - return isinstance(x, bm.BaseArray) + return isinstance(x, bm.BaseArray) def _reduce(outputs, reduction, axis=None): - if reduction == 'mean': - return outputs.mean(axis) - elif reduction == 'sum': - return outputs.sum(axis) - elif reduction == 'none': - return outputs - else: - raise UnsupportedError(_reduction_error % reduction) + if reduction == 'mean': + return outputs.mean(axis) + elif reduction == 'sum': + return outputs.sum(axis) + elif reduction == 'none': + return outputs + else: + raise UnsupportedError(_reduction_error % reduction) def _multi_return(r): - if isinstance(r, jax.Array): - return r - elif isinstance(r, bm.BaseArray): - return r.value - else: - leaves = tree_flatten(r)[0] - r = leaves[0] - for leaf in leaves[1:]: - r += leaf - return r + if isinstance(r, jax.Array): + return r + elif isinstance(r, bm.BaseArray): + return r.value + else: + leaves = tree_flatten(r)[0] + r = leaves[0] + for leaf in leaves[1:]: + r += leaf + return r diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index a28ba7d84..72b9b8c5c 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -19,7 +19,6 @@ Details in the following. """ - # necessity to wrap the jax.numpy.ndarray: # 1. for parameters and variables which want to # modify in the JIT mode, this wrapper is necessary. @@ -30,31 +29,25 @@ # -# data structure -from .ndarray import * -from .delayvars import * - +from . import activations +from . import random, linalg, fft +from . import surrogate, event, sparse, jitconn # functions from .activations import * -from . import activations - +from .compat_numpy import * +from .compat_tensorflow import * # high-level numpy operations from .datatypes import * +from .delayvars import * +from .environment import * from .interoperability import * -from .compat_numpy import * -from .compat_tensorflow import * +# environment settings +from .modes import * +# data structure +from .ndarray import * +# Variable and Objects for object-oriented JAX transformations +from .object_transform import * from .others import * -from . import random, linalg, fft - # operators from .pre_syn_post import * -from . import surrogate, event, sparse, jitconn - -# Variable and Objects for object-oriented JAX transformations -from .object_transform import * - -# environment settings -from .modes import * -from .environment import * from .scales import * - diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index bc0ebb9dc..64e32ff68 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -5,56 +5,56 @@ from jax.tree_util import tree_map -from .ndarray import Array, _return, BaseArray +from .ndarray import _return, BaseArray def _as_jax_array_(obj): - return obj.value if isinstance(obj, BaseArray) else obj + return obj.value if isinstance(obj, BaseArray) else obj def _is_leaf(a): - return isinstance(a, BaseArray) + return isinstance(a, BaseArray) def _compatible_with_brainpy_array( fun: Callable, module: str = '' ): - func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun - - @functools.wraps(func_to_wrap) - def new_fun(*args, **kwargs): - args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) - out = None - if len(kwargs): - # compatible with PyTorch syntax - if 'dim' in kwargs: - kwargs['axis'] = kwargs.pop('dim') - if 'keepdim' in kwargs: - kwargs['keepdims'] = kwargs.pop('keepdim') - # compatible with TensorFlow syntax - if 'keep_dims' in kwargs: - kwargs['keepdims'] = kwargs.pop('keep_dims') - # compatible with NumPy/PyTorch syntax - if 'out' in kwargs: - out = kwargs.pop('out') - if not isinstance(out, BaseArray): - raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') - # format - kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) - r = fun(*args, **kwargs) - if out is None: - return tree_map(_return, r) - else: - out.value = r - - new_fun.__doc__ = ( - f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' - f'while it is compatible with brainpy Array/Variable. \n\n' - f'Note that this function is also compatible with:\n\n' - f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' - f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' - f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' - ) - - return new_fun + func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun + + @functools.wraps(func_to_wrap) + def new_fun(*args, **kwargs): + args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) + out = None + if len(kwargs): + # compatible with PyTorch syntax + if 'dim' in kwargs: + kwargs['axis'] = kwargs.pop('dim') + if 'keepdim' in kwargs: + kwargs['keepdims'] = kwargs.pop('keepdim') + # compatible with TensorFlow syntax + if 'keep_dims' in kwargs: + kwargs['keepdims'] = kwargs.pop('keep_dims') + # compatible with NumPy/PyTorch syntax + if 'out' in kwargs: + out = kwargs.pop('out') + if not isinstance(out, BaseArray): + raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') + # format + kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) + r = fun(*args, **kwargs) + if out is None: + return tree_map(_return, r) + else: + out.value = r + + new_fun.__doc__ = ( + f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' + f'while it is compatible with brainpy Array/Variable. \n\n' + f'Note that this function is also compatible with:\n\n' + f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' + f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' + f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' + ) + + return new_fun diff --git a/brainpy/_src/math/activations.py b/brainpy/_src/math/activations.py index af391c8e9..21755596c 100644 --- a/brainpy/_src/math/activations.py +++ b/brainpy/_src/math/activations.py @@ -22,474 +22,474 @@ from .random import uniform __all__ = [ - 'celu', - 'elu', - 'gelu', - 'glu', - 'prelu', - 'hard_tanh', - 'hard_sigmoid', - 'tanh_shrink', - 'hard_silu', - 'hard_swish', - 'hard_shrink', - 'leaky_relu', - 'log_sigmoid', - 'log_softmax', - 'one_hot', - 'normalize', - 'relu', - 'relu6', - 'rrelu', - 'sigmoid', - 'soft_sign', - 'softmax', - 'softmin', - 'softplus', - 'soft_shrink', - 'silu', - 'mish', - 'swish', - 'selu', - 'identity', + 'celu', + 'elu', + 'gelu', + 'glu', + 'prelu', + 'hard_tanh', + 'hard_sigmoid', + 'tanh_shrink', + 'hard_silu', + 'hard_swish', + 'hard_shrink', + 'leaky_relu', + 'log_sigmoid', + 'log_softmax', + 'one_hot', + 'normalize', + 'relu', + 'relu6', + 'rrelu', + 'sigmoid', + 'soft_sign', + 'softmax', + 'softmin', + 'softplus', + 'soft_shrink', + 'silu', + 'mish', + 'swish', + 'selu', + 'identity', ] def get(activation): - global_vars = globals() + global_vars = globals() - if activation is None: - return None + if activation is None: + return None - if isinstance(activation, str): - if activation not in global_vars: - raise ValueError(f'Unknown activation function: {activation}, \nwe only support: ' - f'{[k for k, v in global_vars.items() if not k.startswith("_") and callable(v)]}') - elif callable(activation): - return activation - else: - raise ValueError(f'Unknown activation function {activation}. ') - return global_vars[activation] + if isinstance(activation, str): + if activation not in global_vars: + raise ValueError(f'Unknown activation function: {activation}, \nwe only support: ' + f'{[k for k, v in global_vars.items() if not k.startswith("_") and callable(v)]}') + elif callable(activation): + return activation + else: + raise ValueError(f'Unknown activation function {activation}. ') + return global_vars[activation] def identity(x): - return x.value if isinstance(x, BaseArray) else x + return x.value if isinstance(x, BaseArray) else x def celu(x, alpha=1.0): - r"""Continuously-differentiable exponential linear unit activation. + r"""Continuously-differentiable exponential linear unit activation. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{celu}(x) = \begin{cases} - x, & x > 0\\ - \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 - \end{cases} + .. math:: + \mathrm{celu}(x) = \begin{cases} + x, & x > 0\\ + \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 + \end{cases} - For more information, see - `Continuously Differentiable Exponential Linear Units - `_. + For more information, see + `Continuously Differentiable Exponential Linear Units + `_. - Parameters:: + Parameters:: - x : ArrayType - The input array. - alpha : ndarray, float - The default is 1.0. - """ - x = x.value if isinstance(x, BaseArray) else x - alpha = alpha.value if isinstance(alpha, BaseArray) else alpha - return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha)) + x : ArrayType + The input array. + alpha : ndarray, float + The default is 1.0. + """ + x = x.value if isinstance(x, BaseArray) else x + alpha = alpha.value if isinstance(alpha, BaseArray) else alpha + return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha)) def elu(x, alpha=1.0): - r"""Exponential linear unit activation function. + r"""Exponential linear unit activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{elu}(x) = \begin{cases} - x, & x > 0\\ - \alpha \left(\exp(x) - 1\right), & x \le 0 - \end{cases} + .. math:: + \mathrm{elu}(x) = \begin{cases} + x, & x > 0\\ + \alpha \left(\exp(x) - 1\right), & x \le 0 + \end{cases} - Parameters:: + Parameters:: - x: JaxArray, jnp.ndarray - The input array. - alpha : scalar or Array - default: 1.0. - """ - x = x.value if isinstance(x, BaseArray) else x - alpha = alpha.value if isinstance(alpha, BaseArray) else alpha - safe_x = jnp.where(x > 0, 0., x) - return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x)) + x: JaxArray, jnp.ndarray + The input array. + alpha : scalar or Array + default: 1.0. + """ + x = x.value if isinstance(x, BaseArray) else x + alpha = alpha.value if isinstance(alpha, BaseArray) else alpha + safe_x = jnp.where(x > 0, 0., x) + return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x)) def gelu(x, approximate=True): - r"""Gaussian error linear unit activation function. + r"""Gaussian error linear unit activation function. - If ``approximate=False``, computes the element-wise function: + If ``approximate=False``, computes the element-wise function: - .. math:: - \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( - \frac{x}{\sqrt{2}} \right) \right) + .. math:: + \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( + \frac{x}{\sqrt{2}} \right) \right) - If ``approximate=True``, uses the approximate formulation of GELU: + If ``approximate=True``, uses the approximate formulation of GELU: - .. math:: - \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( - \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right) + .. math:: + \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( + \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right) - For more information, see `Gaussian Error Linear Units (GELUs) - `_, section 2. + For more information, see `Gaussian Error Linear Units (GELUs) + `_, section 2. - Parameters:: + Parameters:: - x: ArrayType - The input array. - approximate: bool - whether to use the approximate or exact formulation. - """ - x = x.value if isinstance(x, BaseArray) else x - if approximate: - sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) - cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3)))) - y = x * cdf - else: - y = jnp.array(x * (jax.lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype) - return y + x: ArrayType + The input array. + approximate: bool + whether to use the approximate or exact formulation. + """ + x = x.value if isinstance(x, BaseArray) else x + if approximate: + sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) + cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3)))) + y = x * cdf + else: + y = jnp.array(x * (jax.lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype) + return y def glu(x, axis=-1): - r"""Gated linear unit activation function. + r"""Gated linear unit activation function. - Parameters:: + Parameters:: - x: ArrayType - The input array. - axis: int - The axis along which the split should be computed (default: -1) - """ - size = x.shape[axis] - assert size % 2 == 0, "axis size must be divisible by 2" - x = x.value if isinstance(x, BaseArray) else x - x1, x2 = jnp.split(x, 2, axis) - return x1 * sigmoid(x2) + x: ArrayType + The input array. + axis: int + The axis along which the split should be computed (default: -1) + """ + size = x.shape[axis] + assert size % 2 == 0, "axis size must be divisible by 2" + x = x.value if isinstance(x, BaseArray) else x + x1, x2 = jnp.split(x, 2, axis) + return x1 * sigmoid(x2) def hard_tanh(x, min_val=- 1.0, max_val=1.0): - r"""Hard :math:`\mathrm{tanh}` activation function. + r"""Hard :math:`\mathrm{tanh}` activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{hard\_tanh}(x) = \begin{cases} - -1, & x < -1\\ - x, & -1 \le x \le 1\\ - 1, & 1 < x - \end{cases} + .. math:: + \mathrm{hard\_tanh}(x) = \begin{cases} + -1, & x < -1\\ + x, & -1 \le x \le 1\\ + 1, & 1 < x + \end{cases} - Parameters:: + Parameters:: - x: ArrayType - The input array. - min_val: float - minimum value of the linear region range. Default: -1 - max_val: float - maximum value of the linear region range. Default: 1 - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.where(x > max_val, max_val, jnp.where(x < min_val, min_val, x)) + x: ArrayType + The input array. + min_val: float + minimum value of the linear region range. Default: -1 + max_val: float + maximum value of the linear region range. Default: 1 + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.where(x > max_val, max_val, jnp.where(x < min_val, min_val, x)) def hard_sigmoid(x): - r"""Hard Sigmoid activation function. + r"""Hard Sigmoid activation function. - Computes the element-wise function + Computes the element-wise function - .. math:: - \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6} + .. math:: + \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6} - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - return relu6(x + 3.) / 6. + x: ArrayType + The input array. + """ + return relu6(x + 3.) / 6. def tanh_shrink(x): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{Tanhshrink}(x) = x - \tanh(x) - """ - x = x.value if isinstance(x, BaseArray) else x - return x - jnp.tanh(x) + .. math:: + \text{Tanhshrink}(x) = x - \tanh(x) + """ + x = x.value if isinstance(x, BaseArray) else x + return x - jnp.tanh(x) def hard_silu(x): - r"""Hard SiLU activation function + r"""Hard SiLU activation function - Computes the element-wise function + Computes the element-wise function - .. math:: - \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x) + .. math:: + \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x) - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - return x * hard_sigmoid(x) + x: ArrayType + The input array. + """ + return x * hard_sigmoid(x) hard_swish = hard_silu def hard_shrink(x, lambd=0.5): - r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. + r"""Applies the Hard Shrinkage (Hardshrink) function element-wise. - Hardshrink is defined as: + Hardshrink is defined as: - .. math:: - \text{HardShrink}(x) = - \begin{cases} - x, & \text{ if } x > \lambda \\ - x, & \text{ if } x < -\lambda \\ - 0, & \text{ otherwise } - \end{cases} + .. math:: + \text{HardShrink}(x) = + \begin{cases} + x, & \text{ if } x > \lambda \\ + x, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} - Args: - lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + Args: + lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.)) + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.)) def leaky_relu(x, negative_slope=1e-2): - r"""Leaky rectified linear unit activation function. + r"""Leaky rectified linear unit activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{leaky\_relu}(x) = \begin{cases} - x, & x \ge 0\\ - \alpha x, & x < 0 - \end{cases} + .. math:: + \mathrm{leaky\_relu}(x) = \begin{cases} + x, & x \ge 0\\ + \alpha x, & x < 0 + \end{cases} - where :math:`\alpha` = :code:`negative_slope`. + where :math:`\alpha` = :code:`negative_slope`. - Parameters:: + Parameters:: - x: ArrayType - The input array. - negative_slope : float - The scalar specifying the negative slope (default: 0.01) - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.where(x >= 0, x, negative_slope * x) + x: ArrayType + The input array. + negative_slope : float + The scalar specifying the negative slope (default: 0.01) + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.where(x >= 0, x, negative_slope * x) def softplus(x, beta: float = 1., threshold: float = 20.): - r"""Softplus activation function. + r"""Softplus activation function. - Computes the element-wise function + Computes the element-wise function - .. math:: - \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) - SoftPlus is a smooth approximation to the ReLU function and can be used - to constrain the output of a machine to always be positive. + SoftPlus is a smooth approximation to the ReLU function and can be used + to constrain the output of a machine to always be positive. - For numerical stability the implementation reverts to the linear function - when :math:`input \times \beta > threshold`. + For numerical stability the implementation reverts to the linear function + when :math:`input \times \beta > threshold`. - Parameters:: + Parameters:: - x: The input array. - beta: the :math:`\beta` value for the Softplus formulation. Default: 1. - threshold: values above this revert to a linear function. Default: 20. + x: The input array. + beta: the :math:`\beta` value for the Softplus formulation. Default: 1. + threshold: values above this revert to a linear function. Default: 20. - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.where(x > threshold / beta, x, 1 / beta * jnp.logaddexp(beta * x, 0)) + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.where(x > threshold / beta, x, 1 / beta * jnp.logaddexp(beta * x, 0)) def log_sigmoid(x): - r"""Log-sigmoid activation function. + r"""Log-sigmoid activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x}) + .. math:: + \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x}) - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - return -softplus(-x) + x: ArrayType + The input array. + """ + return -softplus(-x) def soft_shrink(x, lambd=0.5): - r"""Applies the soft shrinkage function elementwise: - - .. math:: - \text{SoftShrinkage}(x) = - \begin{cases} - x - \lambda, & \text{ if } x > \lambda \\ - x + \lambda, & \text{ if } x < -\lambda \\ - 0, & \text{ otherwise } - \end{cases} + r"""Applies the soft shrinkage function elementwise: + + .. math:: + \text{SoftShrinkage}(x) = + \begin{cases} + x - \lambda, & \text{ if } x > \lambda \\ + x + \lambda, & \text{ if } x < -\lambda \\ + 0, & \text{ otherwise } + \end{cases} - Args: - lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + Args: + lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.)) + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.)) def log_softmax(x, axis=-1): - r"""Log-Softmax function. + r"""Log-Softmax function. - Computes the logarithm of the :code:`softmax` function, which rescales - elements to the range :math:`[-\infty, 0)`. + Computes the logarithm of the :code:`softmax` function, which rescales + elements to the range :math:`[-\infty, 0)`. - .. math :: - \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} - \right) + .. math :: + \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} + \right) - Parameters:: + Parameters:: - x: ArrayType - The input array. - axis: int - The axis or axes along which the :code:`log_softmax` should be - computed. Either an integer or a tuple of integers. - """ - x = x.value if isinstance(x, BaseArray) else x - shifted = x - jax.lax.stop_gradient(x.max(axis, keepdims=True)) - return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True)) - # exp = jnp.exp(x) - # return jnp.log(exp / exp.sum(axis=axis, keepdims=True)) + x: ArrayType + The input array. + axis: int + The axis or axes along which the :code:`log_softmax` should be + computed. Either an integer or a tuple of integers. + """ + x = x.value if isinstance(x, BaseArray) else x + shifted = x - jax.lax.stop_gradient(x.max(axis, keepdims=True)) + return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True)) + # exp = jnp.exp(x) + # return jnp.log(exp / exp.sum(axis=axis, keepdims=True)) def _canonicalize_axis(axis, num_dims) -> int: - """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" - axis = operator.index(axis) - if not -num_dims <= axis < num_dims: - raise ValueError( - "axis {} is out of bounds for array of dimension {}".format( - axis, num_dims)) - if axis < 0: - axis = axis + num_dims - return axis + """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" + axis = operator.index(axis) + if not -num_dims <= axis < num_dims: + raise ValueError( + "axis {} is out of bounds for array of dimension {}".format( + axis, num_dims)) + if axis < 0: + axis = axis + num_dims + return axis def one_hot(x, num_classes, *, dtype=None, axis=-1): - r"""One-hot encodes the given indicies. - - Each index in the input ``x`` is encoded as a vector of zeros of length - ``num_classes`` with the element at ``index`` set to one:: - - >>> import jax.numpy as jnp - >>> one_hot(jnp.array([0, 1, 2]), 3) - Array([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], dtype=float32) - - Indicies outside the range [0, num_classes) will be encoded as zeros:: - - >>> import jax.numpy as jnp - >>> one_hot(jnp.array([-1, 3]), 3) - Array([[0., 0., 0.], - [0., 0., 0.]], dtype=float32) - - Args: - x: A tensor of indices. - num_classes: Number of classes in the one-hot dimension. - dtype: optional, a float dtype for the returned values (default float64 if - jax_enable_x64 is true, otherwise float32). - axis: the axis or axes along which the function should be - computed. - """ - num_classes = jax.core.concrete_or_error( - int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") - dtype = jax.dtypes.canonicalize_dtype(jnp.float64 if dtype is None else dtype) - x = jnp.asarray(x.value if isinstance(x, BaseArray) else x) - try: - output_pos_axis = _canonicalize_axis(axis, x.ndim + 1) - except TypeError: - axis_size = jax.lax.psum(1, axis) - if num_classes != axis_size: - raise ValueError(f"Expected num_classes to match the size of axis {axis}, " - f"but {num_classes} != {axis_size}") from None - axis_idx = jax.lax.axis_index(axis) - return jnp.asarray(x == axis_idx, dtype=dtype) - axis = operator.index(axis) - lhs = jax.lax.expand_dims(x, (axis,)) - rhs_shape = [1] * x.ndim - rhs_shape.insert(output_pos_axis, num_classes) - rhs = jax.lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), - rhs_shape, - (output_pos_axis,)) - return jnp.asarray(lhs == rhs, dtype=dtype) + r"""One-hot encodes the given indicies. + + Each index in the input ``x`` is encoded as a vector of zeros of length + ``num_classes`` with the element at ``index`` set to one:: + + >>> import jax.numpy as jnp + >>> one_hot(jnp.array([0, 1, 2]), 3) + Array([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=float32) + + Indicies outside the range [0, num_classes) will be encoded as zeros:: + + >>> import jax.numpy as jnp + >>> one_hot(jnp.array([-1, 3]), 3) + Array([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + + Args: + x: A tensor of indices. + num_classes: Number of classes in the one-hot dimension. + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + axis: the axis or axes along which the function should be + computed. + """ + num_classes = jax.core.concrete_or_error( + int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") + dtype = jax.dtypes.canonicalize_dtype(jnp.float64 if dtype is None else dtype) + x = jnp.asarray(x.value if isinstance(x, BaseArray) else x) + try: + output_pos_axis = _canonicalize_axis(axis, x.ndim + 1) + except TypeError: + axis_size = jax.lax.psum(1, axis) + if num_classes != axis_size: + raise ValueError(f"Expected num_classes to match the size of axis {axis}, " + f"but {num_classes} != {axis_size}") from None + axis_idx = jax.lax.axis_index(axis) + return jnp.asarray(x == axis_idx, dtype=dtype) + axis = operator.index(axis) + lhs = jax.lax.expand_dims(x, (axis,)) + rhs_shape = [1] * x.ndim + rhs_shape.insert(output_pos_axis, num_classes) + rhs = jax.lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), + rhs_shape, + (output_pos_axis,)) + return jnp.asarray(lhs == rhs, dtype=dtype) def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5): - """Normalizes an array by subtracting mean and dividing by sqrt(var).""" - x = x.value if isinstance(x, BaseArray) else x - if mean is None: - mean = jnp.mean(x, axis, keepdims=True) - if variance is None: - # this definition is traditionally seen as less accurate than jnp.var's - # mean((x - mean(x))**2) but may be faster and even, given typical - # activation distributions and low-precision arithmetic, more accurate - # when used in neural network normalization layers - variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean) - y = (x - mean) * jax.lax.rsqrt(variance + epsilon) - return y + """Normalizes an array by subtracting mean and dividing by sqrt(var).""" + x = x.value if isinstance(x, BaseArray) else x + if mean is None: + mean = jnp.mean(x, axis, keepdims=True) + if variance is None: + # this definition is traditionally seen as less accurate than jnp.var's + # mean((x - mean(x))**2) but may be faster and even, given typical + # activation distributions and low-precision arithmetic, more accurate + # when used in neural network normalization layers + variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean) + y = (x - mean) * jax.lax.rsqrt(variance + epsilon) + return y def relu(x): - x = x.value if isinstance(x, BaseArray) else x - return _relu(x) + x = x.value if isinstance(x, BaseArray) else x + return _relu(x) @jax.custom_jvp def _relu(x: BaseArray) -> Array: - r"""Rectified linear unit activation function. + r"""Rectified linear unit activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{relu}(x) = \max(x, 0) + .. math:: + \mathrm{relu}(x) = \max(x, 0) - except under differentiation, we take: + except under differentiation, we take: - .. math:: - \nabla \mathrm{relu}(0) = 0 + .. math:: + \nabla \mathrm{relu}(0) = 0 - For more information see - `Numerical influence of ReLU’(0) on backpropagation - `_. + For more information see + `Numerical influence of ReLU’(0) on backpropagation + `_. - Args: - x : input array - """ - return jnp.maximum(x, 0) + Args: + x : input array + """ + return jnp.maximum(x, 0) # For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm @@ -497,229 +497,229 @@ def _relu(x: BaseArray) -> Array: def relu6(x): - r"""Rectified Linear Unit 6 activation function. + r"""Rectified Linear Unit 6 activation function. - Computes the element-wise function + Computes the element-wise function - .. math:: - \mathrm{relu6}(x) = \min(\max(x, 0), 6) + .. math:: + \mathrm{relu6}(x) = \min(\max(x, 0), 6) - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.minimum(jnp.maximum(x, 0), 6.) + x: ArrayType + The input array. + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.minimum(jnp.maximum(x, 0), 6.) def rrelu(x, lower=0.125, upper=0.3333333333333333, ): - r"""Applies the randomized leaky rectified liner unit function, element-wise, - as described in the paper: + r"""Applies the randomized leaky rectified liner unit function, element-wise, + as described in the paper: - `Empirical Evaluation of Rectified Activations in Convolutional Network`_. + `Empirical Evaluation of Rectified Activations in Convolutional Network`_. - The function is defined as: + The function is defined as: - .. math:: - \text{RReLU}(x) = - \begin{cases} - x & \text{if } x \geq 0 \\ - ax & \text{ otherwise } - \end{cases} + .. math:: + \text{RReLU}(x) = + \begin{cases} + x & \text{if } x \geq 0 \\ + ax & \text{ otherwise } + \end{cases} - where :math:`a` is randomly sampled from uniform distribution - :math:`\mathcal{U}(\text{lower}, \text{upper})`. + where :math:`a` is randomly sampled from uniform distribution + :math:`\mathcal{U}(\text{lower}, \text{upper})`. - See: https://arxiv.org/pdf/1505.00853.pdf + See: https://arxiv.org/pdf/1505.00853.pdf - Args: - lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` - upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + Args: + lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. - .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: - https://arxiv.org/abs/1505.00853 - """ - x = x.value if isinstance(x, BaseArray) else x - a = uniform(lower, upper, size=x.shape) - return jnp.where(x >= 0., x, a * x) + .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`: + https://arxiv.org/abs/1505.00853 + """ + x = x.value if isinstance(x, BaseArray) else x + a = uniform(lower, upper, size=x.shape) + return jnp.where(x >= 0., x, a * x) def prelu(x, a=0.25): - r"""Applies the element-wise function: + r"""Applies the element-wise function: - .. math:: - \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) - or + or - .. math:: - \text{PReLU}(x) = - \begin{cases} - x, & \text{ if } x \geq 0 \\ - ax, & \text{ otherwise } - \end{cases} + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + ax, & \text{ otherwise } + \end{cases} - Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single - parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, - a separate :math:`a` is used for each input channel. - """ - x = x.value if isinstance(x, BaseArray) else x - return jnp.where(x >= 0., x, a * x) + Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single + parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`, + a separate :math:`a` is used for each input channel. + """ + x = x.value if isinstance(x, BaseArray) else x + return jnp.where(x >= 0., x, a * x) def sigmoid(x): - r"""Sigmoid activation function. + r"""Sigmoid activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} + .. math:: + \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - x = x.value if isinstance(x, BaseArray) else x - return jax.scipy.special.expit(x) + x: ArrayType + The input array. + """ + x = x.value if isinstance(x, BaseArray) else x + return jax.scipy.special.expit(x) def soft_sign(x): - r"""Soft-sign activation function. + r"""Soft-sign activation function. - Computes the element-wise function + Computes the element-wise function - .. math:: - \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1} + .. math:: + \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1} - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - x = x.value if isinstance(x, BaseArray) else x - return x / (jnp.abs(x) + 1) + x: ArrayType + The input array. + """ + x = x.value if isinstance(x, BaseArray) else x + return x / (jnp.abs(x) + 1) def softmax(x, axis=-1): - r"""Softmax function. + r"""Softmax function. - Computes the function which rescales elements to the range :math:`[0, 1]` - such that the elements along :code:`axis` sum to :math:`1`. + Computes the function which rescales elements to the range :math:`[0, 1]` + such that the elements along :code:`axis` sum to :math:`1`. - .. math :: - \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + .. math :: + \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} - Parameters:: + Parameters:: - x: ArrayType - The input array. - axis: int - The axis or axes along which the softmax should be computed. The - softmax output summed across these dimensions should sum to :math:`1`. - Either an integer or a tuple of integers. - """ - x = x.value if isinstance(x, BaseArray) else x - unnormalized = jnp.exp(x - jax.lax.stop_gradient(x.max(axis, keepdims=True))) - return unnormalized / unnormalized.sum(axis, keepdims=True) + x: ArrayType + The input array. + axis: int + The axis or axes along which the softmax should be computed. The + softmax output summed across these dimensions should sum to :math:`1`. + Either an integer or a tuple of integers. + """ + x = x.value if isinstance(x, BaseArray) else x + unnormalized = jnp.exp(x - jax.lax.stop_gradient(x.max(axis, keepdims=True))) + return unnormalized / unnormalized.sum(axis, keepdims=True) def softmin(x, axis=-1): - r"""Applies the Softmin function to an n-dimensional input Tensor - rescaling them so that the elements of the n-dimensional output Tensor - lie in the range `[0, 1]` and sum to 1. + r"""Applies the Softmin function to an n-dimensional input Tensor + rescaling them so that the elements of the n-dimensional output Tensor + lie in the range `[0, 1]` and sum to 1. - Softmin is defined as: + Softmin is defined as: - .. math:: - \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} + .. math:: + \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} - Shape: - - Input: :math:`(*)` where `*` means, any number of additional - dimensions - - Output: :math:`(*)`, same shape as the input + Shape: + - Input: :math:`(*)` where `*` means, any number of additional + dimensions + - Output: :math:`(*)`, same shape as the input - Args: - axis (int): A dimension along which Softmin will be computed (so every slice - along dim will sum to 1). - """ - x = x.value if isinstance(x, BaseArray) else x - unnormalized = jnp.exp(-x) - return unnormalized / unnormalized.sum(axis, keepdims=True) + Args: + axis (int): A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). + """ + x = x.value if isinstance(x, BaseArray) else x + unnormalized = jnp.exp(-x) + return unnormalized / unnormalized.sum(axis, keepdims=True) soft_max = softmax def silu(x): - r"""SiLU activation function. + r"""SiLU activation function. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}} + .. math:: + \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}} - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - x = x.value if isinstance(x, BaseArray) else x - return x * sigmoid(x) + x: ArrayType + The input array. + """ + x = x.value if isinstance(x, BaseArray) else x + return x * sigmoid(x) swish = silu def mish(x): - r"""Applies the Mish function, element-wise. + r"""Applies the Mish function, element-wise. - Mish: A Self Regularized Non-Monotonic Neural Activation Function. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. - .. math:: - \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) - .. note:: - See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + .. note:: + See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ - Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Output: :math:`(*)`, same shape as the input. - """ - x = x.value if isinstance(x, BaseArray) else x - return x * jnp.tanh(softplus(x)) + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + """ + x = x.value if isinstance(x, BaseArray) else x + return x * jnp.tanh(softplus(x)) def selu(x): - r"""Scaled exponential linear unit activation. + r"""Scaled exponential linear unit activation. - Computes the element-wise function: + Computes the element-wise function: - .. math:: - \mathrm{selu}(x) = \lambda \begin{cases} - x, & x > 0\\ - \alpha e^x - \alpha, & x \le 0 - \end{cases} + .. math:: + \mathrm{selu}(x) = \lambda \begin{cases} + x, & x > 0\\ + \alpha e^x - \alpha, & x \le 0 + \end{cases} - where :math:`\lambda = 1.0507009873554804934193349852946` and - :math:`\alpha = 1.6732632423543772848170429916717`. + where :math:`\lambda = 1.0507009873554804934193349852946` and + :math:`\alpha = 1.6732632423543772848170429916717`. - For more information, see - `Self-Normalizing Neural Networks - `_. + For more information, see + `Self-Normalizing Neural Networks + `_. - Parameters:: + Parameters:: - x: ArrayType - The input array. - """ - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - x = x.value if isinstance(x, BaseArray) else x - safe_x = jnp.where(x > 0, 0., x) - return scale * jnp.where(x > 0, x, alpha * jnp.expm1(safe_x)) + x: ArrayType + The input array. + """ + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + x = x.value if isinstance(x, BaseArray) else x + safe_x = jnp.where(x > 0, 0., x) + return scale * jnp.where(x > 0, x, alpha * jnp.expm1(safe_x)) diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index 7320494da..af3230481 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -11,89 +11,89 @@ from .ndarray import Array, BaseArray __all__ = [ - 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', - 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', - - # math funcs - 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', - 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide', - 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power', - 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2', - 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', - 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', - 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', - 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod', - 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', - 'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf', - 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', - 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside', - 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle', - - # Elementwise bit operations - 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', - 'invert', 'left_shift', 'right_shift', - - # logic funcs - 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', - 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not', - 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue', - - # array manipulation - 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes', - 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', - 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', - 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', - 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', - 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract', - 'count_nonzero', 'max', 'min', 'amax', 'amin', - - # array creation - 'array_split', 'meshgrid', 'vander', - - # indexing funcs - 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', - 'triu_indices_from', 'take', 'select', - - # statistic funcs - 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', - 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar', - 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize', - - # window funcs - 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', - - # constants - 'e', 'pi', 'inf', - - # linear algebra - 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', - - # data types - 'dtype', 'finfo', 'iinfo', - - # more - 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv', - 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes', - 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from', - 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient', - 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices', - 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load', - 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d', - 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint', - 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90', - 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', - 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile', - 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj', - 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable', - 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types', - 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete', - - # unique - 'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt', - 'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval', - 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat', + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', + + # math funcs + 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', + 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide', + 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power', + 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2', + 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', + 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', + 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod', + 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', + 'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf', + 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', + 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle', + + # Elementwise bit operations + 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', + 'invert', 'left_shift', 'right_shift', + + # logic funcs + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not', + 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue', + + # array manipulation + 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', + + # array creation + 'array_split', 'meshgrid', 'vander', + + # indexing funcs + 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', + + # statistic funcs + 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar', + 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize', + + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', + + # constants + 'e', 'pi', 'inf', + + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + + # data types + 'dtype', 'finfo', 'iinfo', + + # more + 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv', + 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes', + 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from', + 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient', + 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices', + 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load', + 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d', + 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint', + 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90', + 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', + 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile', + 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj', + 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable', + 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types', + 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete', + + # unique + 'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt', + 'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval', + 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat', ] @@ -102,129 +102,129 @@ def _return(a): - return Array(a) + return Array(a) def fill_diagonal(a, val, inplace=True): - if a.ndim < 2: - raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') - if not isinstance(a, BaseArray) and inplace: - raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' - 'it requires a brainpy Array. If you want to disable ' - 'inplace updating, use ``fill_diagonal(inplace=False)``.') - val = val.value if isinstance(val, BaseArray) else val - i, j = jnp.diag_indices(_min(a.shape[-2:])) - r = as_jax(a).at[..., i, j].set(val) - if inplace: - a.value = r - else: - return r + if a.ndim < 2: + raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') + if not isinstance(a, BaseArray) and inplace: + raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' + 'it requires a brainpy Array. If you want to disable ' + 'inplace updating, use ``fill_diagonal(inplace=False)``.') + val = val.value if isinstance(val, BaseArray) else val + i, j = jnp.diag_indices(_min(a.shape[-2:])) + r = as_jax(a).at[..., i, j].set(val) + if inplace: + a.value = r + else: + return r def zeros(shape, dtype=None): - return _return(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def ones(shape, dtype=None): - return _return(jnp.ones(shape, dtype=dtype)) + return _return(jnp.ones(shape, dtype=dtype)) def empty(shape, dtype=None): - return _return(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def zeros_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) + a = _as_jax_array_(a) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def ones_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) + a = _as_jax_array_(a) + return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) def empty_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) + a = _as_jax_array_(a) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: - a = _as_jax_array_(a) - try: - res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - except TypeError: - leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, BaseArray)) - leaves = [_as_jax_array_(l) for l in leaves] - a = tree_unflatten(tree, leaves) - res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - return _return(res) + a = _as_jax_array_(a) + try: + res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + except TypeError: + leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, BaseArray)) + leaves = [_as_jax_array_(l) for l in leaves] + a = tree_unflatten(tree, leaves) + res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + return _return(res) def asarray(a, dtype=None, order=None): - a = _as_jax_array_(a) - try: - res = jnp.asarray(a=a, dtype=dtype, order=order) - except TypeError: - leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, BaseArray)) - leaves = [_as_jax_array_(l) for l in leaves] - arrays = tree_unflatten(tree, leaves) - res = jnp.asarray(a=arrays, dtype=dtype, order=order) - return _return(res) + a = _as_jax_array_(a) + try: + res = jnp.asarray(a=a, dtype=dtype, order=order) + except TypeError: + leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, BaseArray)) + leaves = [_as_jax_array_(l) for l in leaves] + arrays = tree_unflatten(tree, leaves) + res = jnp.asarray(a=arrays, dtype=dtype, order=order) + return _return(res) def arange(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return _return(jnp.arange(*args, **kwargs)) + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + return _return(jnp.arange(*args, **kwargs)) def linspace(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - res = jnp.linspace(*args, **kwargs) - if isinstance(res, tuple): - return _return(res[0]), res[1] - else: - return _return(res) + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + res = jnp.linspace(*args, **kwargs) + if isinstance(res, tuple): + return _return(res[0]), res[1] + else: + return _return(res) def logspace(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return _return(jnp.logspace(*args, **kwargs)) + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + return _return(jnp.logspace(*args, **kwargs)) def asanyarray(a, dtype=None, order=None): - return asarray(a, dtype=dtype, order=order) + return asarray(a, dtype=dtype, order=order) def ascontiguousarray(a, dtype=None, order=None): - return asarray(a, dtype=dtype, order=order) + return asarray(a, dtype=dtype, order=order) def asfarray(a, dtype=None): - if not np.issubdtype(dtype, np.inexact): - dtype = np.float64 - return asarray(a, dtype=dtype) + if not np.issubdtype(dtype, np.inexact): + dtype = np.float64 + return asarray(a, dtype=dtype) def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array: - del assume_unique - ar1_flat = ravel(ar1) - ar2_flat = ravel(ar2) - # Note: an algorithm based on searchsorted has better scaling, but in practice - # is very slow on accelerators because it relies on lax control flow. If XLA - # ever supports binary search natively, we should switch to this: - # ar2_flat = jnp.sort(ar2_flat) - # ind = jnp.searchsorted(ar2_flat, ar1_flat) - # if invert: - # return ar1_flat != ar2_flat[ind] - # else: - # return ar1_flat == ar2_flat[ind] - if invert: - return asarray((ar1_flat[:, None] != ar2_flat[None, :]).all(-1)) - else: - return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1)) + del assume_unique + ar1_flat = ravel(ar1) + ar2_flat = ravel(ar2) + # Note: an algorithm based on searchsorted has better scaling, but in practice + # is very slow on accelerators because it relies on lax control flow. If XLA + # ever supports binary search natively, we should switch to this: + # ar2_flat = jnp.sort(ar2_flat) + # ind = jnp.searchsorted(ar2_flat, ar1_flat) + # if invert: + # return ar1_flat != ar2_flat[ind] + # else: + # return ar1_flat == ar2_flat[ind] + if invert: + return asarray((ar1_flat[:, None] != ar2_flat[None, :]).all(-1)) + else: + return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1)) # Others @@ -272,29 +272,29 @@ def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array: def msort(a): - """ - Return a copy of an array sorted along the first axis. + """ + Return a copy of an array sorted along the first axis. - Parameters:: - - a : array_like - Array to be sorted. + Parameters:: - Returns:: - - sorted_array : ndarray - Array of the same type and shape as `a`. + a : array_like + Array to be sorted. - See Also:: - - sort + Returns:: - Notes:: - - ``brainpy.math.msort(a)`` is equivalent to ``brainpy.math.sort(a, axis=0)``. + sorted_array : ndarray + Array of the same type and shape as `a`. - """ - return sort(a, axis=0) + See Also:: + + sort + + Notes:: + + ``brainpy.math.msort(a)`` is equivalent to ``brainpy.math.sort(a, axis=0)``. + + """ + return sort(a, axis=0) nan_to_num = _compatible_with_brainpy_array(jnp.nan_to_num) @@ -402,9 +402,9 @@ def msort(a): ediff1d = _compatible_with_brainpy_array(jnp.ediff1d) cross = _compatible_with_brainpy_array(jnp.cross) if jax.__version__ >= '0.4.18': - trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid) + trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid) else: - trapz = _compatible_with_brainpy_array(jnp.trapz) + trapz = _compatible_with_brainpy_array(jnp.trapz) isfinite = _compatible_with_brainpy_array(jnp.isfinite) isinf = _compatible_with_brainpy_array(jnp.isinf) isnan = _compatible_with_brainpy_array(jnp.isnan) @@ -455,84 +455,84 @@ def msort(a): def shape(a): - """ - Return the shape of an array. - - Parameters:: - - a : array_like - Input array. - - Returns:: - - shape : tuple of ints - The elements of the shape tuple give the lengths of the - corresponding array dimensions. - - See Also:: - - len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with - ``N>=1``. - ndarray.shape : Equivalent array method. - - Examples:: - - >>> brainpy.math.shape(brainpy.math.eye(3)) - (3, 3) - >>> brainpy.math.shape([[1, 3]]) - (1, 2) - >>> brainpy.math.shape([0]) - (1,) - >>> brainpy.math.shape(0) - () - - """ - if isinstance(a, (BaseArray, jax.Array, np.ndarray)): - return a.shape - else: - return np.shape(a) + """ + Return the shape of an array. + + Parameters:: + + a : array_like + Input array. + + Returns:: + + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also:: + + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + ndarray.shape : Equivalent array method. + + Examples:: + + >>> brainpy.math.shape(brainpy.math.eye(3)) + (3, 3) + >>> brainpy.math.shape([[1, 3]]) + (1, 2) + >>> brainpy.math.shape([0]) + (1,) + >>> brainpy.math.shape(0) + () + + """ + if isinstance(a, (BaseArray, jax.Array, np.ndarray)): + return a.shape + else: + return np.shape(a) def size(a, axis=None): - """ - Return the number of elements along a given axis. - - Parameters:: - - a : array_like - Input data. - axis : int, optional - Axis along which the elements are counted. By default, give - the total number of elements. - - Returns:: - - element_count : int - Number of elements along the specified axis. - - See Also:: - - shape : dimensions of array - Array.shape : dimensions of array - Array.size : number of elements in array - - Examples:: - - >>> a = brainpy.math.array([[1,2,3], [4,5,6]]) - >>> brainpy.math.size(a) - 6 - >>> brainpy.math.size(a, 1) - 3 - >>> brainpy.math.size(a, 0) - 2 - """ - if isinstance(a, (BaseArray, jax.Array, np.ndarray)): - if axis is None: - return a.size + """ + Return the number of elements along a given axis. + + Parameters:: + + a : array_like + Input data. + axis : int, optional + Axis along which the elements are counted. By default, give + the total number of elements. + + Returns:: + + element_count : int + Number of elements along the specified axis. + + See Also:: + + shape : dimensions of array + Array.shape : dimensions of array + Array.size : number of elements in array + + Examples:: + + >>> a = brainpy.math.array([[1,2,3], [4,5,6]]) + >>> brainpy.math.size(a) + 6 + >>> brainpy.math.size(a, 1) + 3 + >>> brainpy.math.size(a, 0) + 2 + """ + if isinstance(a, (BaseArray, jax.Array, np.ndarray)): + if axis is None: + return a.size + else: + return a.shape[axis] else: - return a.shape[axis] - else: - return np.size(a, axis=axis) + return np.size(a, axis=axis) reshape = _compatible_with_brainpy_array(jnp.reshape) @@ -673,12 +673,13 @@ def size(a, axis=None): sort_complex = _compatible_with_brainpy_array(jnp.sort_complex) unpackbits = _compatible_with_brainpy_array(jnp.unpackbits) + # Unique APIs # ----------- def asscalar(a): - return a.item() + return a.item() array_type = [[np.half, np.single, np.double, np.longdouble], @@ -693,23 +694,23 @@ def asscalar(a): def common_type(*arrays): - is_complex = False - precision = 0 - for a in arrays: - t = a.dtype.type - if iscomplexobj(a): - is_complex = True - if issubclass(t, jnp.integer): - p = 2 # array_precision[_nx.double] + is_complex = False + precision = 0 + for a in arrays: + t = a.dtype.type + if iscomplexobj(a): + is_complex = True + if issubclass(t, jnp.integer): + p = 2 # array_precision[_nx.double] + else: + p = array_precision.get(t, None) + if p is None: + raise TypeError("can't get common type for non-numeric array") + precision = _max(precision, p) + if is_complex: + return array_type[1][precision] else: - p = array_precision.get(t, None) - if p is None: - raise TypeError("can't get common type for non-numeric array") - precision = _max(precision, p) - if is_complex: - return array_type[1][precision] - else: - return array_type[0][precision] + return array_type[0][precision] genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs)) @@ -718,44 +719,44 @@ def common_type(*arrays): def place(arr, mask, vals): - if not isinstance(arr, BaseArray): - raise ValueError(f'Must be an instance of brainpy Array, but we got {type(arr)}') - arr[mask] = vals + if not isinstance(arr, BaseArray): + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(arr)}') + arr[mask] = vals polydiv = _compatible_with_brainpy_array(jnp.polydiv) def put(a, ind, v): - if not isinstance(a, BaseArray): - raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') - a[ind] = v + if not isinstance(a, BaseArray): + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') + a[ind] = v def putmask(a, mask, values): - if not isinstance(a, BaseArray): - raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') - if a.shape != values.shape: - raise ValueError('Only support the shapes of "a" and "values" are consistent.') - a[mask] = values + if not isinstance(a, BaseArray): + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') + if a.shape != values.shape: + raise ValueError('Only support the shapes of "a" and "values" are consistent.') + a[mask] = values def safe_eval(source): - return tree_map(Array, np.safe_eval(source)) + return tree_map(Array, np.safe_eval(source)) def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', footer='', comments='# ', encoding=None): - X = as_numpy(X) - np.savetxt(fname, X, fmt=fmt, delimiter=delimiter, newline=newline, header=header, - footer=footer, comments=comments, encoding=encoding) + X = as_numpy(X) + np.savetxt(fname, X, fmt=fmt, delimiter=delimiter, newline=newline, header=header, + footer=footer, comments=comments, encoding=encoding) def savez_compressed(file, *args, **kwds): - args = tuple([(as_numpy(a) if isinstance(a, (jnp.ndarray, BaseArray)) else a) for a in args]) - kwds = {k: (as_numpy(v) if isinstance(v, (jnp.ndarray, BaseArray)) else v) - for k, v in kwds.items()} - np.savez_compressed(file, *args, **kwds) + args = tuple([(as_numpy(a) if isinstance(a, (jnp.ndarray, BaseArray)) else a) for a in args]) + kwds = {k: (as_numpy(v) if isinstance(v, (jnp.ndarray, BaseArray)) else v) + for k, v in kwds.items()} + np.savez_compressed(file, *args, **kwds) show_config = np.show_config @@ -763,30 +764,30 @@ def savez_compressed(file, *args, **kwds): def copyto(dst, src): - if not isinstance(dst, BaseArray): - raise ValueError('dst must be an instance of ArrayType.') - dst[:] = src + if not isinstance(dst, BaseArray): + raise ValueError('dst must be an instance of ArrayType.') + dst[:] = src def matrix(data, dtype=None): - data = array(data, copy=True, dtype=dtype) - if data.ndim > 2: - raise ValueError(f'shape too large {data.shape} to be a matrix.') - if data.ndim != 2: - for i in range(2 - data.ndim): - data = expand_dims(data, 0) - return data + data = array(data, copy=True, dtype=dtype) + if data.ndim > 2: + raise ValueError(f'shape too large {data.shape} to be a matrix.') + if data.ndim != 2: + for i in range(2 - data.ndim): + data = expand_dims(data, 0) + return data def asmatrix(data, dtype=None): - data = array(data, dtype=dtype) - if data.ndim > 2: - raise ValueError(f'shape too large {data.shape} to be a matrix.') - if data.ndim != 2: - for i in range(2 - data.ndim): - data = expand_dims(data, 0) - return data + data = array(data, dtype=dtype) + if data.ndim > 2: + raise ValueError(f'shape too large {data.shape} to be a matrix.') + if data.ndim != 2: + for i in range(2 - data.ndim): + data = expand_dims(data, 0) + return data def mat(data, dtype=None): - return asmatrix(data, dtype=dtype) + return asmatrix(data, dtype=dtype) diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 8b8684e02..d58b02a7e 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -8,30 +8,30 @@ from .ndarray import Array, _as_jax_array_, _return, _check_out __all__ = [ - 'Tensor', - 'flatten', - 'unflatten', - 'cat', - 'abs', - 'absolute', - 'acos', - 'arccos', - 'acosh', - 'arccosh', - 'add', - 'addcdiv', - 'addcmul', - 'angle', - 'asin', - 'arcsin', - 'asinh', - 'arcsin', - 'atan', - 'arctan', - 'atan2', - 'atanh', - 'clamp_max', - 'clamp_min', + 'Tensor', + 'flatten', + 'unflatten', + 'cat', + 'abs', + 'absolute', + 'acos', + 'arccos', + 'acosh', + 'arccosh', + 'add', + 'addcdiv', + 'addcmul', + 'angle', + 'asin', + 'arcsin', + 'asinh', + 'arcsin', + 'atan', + 'arctan', + 'atan2', + 'atanh', + 'clamp_max', + 'clamp_min', ] Tensor = Array @@ -41,95 +41,95 @@ def flatten(input: Union[jax.Array, Array], start_dim: Optional[int] = None, end_dim: Optional[int] = None) -> jax.Array: - """Flattens input by reshaping it into a one-dimensional tensor. - If ``start_dim`` or ``end_dim`` are passed, only dimensions starting - with ``start_dim`` and ending with ``end_dim`` are flattened. - The order of elements in input is unchanged. - - .. note:: - Flattening a zero-dimensional tensor will return a one-dimensional view. - - Parameters:: - - input: Array - The input array. - start_dim: int - the first dim to flatten - end_dim: int - the last dim to flatten - - Returns:: - - out: Array - """ - input = _as_jax_array_(input) - shape = input.shape - ndim = input.ndim - if ndim == 0: - ndim = 1 - if start_dim is None: - start_dim = 0 - elif start_dim < 0: - start_dim = ndim + start_dim - if end_dim is None: - end_dim = ndim - 1 - elif end_dim < 0: - end_dim = ndim + end_dim - end_dim += 1 - if start_dim < 0 or start_dim > ndim: - raise ValueError(f'start_dim {start_dim} is out of size.') - if end_dim < 0 or end_dim > ndim: - raise ValueError(f'end_dim {end_dim} is out of size.') - new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int),) + shape[end_dim:] - return jnp.reshape(input, new_shape) + """Flattens input by reshaping it into a one-dimensional tensor. + If ``start_dim`` or ``end_dim`` are passed, only dimensions starting + with ``start_dim`` and ending with ``end_dim`` are flattened. + The order of elements in input is unchanged. + + .. note:: + Flattening a zero-dimensional tensor will return a one-dimensional view. + + Parameters:: + + input: Array + The input array. + start_dim: int + the first dim to flatten + end_dim: int + the last dim to flatten + + Returns:: + + out: Array + """ + input = _as_jax_array_(input) + shape = input.shape + ndim = input.ndim + if ndim == 0: + ndim = 1 + if start_dim is None: + start_dim = 0 + elif start_dim < 0: + start_dim = ndim + start_dim + if end_dim is None: + end_dim = ndim - 1 + elif end_dim < 0: + end_dim = ndim + end_dim + end_dim += 1 + if start_dim < 0 or start_dim > ndim: + raise ValueError(f'start_dim {start_dim} is out of size.') + if end_dim < 0 or end_dim > ndim: + raise ValueError(f'end_dim {end_dim} is out of size.') + new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int),) + shape[end_dim:] + return jnp.reshape(input, new_shape) def unflatten(x: Union[jax.Array, Array], dim: int, sizes: Sequence[int]) -> Array: - """ - Expands a dimension of the input tensor over multiple dimensions. - - Args: - x: input tensor. - dim: Dimension to be unflattened, specified as an index into ``x.shape``. - sizes: New shape of the unflattened dimension. One of its elements can be -1 - in which case the corresponding output dimension is inferred. - Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``. - - Returns: - A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions. - The returned tensor has one more dimension than the input tensor. - The returned tensor shares the same underlying data with this tensor. - """ - assert x.ndim > dim, ('The dimension to be unflattened should be less than the tensor dimension. ' - f'Got {dim} and {x.ndim}.') - x = _as_jax_array_(x) - shape = x.shape - new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:] - r = jnp.reshape(x, new_shape) - return _return(r) + """ + Expands a dimension of the input tensor over multiple dimensions. + + Args: + x: input tensor. + dim: Dimension to be unflattened, specified as an index into ``x.shape``. + sizes: New shape of the unflattened dimension. One of its elements can be -1 + in which case the corresponding output dimension is inferred. + Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``. + + Returns: + A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions. + The returned tensor has one more dimension than the input tensor. + The returned tensor shares the same underlying data with this tensor. + """ + assert x.ndim > dim, ('The dimension to be unflattened should be less than the tensor dimension. ' + f'Got {dim} and {x.ndim}.') + x = _as_jax_array_(x) + shape = x.shape + new_shape = shape[:dim] + tuple(sizes) + shape[dim + 1:] + r = jnp.reshape(x, new_shape) + return _return(r) def unsqueeze(x: Union[jax.Array, Array], dim: int) -> Array: - """Returns a new tensor with a dimension of size one inserted at the specified position. + """Returns a new tensor with a dimension of size one inserted at the specified position. - The returned tensor shares the same underlying data with this tensor. - A dim value within the range ``[-input.dim() - 1, input.dim() + 1)`` can be used. - Negative dim will correspond to unsqueeze() applied at ``dim = dim + input.dim() + 1``. + The returned tensor shares the same underlying data with this tensor. + A dim value within the range ``[-input.dim() - 1, input.dim() + 1)`` can be used. + Negative dim will correspond to unsqueeze() applied at ``dim = dim + input.dim() + 1``. - Parameters:: + Parameters:: - x: Array - The input Array - dim: int - The index at which to insert the singleton dimension + x: Array + The input Array + dim: int + The index at which to insert the singleton dimension - Returns:: + Returns:: - out: Array - """ - x = _as_jax_array_(x) - r = jnp.expand_dims(x, dim) - return _return(r) + out: Array + """ + x = _as_jax_array_(x) + r = jnp.expand_dims(x, dim) + return _return(r) # Math operations @@ -138,13 +138,13 @@ def abs( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.abs(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.abs(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r absolute = abs @@ -155,13 +155,13 @@ def acos( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.arccos(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.arccos(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arccos = acos @@ -172,13 +172,13 @@ def acosh( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.arccosh(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.arccosh(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arccosh = acosh @@ -191,23 +191,23 @@ def add( alpha: Optional[jnp.number] = 1, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - r""" - Adds ``other``, scaled by ``alpha``, to ``input``. + r""" + Adds ``other``, scaled by ``alpha``, to ``input``. - .. math:: + .. math:: - \text { out }_i=\text { input }_i+\text { alpha } \times \text { other }_i + \text { out }_i=\text { input }_i+\text { alpha } \times \text { other }_i - """ - x = _as_jax_array_(x) - y = _as_jax_array_(y) - y = jnp.multiply(alpha, y) - r = jnp.add(x, y) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + """ + x = _as_jax_array_(x) + y = _as_jax_array_(y) + y = jnp.multiply(alpha, y) + r = jnp.add(x, y) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r def addcdiv( @@ -218,10 +218,10 @@ def addcdiv( value: jnp.number = 1, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - tensor1 = _as_jax_array_(tensor1) - tensor2 = _as_jax_array_(tensor2) - other = jnp.divide(tensor1, tensor2) - return add(x, other, alpha=value, out=out) + tensor1 = _as_jax_array_(tensor1) + tensor2 = _as_jax_array_(tensor2) + other = jnp.divide(tensor1, tensor2) + return add(x, other, alpha=value, out=out) def addcmul( @@ -232,10 +232,10 @@ def addcmul( value: jnp.number = 1, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - tensor1 = _as_jax_array_(tensor1) - tensor2 = _as_jax_array_(tensor2) - other = jnp.multiply(tensor1, tensor2) - return add(x, other, alpha=value, out=out) + tensor1 = _as_jax_array_(tensor1) + tensor2 = _as_jax_array_(tensor2) + other = jnp.multiply(tensor1, tensor2) + return add(x, other, alpha=value, out=out) def angle( @@ -243,13 +243,13 @@ def angle( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.angle(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.angle(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r def asin( @@ -257,13 +257,13 @@ def asin( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.arcsin(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.arcsin(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arcsin = asin @@ -274,13 +274,13 @@ def asinh( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.arcsinh(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.arcsinh(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arcsinh = asinh @@ -291,13 +291,13 @@ def atan( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.arctan(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.arctan(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arctan = atan @@ -308,13 +308,13 @@ def atanh( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x = _as_jax_array_(x) - r = jnp.arctanh(x) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x = _as_jax_array_(x) + r = jnp.arctanh(x) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arctanh = atanh @@ -326,14 +326,14 @@ def atan2( *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None ) -> Optional[Array]: - x1 = _as_jax_array_(x1) - x2 = _as_jax_array_(x2) - r = jnp.arctan2(x1, x2) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + x1 = _as_jax_array_(x1) + x2 = _as_jax_array_(x2) + r = jnp.arctan2(x1, x2) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r arctan2 = atan2 diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py index bd4a1392f..7ba88a37a 100644 --- a/brainpy/_src/math/compat_tensorflow.py +++ b/brainpy/_src/math/compat_tensorflow.py @@ -7,19 +7,19 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from .compat_numpy import ( - prod, min, sum, all, any, mean, std, var, concatenate, clip, - asarray, + prod, min, sum, all, any, mean, std, var, concatenate, clip, + asarray, ) from .ndarray import _return, _as_jax_array_ __all__ = [ - 'concat', - 'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', 'reduce_any', - 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', 'reduce_euclidean_norm', - 'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', 'unsorted_segment_prod', - 'unsorted_segment_max', 'unsorted_segment_min', 'unsorted_segment_mean', - 'segment_sum', 'segment_prod', 'segment_max', 'segment_min', - 'clip_by_value', 'cast', + 'concat', + 'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', 'reduce_any', + 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', 'reduce_euclidean_norm', + 'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', 'unsorted_segment_prod', + 'unsorted_segment_max', 'unsorted_segment_min', 'unsorted_segment_mean', + 'segment_sum', 'segment_prod', 'segment_max', 'segment_min', + 'clip_by_value', 'cast', ] reduce_prod = prod @@ -35,185 +35,185 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False): - """Computes log(sum(exp(elements across dimensions of a tensor))). + """Computes log(sum(exp(elements across dimensions of a tensor))). - Reduces `input_tensor` along the dimensions given in `axis`. + Reduces `input_tensor` along the dimensions given in `axis`. - Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - of the entries in `axis`, which must be unique. If `keepdims` is true, the - reduced dimensions are retained with length 1. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. - If `axis` has no entries, all dimensions are reduced, and a - tensor with a single element is returned. + If `axis` has no entries, all dimensions are reduced, and a + tensor with a single element is returned. - This function is more numerically stable than log(sum(exp(input))). It avoids - overflows caused by taking the exp of large inputs and underflows caused by - taking the log of small inputs. + This function is more numerically stable than log(sum(exp(input))). It avoids + overflows caused by taking the exp of large inputs and underflows caused by + taking the log of small inputs. - Args: - input_tensor: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all - dimensions. Must be in the range `[-rank(input_tensor), - rank(input_tensor))`. - keepdims: If true, retains reduced dimensions with length 1. + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. - Returns: - The reduced tensor. - """ - r = jnp.log(jnp.sum(jnp.exp(_as_jax_array_(input_tensor)), axis=axis, keepdims=keepdims)) - return _return(r) + Returns: + The reduced tensor. + """ + r = jnp.log(jnp.sum(jnp.exp(_as_jax_array_(input_tensor)), axis=axis, keepdims=keepdims)) + return _return(r) def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False): - """Computes the Euclidean norm of elements across dimensions of a tensor. - Reduces `input_tensor` along the dimensions given in `axis`. + """Computes the Euclidean norm of elements across dimensions of a tensor. + Reduces `input_tensor` along the dimensions given in `axis`. - Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - of the entries in `axis`, which must be unique. If `keepdims` is true, the - reduced dimensions are retained with length 1. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. - If `axis` is None, all dimensions are reduced, and a - tensor with a single element is returned. + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. - Args: - input_tensor: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all - dimensions. Must be in the range `[-rank(input_tensor), - rank(input_tensor))`. - keepdims: If true, retains reduced dimensions with length 1. + Args: + input_tensor: The tensor to reduce. Should have numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. - Returns: - The reduced tensor, of the same dtype as the input_tensor. - """ - r = jnp.linalg.norm(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims) - return _return(r) + Returns: + The reduced tensor, of the same dtype as the input_tensor. + """ + r = jnp.linalg.norm(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims) + return _return(r) def reduce_max(input_tensor, axis=None, keepdims=False): - """Computes `maximum` of elements across dimensions of a tensor. - - This is the reduction operation for the elementwise `maximum` op. - Reduces `input_tensor` along the dimensions given in `axis`. + """Computes `maximum` of elements across dimensions of a tensor. - Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each - of the entries in `axis`, which must be unique. If `keepdims` is true, the - reduced dimensions are retained with length 1. + This is the reduction operation for the elementwise `maximum` op. + Reduces `input_tensor` along the dimensions given in `axis`. - If `axis` is None, all dimensions are reduced, and a - tensor with a single element is returned. + Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each + of the entries in `axis`, which must be unique. If `keepdims` is true, the + reduced dimensions are retained with length 1. - Args: - input_tensor: The tensor to reduce. Should have real numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all - dimensions. Must be in the range `[-rank(input_tensor), - rank(input_tensor))`. - keepdims: If true, retains reduced dimensions with length 1. + If `axis` is None, all dimensions are reduced, and a + tensor with a single element is returned. - Returns: - The reduced tensor. - """ - return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims)) + Args: + input_tensor: The tensor to reduce. Should have real numeric type. + axis: The dimensions to reduce. If `None` (the default), reduces all + dimensions. Must be in the range `[-rank(input_tensor), + rank(input_tensor))`. + keepdims: If true, retains reduced dimensions with length 1. + + Returns: + The reduced tensor. + """ + return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims)) def segment_mean(data, segment_ids): - """Computes the average along segments of a tensor. + """Computes the average along segments of a tensor. - See https://tensorflow.google.cn/api_docs/python/tf/math/segment_mean + See https://tensorflow.google.cn/api_docs/python/tf/math/segment_mean - """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), - indices_are_sorted=False) - d = jax.ops.segment_sum(jnp.ones_like(data), - _as_jax_array_(segment_ids), - indices_are_sorted=False) - return _return(jnp.nan_to_num(r / d)) + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + indices_are_sorted=False) + d = jax.ops.segment_sum(jnp.ones_like(data), + _as_jax_array_(segment_ids), + indices_are_sorted=False) + return _return(jnp.nan_to_num(r / d)) def unsorted_segment_sum(data, segment_ids, num_segments): - """Computes the sum along segments of a tensor. + """Computes the sum along segments of a tensor. - See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sum + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sum - """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - return _return(r) + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + return _return(r) def unsorted_segment_prod(data, segment_ids, num_segments): - """Computes the product along segments of a tensor. + """Computes the product along segments of a tensor. - See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_prod + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_prod - """ - r = jax.ops.segment_prod(_as_jax_array_(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - return _return(r) + """ + r = jax.ops.segment_prod(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + return _return(r) def unsorted_segment_max(data, segment_ids, num_segments): - """Computes the maximum along segments of a tensor. + """Computes the maximum along segments of a tensor. - See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_max + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_max - """ - r = jax.ops.segment_max(_as_jax_array_(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - return _return(r) + """ + r = jax.ops.segment_max(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + return _return(r) def unsorted_segment_min(data, segment_ids, num_segments): - """Computes the minimum along segments of a tensor. + """Computes the minimum along segments of a tensor. - See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_min + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_min - """ - r = jax.ops.segment_min(_as_jax_array_(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - return _return(r) + """ + r = jax.ops.segment_min(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + return _return(r) def unsorted_segment_sqrt_n(data, segment_ids, num_segments): - """Computes the sum along segments of a tensor divided by the sqrt(N). + """Computes the sum along segments of a tensor divided by the sqrt(N). - See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sqrt_n + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_sqrt_n - """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - d = jax.ops.segment_sum(jnp.ones_like(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - return _return(jnp.nan_to_num(r / jnp.sqrt(d))) + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + d = jax.ops.segment_sum(jnp.ones_like(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + return _return(jnp.nan_to_num(r / jnp.sqrt(d))) def unsorted_segment_mean(data, segment_ids, num_segments): - """Computes the average along segments of a tensor. + """Computes the average along segments of a tensor. - See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_mean + See https://tensorflow.google.cn/api_docs/python/tf/math/unsorted_segment_mean - """ - r = jax.ops.segment_sum(_as_jax_array_(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - d = jax.ops.segment_sum(jnp.ones_like(data), - _as_jax_array_(segment_ids), - num_segments=num_segments, - indices_are_sorted=False) - return _return(jnp.nan_to_num(r / d)) + """ + r = jax.ops.segment_sum(_as_jax_array_(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + d = jax.ops.segment_sum(jnp.ones_like(data), + _as_jax_array_(segment_ids), + num_segments=num_segments, + indices_are_sorted=False) + return _return(jnp.nan_to_num(r / d)) def segment_sum(data: Union[Array, jnp.ndarray], @@ -223,49 +223,49 @@ def segment_sum(data: Union[Array, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_sum`` operator for brainpy `Array` and `Variable`. - - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return _return(jax.ops.segment_sum(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + """``segment_sum`` operator for brainpy `Array` and `Variable`. + + Parameters:: + + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns:: + + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return _return(jax.ops.segment_sum(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_prod(data: Union[Array, jnp.ndarray], @@ -275,49 +275,49 @@ def segment_prod(data: Union[Array, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_prod`` operator for brainpy `Array` and `Variable`. - - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return _return(jax.ops.segment_prod(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + """``segment_prod`` operator for brainpy `Array` and `Variable`. + + Parameters:: + + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns:: + + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return _return(jax.ops.segment_prod(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_max(data: Union[Array, jnp.ndarray], @@ -327,49 +327,49 @@ def segment_max(data: Union[Array, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_max`` operator for brainpy `Array` and `Variable`. - - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return _return(jax.ops.segment_max(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + """``segment_max`` operator for brainpy `Array` and `Variable`. + + Parameters:: + + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns:: + + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return _return(jax.ops.segment_max(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_min(data: Union[Array, jnp.ndarray], @@ -379,76 +379,76 @@ def segment_min(data: Union[Array, jnp.ndarray], unique_indices: bool = False, bucket_size: Optional[int] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: - """``segment_min`` operator for brainpy `Array` and `Variable`. - - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. - """ - return _return(jax.ops.segment_min(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + """``segment_min`` operator for brainpy `Array` and `Variable`. + + Parameters:: + + data: Array + An array with the values to be reduced. + segment_ids: Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments: Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted: bool + whether ``segment_ids`` is known to be sorted. + unique_indices: bool + whether `segment_ids` is known to be free of duplicates. + bucket_size: int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode: lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns:: + + output: Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. + """ + return _return(jax.ops.segment_min(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def cast(x, dtype): - """Casts a tensor to a new type. - - The operation casts `x` (in case of `Tensor`) or `x.values` - (in case of `SparseTensor` or `IndexedSlices`) to `dtype`. - - The operation supports data types (for `x` and `dtype`) of - `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, - `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. - In case of casting from complex types (`complex64`, `complex128`) to real - types, only the real part of `x` is returned. In case of casting from real - types to complex types (`complex64`, `complex128`), the imaginary part of the - returned value is set to `0`. The handling of complex types here matches the - behavior of numpy. - - Note casting nan and inf values to integral types has undefined behavior. - - Args: - x: A `Array`. It could be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, - `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`, - `bfloat16`. - dtype: The destination type. The list of supported dtypes is the same as - `x`. - Returns: - A `Array` with same shape as `x` and same type as `dtype`. - - """ - return asarray(x, dtype=dtype) + """Casts a tensor to a new type. + + The operation casts `x` (in case of `Tensor`) or `x.values` + (in case of `SparseTensor` or `IndexedSlices`) to `dtype`. + + The operation supports data types (for `x` and `dtype`) of + `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, + `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. + In case of casting from complex types (`complex64`, `complex128`) to real + types, only the real part of `x` is returned. In case of casting from real + types to complex types (`complex64`, `complex128`), the imaginary part of the + returned value is set to `0`. The handling of complex types here matches the + behavior of numpy. + + Note casting nan and inf values to integral types has undefined behavior. + + Args: + x: A `Array`. It could be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, + `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`, + `bfloat16`. + dtype: The destination type. The list of supported dtypes is the same as + `x`. + Returns: + A `Array` with same shape as `x` and same type as `dtype`. + + """ + return asarray(x, dtype=dtype) diff --git a/brainpy/_src/math/datatypes.py b/brainpy/_src/math/datatypes.py index 40e922103..efa6004f6 100644 --- a/brainpy/_src/math/datatypes.py +++ b/brainpy/_src/math/datatypes.py @@ -1,13 +1,12 @@ import jax.numpy as jnp __all__ = [ - 'uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', - 'float64', 'complex64', 'complex128', + 'uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32', + 'float64', 'complex64', 'complex128', - 'bfloat16', 'half', 'float', 'double', 'cfloat', 'cdouble', 'short', 'int', 'long', 'bool' + 'bfloat16', 'half', 'float', 'double', 'cfloat', 'cdouble', 'short', 'int', 'long', 'bool' ] - uint8 = jnp.uint8 uint16 = jnp.uint16 uint32 = jnp.uint32 @@ -22,7 +21,6 @@ complex64 = jnp.complex64 complex128 = jnp.complex128 - # data types in PyTorch bfloat16 = jnp.bfloat16 half = jnp.float16 @@ -41,4 +39,3 @@ # qint8 = jnp.qint8 # qint32 = jnp.qint32 # quint4x2 = jnp.quint4x2 - diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 0f53f464e..ba3c46114 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable import numbers +from typing import Union, Callable import jax import jax.numpy as jnp @@ -20,20 +20,20 @@ from .object_transform.variables import Variable __all__ = [ - 'AbstractDelay', - 'TimeDelay', 'LengthDelay', - 'NeuTimeDelay', 'NeuLenDelay', - 'ROTATE_UPDATE', - 'CONCAT_UPDATE', + 'AbstractDelay', + 'TimeDelay', 'LengthDelay', + 'NeuTimeDelay', 'NeuLenDelay', + 'ROTATE_UPDATE', + 'CONCAT_UPDATE', ] def _as_jax_array(arr): - return arr.value if isinstance(arr, BaseArray) else arr + return arr.value if isinstance(arr, BaseArray) else arr class AbstractDelay(BrainPyObject): - pass + pass _FUNC_BEFORE = 'function' @@ -43,444 +43,444 @@ class AbstractDelay(BrainPyObject): class TimeDelay(AbstractDelay): - r"""Delay variable which has a fixed delay time length. - - For example, we create a delay variable which has a maximum delay length of 1 ms - - >>> import brainpy.math as bm - >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1) - >>> delay(-0.5) - [-0. -0. -0.] - - This function supports multiple dimensions of the tensor. For example, - - 1. the one-dimensional delay data - - >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) - >>> delay(-0.2) - [-0.2 -0.2 -0.2] - - 2. the two-dimensional delay data - - >>> delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) - >>> delay(-0.6) - [[-0.6 -0.6] - [-0.6 -0.6] - [-0.6 -0.6]] - - 3. the three-dimensional delay data - - >>> delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) - >>> delay(-0.8) - [[[-0.8] - [-0.8]] - [[-0.8] - [-0.8]] - [[-0.8] - [-0.8]]] - - Parameters:: - - delay_target: ArrayType - The initial delay data. - t0: float, int - The zero time. - delay_len: float, int - The maximum delay length. - dt: float, int - The time precesion. - before_t0: callable, bm.ndarray, jnp.ndarray, float, int - The delay data before ::math`t_0`. - - when `before_t0` is a function, it should receive a time argument `t` - - when `before_to` is a tensor, it should be a tensor with shape - of :math:`(num_delay, ...)`, where the longest delay data is aranged in - the first index. - name: str - The delay instance name. - interp_method: str - The way to deal with the delay at the time which is not integer times of the time step. - For exameple, if the time step ``dt=0.1``, the time delay length ``delay\_len=1.``, - when users require the delay data at ``t-0.53``, we can deal this situation with - the following methods: - - - ``"linear_interp"``: using linear interpolation to get the delay value - at the required time (default). - - ``"round"``: round the time to make it is the integer times of the time step. For - the above situation, we will use the time at ``t-0.5`` to approximate the delay data - at ``t-0.53``. - - .. versionadded:: 2.1.1 - - See Also:: - - LengthDelay - """ - - def __init__( - self, - delay_target: Union[ndarray, jnp.ndarray], - delay_len: Union[float, int], - before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None, - t0: Union[float, int] = 0., - dt: Union[float, int] = None, - name: str = None, - interp_method: str = 'linear_interp', - ): - super(TimeDelay, self).__init__(name=name) - - # shape - if not isinstance(delay_target, (jnp.ndarray, BaseArray)): - raise ValueError(f'Must be an instance of BaseArray or jax.numpy.ndarray. But we got {type(delay_target)}') - - # delay_len - self.t0 = t0 - self.dt = get_dt() if dt is None else dt - is_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) - self.delay_len = delay_len - self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 - - # interp method - if interp_method not in [_INTERP_LINEAR, _INTERP_ROUND]: - raise UnsupportedError(f'Un-supported interpolation method {interp_method}, ' - f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') - self.interp_method = interp_method - - # time variables - self.idx = Variable(jnp.asarray([0])) - is_float(t0, 't0', allow_none=False, allow_int=True, ) - self.current_time = Variable(jnp.asarray([t0], dtype=get_float())) - - # delay data - batch_axis = None - if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): - batch_axis = delay_target.batch_axis + 1 - self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, - dtype=delay_target.dtype), - batch_axis=batch_axis) - if before_t0 is None: - self._before_type = _DATA_BEFORE - elif callable(before_t0): - self._before_t0 = lambda t: as_jax(broadcast_to(before_t0(t), delay_target.shape), - dtype=delay_target.dtype) - self._before_type = _FUNC_BEFORE - elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): - self._before_type = _DATA_BEFORE - self.data[:-1] = before_t0 - else: - raise ValueError(f'"before_t0" does not support {type(before_t0)}') - # set initial data - self.data[-1] = delay_target - - # interpolation function - self._interp_fun = jnp.interp - for dim in range(1, delay_target.ndim + 1, 1): - self._interp_fun = vmap(self._interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) - - def reset(self, - delay_target, - delay_len, - t0: Union[float, int] = 0., - before_t0=None): - """Reset the delay variable. + r"""Delay variable which has a fixed delay time length. - Parameters:: + For example, we create a delay variable which has a maximum delay length of 1 ms - delay_target: ArrayType - The delay target. - delay_len: float, int - The maximum delay length. The unit is the time. - t0: int, float - The zero time. - before_t0: int, float, ArrayType - The data before t0. - """ - self.delay_len = delay_len - self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 - self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) - self.data[-1] = delay_target - self.idx = Variable(jnp.asarray([0])) - self.current_time = Variable(jnp.asarray([t0])) - if before_t0 is not None: - if not isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): - raise ValueError('Only support numerical values.') - self.data[:-1] = before_t0 - self._before_type = _DATA_BEFORE - - def _check_time1(self, times): - prev_time, current_time = times - raise ValueError(f'The request time should be less than the ' - f'current time {current_time}. But we ' - f'got {prev_time} > {current_time}') - - def _check_time2(self, times): - prev_time, current_time = times - raise ValueError(f'The request time of the variable should be in ' - f'[{current_time - self.delay_len}, {current_time}], ' - f'but we got {prev_time}') - - def __call__(self, time, indices=None): - # check - if check.is_checking(): - current_time = self.current_time[0] - jit_error(time > current_time + 1e-6, - self._check_time1, - (time, current_time)) - jit_error(time < current_time - self.delay_len - self.dt, - self._check_time2, - (time, current_time)) - if self._before_type == _FUNC_BEFORE: - res = cond(time < self.t0, - self._before_t0, - self._after_t0, - time) - else: - res = self._after_t0(time) - if indices is not None: # TODO: indices is highly inefficient - res = res[indices] - return res - - def _after_t0(self, prev_time): - diff = self.delay_len - (self.current_time[0] - prev_time) - if isinstance(diff, ndarray): - diff = diff.value - if self.interp_method == _INTERP_LINEAR: - req_num_step = jnp.asarray(diff / self.dt, dtype=jnp.int32) - extra = diff - req_num_step * self.dt - return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra)) - elif self.interp_method == _INTERP_ROUND: - req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=jnp.int32) - return self._true_fn(req_num_step, 0.) - else: - raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, ' - f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') - - def _true_fn(self, req_num_step, extra): - return self.data[self.idx[0] + req_num_step] - - def _false_fn(self, req_num_step, extra): - idx = jnp.asarray([self.idx[0] + req_num_step, - self.idx[0] + req_num_step + 1]) - idx %= self.num_delay_step - return self._interp_fun(extra, jnp.asarray([0., self.dt]), self.data[idx]) - - def update(self, value): - self.data[self.idx[0]] = value - self.current_time += self.dt - self.idx.value = (self.idx + 1) % self.num_delay_step + >>> import brainpy.math as bm + >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1) + >>> delay(-0.5) + [-0. -0. -0.] + This function supports multiple dimensions of the tensor. For example, -class NeuTimeDelay(TimeDelay): - """Neutral Time Delay. Alias of :py:class:`~.TimeDelay`.""" - pass + 1. the one-dimensional delay data + >>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) + >>> delay(-0.2) + [-0.2 -0.2 -0.2] -ROTATE_UPDATE = 'rotation' -CONCAT_UPDATE = 'concat' + 2. the two-dimensional delay data + >>> delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) + >>> delay(-0.6) + [[-0.6 -0.6] + [-0.6 -0.6] + [-0.6 -0.6]] -class LengthDelay(AbstractDelay): - """Delay variable which has a fixed delay length. + 3. the three-dimensional delay data - Parameters:: + >>> delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) + >>> delay(-0.8) + [[[-0.8] + [-0.8]] + [[-0.8] + [-0.8]] + [[-0.8] + [-0.8]]] - delay_target: int, sequence of int - The initial delay data. - delay_len: int - The maximum delay length. - initial_delay_data: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: + Parameters:: - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = delay_len-1 data - delay = delay_len data ] + delay_target: ArrayType + The initial delay data. + t0: float, int + The zero time. + delay_len: float, int + The maximum delay length. + dt: float, int + The time precesion. + before_t0: callable, bm.ndarray, jnp.ndarray, float, int + The delay data before ::math`t_0`. + - when `before_t0` is a function, it should receive a time argument `t` + - when `before_to` is a tensor, it should be a tensor with shape + of :math:`(num_delay, ...)`, where the longest delay data is aranged in + the first index. + name: str + The delay instance name. + interp_method: str + The way to deal with the delay at the time which is not integer times of the time step. + For exameple, if the time step ``dt=0.1``, the time delay length ``delay\_len=1.``, + when users require the delay data at ``t-0.53``, we can deal this situation with + the following methods: + + - ``"linear_interp"``: using linear interpolation to get the delay value + at the required time (default). + - ``"round"``: round the time to make it is the integer times of the time step. For + the above situation, we will use the time at ``t-0.5`` to approximate the delay data + at ``t-0.53``. + + .. versionadded:: 2.1.1 + + See Also:: + + LengthDelay + """ - .. versionchanged:: 2.2.3.2 + def __init__( + self, + delay_target: Union[ndarray, jnp.ndarray], + delay_len: Union[float, int], + before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None, + t0: Union[float, int] = 0., + dt: Union[float, int] = None, + name: str = None, + interp_method: str = 'linear_interp', + ): + super(TimeDelay, self).__init__(name=name) + + # shape + if not isinstance(delay_target, (jnp.ndarray, BaseArray)): + raise ValueError(f'Must be an instance of BaseArray or jax.numpy.ndarray. But we got {type(delay_target)}') + + # delay_len + self.t0 = t0 + self.dt = get_dt() if dt is None else dt + is_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) + self.delay_len = delay_len + self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 + + # interp method + if interp_method not in [_INTERP_LINEAR, _INTERP_ROUND]: + raise UnsupportedError(f'Un-supported interpolation method {interp_method}, ' + f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') + self.interp_method = interp_method + + # time variables + self.idx = Variable(jnp.asarray([0])) + is_float(t0, 't0', allow_none=False, allow_int=True, ) + self.current_time = Variable(jnp.asarray([t0], dtype=get_float())) + + # delay data + batch_axis = None + if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None): + batch_axis = delay_target.batch_axis + 1 + self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, + dtype=delay_target.dtype), + batch_axis=batch_axis) + if before_t0 is None: + self._before_type = _DATA_BEFORE + elif callable(before_t0): + self._before_t0 = lambda t: as_jax(broadcast_to(before_t0(t), delay_target.shape), + dtype=delay_target.dtype) + self._before_type = _FUNC_BEFORE + elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): + self._before_type = _DATA_BEFORE + self.data[:-1] = before_t0 + else: + raise ValueError(f'"before_t0" does not support {type(before_t0)}') + # set initial data + self.data[-1] = delay_target + + # interpolation function + self._interp_fun = jnp.interp + for dim in range(1, delay_target.ndim + 1, 1): + self._interp_fun = vmap(self._interp_fun, in_axes=(None, None, dim), out_axes=dim - 1) + + def reset(self, + delay_target, + delay_len, + t0: Union[float, int] = 0., + before_t0=None): + """Reset the delay variable. + + Parameters:: + + delay_target: ArrayType + The delay target. + delay_len: float, int + The maximum delay length. The unit is the time. + t0: int, float + The zero time. + before_t0: int, float, ArrayType + The data before t0. + """ + self.delay_len = delay_len + self.num_delay_step = int(jnp.ceil(self.delay_len / self.dt)) + 1 + self.data.value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) + self.data[-1] = delay_target + self.idx = Variable(jnp.asarray([0])) + self.current_time = Variable(jnp.asarray([t0])) + if before_t0 is not None: + if not isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): + raise ValueError('Only support numerical values.') + self.data[:-1] = before_t0 + self._before_type = _DATA_BEFORE + + def _check_time1(self, times): + prev_time, current_time = times + raise ValueError(f'The request time should be less than the ' + f'current time {current_time}. But we ' + f'got {prev_time} > {current_time}') + + def _check_time2(self, times): + prev_time, current_time = times + raise ValueError(f'The request time of the variable should be in ' + f'[{current_time - self.delay_len}, {current_time}], ' + f'but we got {prev_time}') + + def __call__(self, time, indices=None): + # check + if check.is_checking(): + current_time = self.current_time[0] + jit_error(time > current_time + 1e-6, + self._check_time1, + (time, current_time)) + jit_error(time < current_time - self.delay_len - self.dt, + self._check_time2, + (time, current_time)) + if self._before_type == _FUNC_BEFORE: + res = cond(time < self.t0, + self._before_t0, + self._after_t0, + time) + else: + res = self._after_t0(time) + if indices is not None: # TODO: indices is highly inefficient + res = res[indices] + return res + + def _after_t0(self, prev_time): + diff = self.delay_len - (self.current_time[0] - prev_time) + if isinstance(diff, ndarray): + diff = diff.value + if self.interp_method == _INTERP_LINEAR: + req_num_step = jnp.asarray(diff / self.dt, dtype=jnp.int32) + extra = diff - req_num_step * self.dt + return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra)) + elif self.interp_method == _INTERP_ROUND: + req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=jnp.int32) + return self._true_fn(req_num_step, 0.) + else: + raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, ' + f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') + + def _true_fn(self, req_num_step, extra): + return self.data[self.idx[0] + req_num_step] + + def _false_fn(self, req_num_step, extra): + idx = jnp.asarray([self.idx[0] + req_num_step, + self.idx[0] + req_num_step + 1]) + idx %= self.num_delay_step + return self._interp_fun(extra, jnp.asarray([0., self.dt]), self.data[idx]) + + def update(self, value): + self.data[self.idx[0]] = value + self.current_time += self.dt + self.idx.value = (self.idx + 1) % self.num_delay_step - The data in the previous version of ``LengthDelay`` is:: - delay = delay_len [ data - delay = delay_len-1 data - ... .... - ... .... - delay = 2 data - delay = 1 data ] - - - name: str - The delay object name. - batch_axis: int - The batch axis. If not provided, it will be inferred from the `delay_target`. - update_method: str - The method used for updating delay. - - See Also:: - - TimeDelay - """ - - def __init__( - self, - delay_target: Union[ndarray, jax.Array], - delay_len: int, - initial_delay_data: Union[float, int, bool, ndarray, jax.Array, Callable] = None, - name: str = None, - batch_axis: int = None, - update_method: str = ROTATE_UPDATE - ): - super(LengthDelay, self).__init__(name=name) - - assert update_method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.update_method = update_method - # attributes and variables - self.data: Variable = None - self.num_delay_step: int = 0 - self.idx: Variable = None - - self.delay_target = None - if isinstance(delay_target, Variable): - self.delay_target = delay_target - - # initialization - self.reset(delay_target, delay_len, initial_delay_data, batch_axis) - - @property - def delay_shape(self): - """The data shape of this delay variable.""" - return self.data.shape - - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.data.shape[1:] - - def __repr__(self): - name = self.__class__.__name__ - return (f'{name}(num_delay_step={self.num_delay_step}, ' - f'delay_target_shape={self.delay_target_shape}, ' - f'update_method={self.update_method})') - - def reset( - self, - delay_target, - delay_len: int = None, - initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, - batch_axis: int = None - ): - if not isinstance(delay_target, (ndarray, jnp.ndarray)): - raise ValueError(f'Must be an instance of brainpy.math.ndarray ' - f'or jax.numpy.ndarray. But we got {type(delay_target)}') - - # delay_len - is_integer(delay_len, 'delay_len', allow_none=True, min_bound=0) - if delay_len is None: - if self.num_delay_step is None: - raise ValueError('"delay_len" cannot be None.') - delay_len = self.num_delay_step - 1 - self.num_delay_step = delay_len + 1 - - # initialize delay data - if self.data is None: - if batch_axis is None: - if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None): - batch_axis = delay_target.batch_axis + 1 - self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, - dtype=delay_target.dtype), - batch_axis=batch_axis) - else: - self.data.value - self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, - dtype=delay_target.dtype) - - # update delay data - self.data[0] = delay_target - if initial_delay_data is None: - pass - elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): - self.data[1:] = initial_delay_data - elif callable(initial_delay_data): - self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape, - dtype=delay_target.dtype) - else: - raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') +class NeuTimeDelay(TimeDelay): + """Neutral Time Delay. Alias of :py:class:`~.TimeDelay`.""" + pass - # time variables - if self.update_method == ROTATE_UPDATE: - if self.idx is None: - self.idx = Variable(stop_gradient(jnp.asarray([0], dtype=jnp.int32))) - else: - self.idx.value = stop_gradient(jnp.asarray([0], dtype=jnp.int32)) - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.num_delay_step}. ' - f'But we got {delay_len}') +ROTATE_UPDATE = 'rotation' +CONCAT_UPDATE = 'concat' - def __call__(self, delay_len, *indices): - return self.retrieve(delay_len, *indices) - def retrieve(self, delay_len, *indices): - """Retrieve the delay data acoording to the delay length. +class LengthDelay(AbstractDelay): + """Delay variable which has a fixed delay length. Parameters:: - delay_len: int, ArrayType - The delay length used to retrieve the data. - """ - if check.is_checking(): - jit_error(jnp.any(as_jax(delay_len >= self.num_delay_step)), self._check_delay, delay_len) + delay_target: int, sequence of int + The initial delay data. + delay_len: int + The maximum delay length. + initial_delay_data: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: - if self.update_method == ROTATE_UPDATE: - delay_idx = (self.idx[0] + delay_len) % self.num_delay_step - delay_idx = stop_gradient(delay_idx) + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = delay_len-1 data + delay = delay_len data ] - elif self.update_method == CONCAT_UPDATE: - delay_idx = delay_len + .. versionchanged:: 2.2.3.2 - else: - raise ValueError(f'Unknown updating method "{self.update_method}"') + The data in the previous version of ``LengthDelay`` is:: - # the delay index - if isinstance(delay_idx, int): - pass - elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + tuple(indices) - # the delay data - return self.data[indices] + delay = delay_len [ data + delay = delay_len-1 data + ... .... + ... .... + delay = 2 data + delay = 1 data ] - def update(self, value: Union[numbers.Number, BaseArray, jax.Array] = None): - """Update delay variable with the new data. - Parameters:: + name: str + The delay object name. + batch_axis: int + The batch axis. If not provided, it will be inferred from the `delay_target`. + update_method: str + The method used for updating delay. - value: Any - The value of the latest data, used to update this delay variable. - """ - if value is None: - if self.delay_target is None: - raise ValueError('Must provide value.') - else: - value = self.delay_target.value + See Also:: - if self.update_method == ROTATE_UPDATE: - self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.num_delay_step)) - self.data[self.idx[0]] = value + TimeDelay + """ - elif self.update_method == CONCAT_UPDATE: - if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) - else: - self.data[:] = value + def __init__( + self, + delay_target: Union[ndarray, jax.Array], + delay_len: int, + initial_delay_data: Union[float, int, bool, ndarray, jax.Array, Callable] = None, + name: str = None, + batch_axis: int = None, + update_method: str = ROTATE_UPDATE + ): + super(LengthDelay, self).__init__(name=name) + + assert update_method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.update_method = update_method + # attributes and variables + self.data: Variable = None + self.num_delay_step: int = 0 + self.idx: Variable = None + + self.delay_target = None + if isinstance(delay_target, Variable): + self.delay_target = delay_target + + # initialization + self.reset(delay_target, delay_len, initial_delay_data, batch_axis) + + @property + def delay_shape(self): + """The data shape of this delay variable.""" + return self.data.shape + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.data.shape[1:] + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.num_delay_step}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.update_method})') + + def reset( + self, + delay_target, + delay_len: int = None, + initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None, + batch_axis: int = None + ): + if not isinstance(delay_target, (ndarray, jnp.ndarray)): + raise ValueError(f'Must be an instance of brainpy.math.ndarray ' + f'or jax.numpy.ndarray. But we got {type(delay_target)}') + + # delay_len + is_integer(delay_len, 'delay_len', allow_none=True, min_bound=0) + if delay_len is None: + if self.num_delay_step is None: + raise ValueError('"delay_len" cannot be None.') + delay_len = self.num_delay_step - 1 + self.num_delay_step = delay_len + 1 + + # initialize delay data + if self.data is None: + if batch_axis is None: + if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None): + batch_axis = delay_target.batch_axis + 1 + self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape, + dtype=delay_target.dtype), + batch_axis=batch_axis) + else: + self.data.value + self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, + dtype=delay_target.dtype) - else: - raise ValueError(f'Unknown updating method "{self.update_method}"') + # update delay data + self.data[0] = delay_target + if initial_delay_data is None: + pass + elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): + self.data[1:] = initial_delay_data + elif callable(initial_delay_data): + self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape, + dtype=delay_target.dtype) + else: + raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') + + # time variables + if self.update_method == ROTATE_UPDATE: + if self.idx is None: + self.idx = Variable(stop_gradient(jnp.asarray([0], dtype=jnp.int32))) + else: + self.idx.value = stop_gradient(jnp.asarray([0], dtype=jnp.int32)) + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.num_delay_step}. ' + f'But we got {delay_len}') + + def __call__(self, delay_len, *indices): + return self.retrieve(delay_len, *indices) + + def retrieve(self, delay_len, *indices): + """Retrieve the delay data acoording to the delay length. + + Parameters:: + + delay_len: int, ArrayType + The delay length used to retrieve the data. + """ + if check.is_checking(): + jit_error(jnp.any(as_jax(delay_len >= self.num_delay_step)), self._check_delay, delay_len) + + if self.update_method == ROTATE_UPDATE: + delay_idx = (self.idx[0] + delay_len) % self.num_delay_step + delay_idx = stop_gradient(delay_idx) + + elif self.update_method == CONCAT_UPDATE: + delay_idx = delay_len + + else: + raise ValueError(f'Unknown updating method "{self.update_method}"') + + # the delay index + if isinstance(delay_idx, int): + pass + elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + # the delay data + return self.data[indices] + + def update(self, value: Union[numbers.Number, BaseArray, jax.Array] = None): + """Update delay variable with the new data. + + Parameters:: + + value: Any + The value of the latest data, used to update this delay variable. + """ + if value is None: + if self.delay_target is None: + raise ValueError('Must provide value.') + else: + value = self.delay_target.value + + if self.update_method == ROTATE_UPDATE: + self.idx.value = stop_gradient(as_jax((self.idx - 1) % self.num_delay_step)) + self.data[self.idx[0]] = value + + elif self.update_method == CONCAT_UPDATE: + if self.num_delay_step >= 2: + self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) + else: + self.data[:] = value + + else: + raise ValueError(f'Unknown updating method "{self.update_method}"') class NeuLenDelay(LengthDelay): - """Neutral Length Delay. Alias of :py:class:`~.LengthDelay`.""" - pass + """Neutral Length Delay. Alias of :py:class:`~.LengthDelay`.""" + pass diff --git a/brainpy/_src/math/einops.py b/brainpy/_src/math/einops.py index 1dc617021..f6d054485 100644 --- a/brainpy/_src/math/einops.py +++ b/brainpy/_src/math/einops.py @@ -12,7 +12,7 @@ from .ndarray import BaseArray __all__ = [ - 'ein_reduce', 'ein_rearrange', 'ein_repeat', 'ein_shape', + 'ein_reduce', 'ein_rearrange', 'ein_repeat', 'ein_shape', ] Tensor = Union[BaseArray, jax.Array] @@ -28,98 +28,98 @@ def _product(sequence: List[int]) -> int: - """minimalistic product that works both with numbers and symbols. Supports empty lists""" - result = 1 - for element in sequence: - result *= element - return result + """minimalistic product that works both with numbers and symbols. Supports empty lists""" + result = 1 + for element in sequence: + result *= element + return result def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int]): - if callable(reduction_type): - # custom callable - return reduction_type(tensor, tuple(reduced_axes)) - else: - # one of built-in operations - assert reduction_type in _reductions - if reduction_type == "mean": - if not bnp2.is_float_type(tensor): - raise NotImplementedError("reduce_mean is not available for non-floating tensors") - return __reduce(tensor, reduction_type, tuple(reduced_axes)) + if callable(reduction_type): + # custom callable + return reduction_type(tensor, tuple(reduced_axes)) + else: + # one of built-in operations + assert reduction_type in _reductions + if reduction_type == "mean": + if not bnp2.is_float_type(tensor): + raise NotImplementedError("reduce_mean is not available for non-floating tensors") + return __reduce(tensor, reduction_type, tuple(reduced_axes)) def __reduce(x: Union[BaseArray, jax.Array], operation: str, reduced_axes): - if operation == "min": - return x.min(axis=reduced_axes) - elif operation == "max": - return x.max(axis=reduced_axes) - elif operation == "sum": - return x.sum(axis=reduced_axes) - elif operation == "mean": - return x.mean(axis=reduced_axes) - elif operation == "prod": - return x.prod(axis=reduced_axes) - elif operation == "any": - return x.any(axis=reduced_axes) - elif operation == "all": - return x.all(axis=reduced_axes) - else: - raise NotImplementedError("Unknown reduction ", operation) + if operation == "min": + return x.min(axis=reduced_axes) + elif operation == "max": + return x.max(axis=reduced_axes) + elif operation == "sum": + return x.sum(axis=reduced_axes) + elif operation == "mean": + return x.mean(axis=reduced_axes) + elif operation == "prod": + return x.prod(axis=reduced_axes) + elif operation == "any": + return x.any(axis=reduced_axes) + elif operation == "all": + return x.all(axis=reduced_axes) + else: + raise NotImplementedError("Unknown reduction ", operation) def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): - # 'collapses' neighboring axes if those participate in the result pattern in the same order - # TODO add support for added_axes - assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) - # joining consecutive axes that will be reduced - # possibly we can skip this if all backends can optimize this (not sure) - reduced_axes = tuple(sorted(reduced_axes)) - for i in range(len(reduced_axes) - 1)[::-1]: - if reduced_axes[i] + 1 == reduced_axes[i + 1]: - removed_axis = reduced_axes[i + 1] - removed_length = init_shapes[removed_axis] - init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] - init_shapes[removed_axis - 1] *= removed_length - reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) - - # removing axes that are moved together during reshape - def build_mapping(): - init_to_final = {} - for axis in range(len(init_shapes)): - if axis in reduced_axes: - init_to_final[axis] = None - else: - after_reduction = sum(x is not None for x in init_to_final.values()) - init_to_final[axis] = list(axes_reordering).index(after_reduction) - return init_to_final - - init_axis_to_final_axis = build_mapping() - - for init_axis in range(len(init_shapes) - 1)[::-1]: - if init_axis_to_final_axis[init_axis] is None: - continue - if init_axis_to_final_axis[init_axis + 1] is None: - continue - if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: - removed_axis = init_axis + 1 - removed_length = init_shapes[removed_axis] - removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) - - reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) - init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] - init_shapes[removed_axis - 1] *= removed_length - old_reordering = axes_reordering - axes_reordering = [] - for axis in old_reordering: - if axis == removed_axis_after_reduction: - pass - elif axis < removed_axis_after_reduction: - axes_reordering.append(axis) - else: - axes_reordering.append(axis - 1) - init_axis_to_final_axis = build_mapping() - - return init_shapes, reduced_axes, axes_reordering, final_shapes + # 'collapses' neighboring axes if those participate in the result pattern in the same order + # TODO add support for added_axes + assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) + # joining consecutive axes that will be reduced + # possibly we can skip this if all backends can optimize this (not sure) + reduced_axes = tuple(sorted(reduced_axes)) + for i in range(len(reduced_axes) - 1)[::-1]: + if reduced_axes[i] + 1 == reduced_axes[i + 1]: + removed_axis = reduced_axes[i + 1] + removed_length = init_shapes[removed_axis] + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + reduced_axes = reduced_axes[: i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) + + # removing axes that are moved together during reshape + def build_mapping(): + init_to_final = {} + for axis in range(len(init_shapes)): + if axis in reduced_axes: + init_to_final[axis] = None + else: + after_reduction = sum(x is not None for x in init_to_final.values()) + init_to_final[axis] = list(axes_reordering).index(after_reduction) + return init_to_final + + init_axis_to_final_axis = build_mapping() + + for init_axis in range(len(init_shapes) - 1)[::-1]: + if init_axis_to_final_axis[init_axis] is None: + continue + if init_axis_to_final_axis[init_axis + 1] is None: + continue + if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: + removed_axis = init_axis + 1 + removed_length = init_shapes[removed_axis] + removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) + + reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + old_reordering = axes_reordering + axes_reordering = [] + for axis in old_reordering: + if axis == removed_axis_after_reduction: + pass + elif axis < removed_axis_after_reduction: + axes_reordering.append(axis) + else: + axes_reordering.append(axis - 1) + init_axis_to_final_axis = build_mapping() + + return init_shapes, reduced_axes, axes_reordering, final_shapes CookedRecipe = Tuple[Optional[List[int]], Optional[List[int]], List[int], Dict[int, int], Optional[List[int]], int] @@ -132,116 +132,116 @@ def build_mapping(): class TransformRecipe: - """ - Recipe describes actual computation pathway. - Recipe can be applied to a tensor or variable. - """ - - # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) - # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided - - def __init__( - self, - # list of sizes (or just sizes) for elementary axes as they appear in left expression. - # this is what (after computing unknown parts) will be a shape after first transposition. - # This does not include any ellipsis dimensions. - elementary_axes_lengths: List[int], - # if additional axes are provided, they should be set in prev array - # This shows mapping from name to position - axis_name2elementary_axis: Dict[str, int], - # each dimension in input can help to reconstruct length of one elementary axis - # or verify one of dimensions. Each element points to element of elementary_axes_lengths. - input_composition_known_unknown: List[Tuple[List[int], List[int]]], - # permutation applied to elementary axes, if ellipsis is absent - axes_permutation: List[int], - # permutation puts reduced axes in the end, we only need to know the first position. - first_reduced_axis: int, - # at which positions which of elementary axes should appear. Axis position -> axis index. - added_axes: Dict[int, int], - # ids of axes as they appear in result, again pointers to elementary_axes_lengths, - # only used to infer result dimensions - output_composite_axes: List[List[int]], - ): - self.elementary_axes_lengths: List[int] = elementary_axes_lengths - self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis - self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown - self.axes_permutation: List[int] = axes_permutation - - self.first_reduced_axis: int = first_reduced_axis - self.added_axes: Dict[int, int] = added_axes - self.output_composite_axes: List[List[int]] = output_composite_axes + """ + Recipe describes actual computation pathway. + Recipe can be applied to a tensor or variable. + """ + + # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) + # update: pytorch 2.0 torch.jit.script seems to have problems with dataclasses unless they were explicitly provided + + def __init__( + self, + # list of sizes (or just sizes) for elementary axes as they appear in left expression. + # this is what (after computing unknown parts) will be a shape after first transposition. + # This does not include any ellipsis dimensions. + elementary_axes_lengths: List[int], + # if additional axes are provided, they should be set in prev array + # This shows mapping from name to position + axis_name2elementary_axis: Dict[str, int], + # each dimension in input can help to reconstruct length of one elementary axis + # or verify one of dimensions. Each element points to element of elementary_axes_lengths. + input_composition_known_unknown: List[Tuple[List[int], List[int]]], + # permutation applied to elementary axes, if ellipsis is absent + axes_permutation: List[int], + # permutation puts reduced axes in the end, we only need to know the first position. + first_reduced_axis: int, + # at which positions which of elementary axes should appear. Axis position -> axis index. + added_axes: Dict[int, int], + # ids of axes as they appear in result, again pointers to elementary_axes_lengths, + # only used to infer result dimensions + output_composite_axes: List[List[int]], + ): + self.elementary_axes_lengths: List[int] = elementary_axes_lengths + self.axis_name2elementary_axis: Dict[str, int] = axis_name2elementary_axis + self.input_composition_known_unknown: List[Tuple[List[int], List[int]]] = input_composition_known_unknown + self.axes_permutation: List[int] = axes_permutation + + self.first_reduced_axis: int = first_reduced_axis + self.added_axes: Dict[int, int] = added_axes + self.output_composite_axes: List[List[int]] = output_composite_axes def _reconstruct_from_shape_uncached( self: TransformRecipe, shape: List[int], axes_dims: FakeHashableAxesLengths ) -> CookedRecipe: - """ - Reconstruct all actual parameters using shape. - Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet) - known axes can be integers or symbols, but not Nones. - """ - # magic number - need_init_reshape = False - - # last axis is allocated for collapsed ellipsis - axes_lengths: List[int] = list(self.elementary_axes_lengths) - for axis, dim in axes_dims: - axes_lengths[self.axis_name2elementary_axis[axis]] = dim - - for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown): - length = shape[input_axis] - if len(known_axes) == 0 and len(unknown_axes) == 1: - # shortcut for the most common case - axes_lengths[unknown_axes[0]] = length - continue - - known_product = 1 - for axis in known_axes: - known_product *= axes_lengths[axis] - - if len(unknown_axes) == 0: - if isinstance(length, int) and isinstance(known_product, int) and length != known_product: - raise EinopsError(f"Shape mismatch, {length} != {known_product}") - else: - # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out' - if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: - raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}") + """ + Reconstruct all actual parameters using shape. + Shape is a tuple that may contain integers, shape symbols (tf, theano) and UnknownSize (tf, previously mxnet) + known axes can be integers or symbols, but not Nones. + """ + # magic number + need_init_reshape = False + + # last axis is allocated for collapsed ellipsis + axes_lengths: List[int] = list(self.elementary_axes_lengths) + for axis, dim in axes_dims: + axes_lengths[self.axis_name2elementary_axis[axis]] = dim + + for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composition_known_unknown): + length = shape[input_axis] + if len(known_axes) == 0 and len(unknown_axes) == 1: + # shortcut for the most common case + axes_lengths[unknown_axes[0]] = length + continue + + known_product = 1 + for axis in known_axes: + known_product *= axes_lengths[axis] + + if len(unknown_axes) == 0: + if isinstance(length, int) and isinstance(known_product, int) and length != known_product: + raise EinopsError(f"Shape mismatch, {length} != {known_product}") + else: + # assert len(unknown_axes) == 1, 'this is enforced when recipe is created, so commented out' + if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: + raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}") - unknown_axis = unknown_axes[0] - inferred_length: int = length // known_product - axes_lengths[unknown_axis] = inferred_length + unknown_axis = unknown_axes[0] + inferred_length: int = length // known_product + axes_lengths[unknown_axis] = inferred_length - if len(known_axes) + len(unknown_axes) != 1: - need_init_reshape = True + if len(known_axes) + len(unknown_axes) != 1: + need_init_reshape = True - # at this point all axes_lengths are computed (either have values or variables, but not Nones) + # at this point all axes_lengths are computed (either have values or variables, but not Nones) - # elementary axes are ordered as they appear in input, then all added axes - init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None + # elementary axes are ordered as they appear in input, then all added axes + init_shapes: Optional[List[int]] = axes_lengths[: len(self.axes_permutation)] if need_init_reshape else None - need_final_reshape = False - final_shapes: List[int] = [] - for grouping in self.output_composite_axes: - lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] - final_shapes.append(_product(lengths)) - if len(lengths) != 1: - need_final_reshape = True + need_final_reshape = False + final_shapes: List[int] = [] + for grouping in self.output_composite_axes: + lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] + final_shapes.append(_product(lengths)) + if len(lengths) != 1: + need_final_reshape = True - added_axes: Dict[int, int] = { - pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items() - } + added_axes: Dict[int, int] = { + pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items() + } - # this list can be empty - reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation))) + # this list can be empty + reduced_axes = list(range(self.first_reduced_axis, len(self.axes_permutation))) - n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation) + n_axes_after_adding_axes = len(added_axes) + len(self.axes_permutation) - axes_reordering: Optional[List[int]] = self.axes_permutation - if self.axes_permutation == list(range(len(self.axes_permutation))): - axes_reordering = None + axes_reordering: Optional[List[int]] = self.axes_permutation + if self.axes_permutation == list(range(len(self.axes_permutation))): + axes_reordering = None - _final_shapes = final_shapes if need_final_reshape else None - return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes + _final_shapes = final_shapes if need_final_reshape else None + return init_shapes, axes_reordering, reduced_axes, added_axes, _final_shapes, n_axes_after_adding_axes _reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached) @@ -250,59 +250,59 @@ def _reconstruct_from_shape_uncached( def _apply_recipe( recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths ) -> Tensor: - # this method implements actual work for all backends for 3 operations - try: - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = ( - _reconstruct_from_shape(recipe, bnp.shape(tensor), axes_lengths)) - except TypeError: - # shape or one of passed axes lengths is not hashable (i.e. they are symbols) - _result = _reconstruct_from_shape_uncached(recipe, bnp.shape(tensor), axes_lengths) - (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result - if init_shapes is not None: - tensor = bnp.reshape(bnp.as_jax(tensor), init_shapes) - if axes_reordering is not None: - tensor = bnp.transpose(bnp.as_jax(tensor), axes_reordering) - if len(reduced_axes) > 0: - tensor = _reduce_axes(bnp.as_jax(tensor), reduction_type=reduction_type, reduced_axes=reduced_axes) - if len(added_axes) > 0: - tensor = bnp2.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) - if final_shapes is not None: - tensor = bnp.reshape(bnp.as_jax(tensor), final_shapes) - return tensor + # this method implements actual work for all backends for 3 operations + try: + init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = ( + _reconstruct_from_shape(recipe, bnp.shape(tensor), axes_lengths)) + except TypeError: + # shape or one of passed axes lengths is not hashable (i.e. they are symbols) + _result = _reconstruct_from_shape_uncached(recipe, bnp.shape(tensor), axes_lengths) + (init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added) = _result + if init_shapes is not None: + tensor = bnp.reshape(bnp.as_jax(tensor), init_shapes) + if axes_reordering is not None: + tensor = bnp.transpose(bnp.as_jax(tensor), axes_reordering) + if len(reduced_axes) > 0: + tensor = _reduce_axes(bnp.as_jax(tensor), reduction_type=reduction_type, reduced_axes=reduced_axes) + if len(added_axes) > 0: + tensor = bnp2.add_axes(tensor, n_axes=n_axes_w_added, pos2len=added_axes) + if final_shapes is not None: + tensor = bnp.reshape(bnp.as_jax(tensor), final_shapes) + return tensor def _apply_recipe_array_api( xp, recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction, axes_lengths: HashableAxesLengths ) -> Tensor: - # completely-inline implementation - init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( - recipe, tensor.shape, axes_lengths - ) - if init_shapes is not None: - tensor = xp.reshape(tensor, init_shapes) - if axes_reordering is not None: - tensor = xp.permute_dims(tensor, axes_reordering) - if len(reduced_axes) > 0: - if callable(reduction_type): - # custom callable - tensor = reduction_type(tensor, tuple(reduced_axes)) - else: - # one of built-in operations - assert reduction_type in _reductions - tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes)) - if len(added_axes) > 0: - # we use broadcasting - for axis_position, axis_length in added_axes.items(): - tensor = xp.expand_dims(tensor, axis=axis_position) + # completely-inline implementation + init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape( + recipe, tensor.shape, axes_lengths + ) + if init_shapes is not None: + tensor = xp.reshape(tensor, init_shapes) + if axes_reordering is not None: + tensor = xp.permute_dims(tensor, axes_reordering) + if len(reduced_axes) > 0: + if callable(reduction_type): + # custom callable + tensor = reduction_type(tensor, tuple(reduced_axes)) + else: + # one of built-in operations + assert reduction_type in _reductions + tensor = getattr(xp, reduction_type)(tensor, axis=tuple(reduced_axes)) + if len(added_axes) > 0: + # we use broadcasting + for axis_position, axis_length in added_axes.items(): + tensor = xp.expand_dims(tensor, axis=axis_position) - final_shape = list(tensor.shape) - for axis_position, axis_length in added_axes.items(): - final_shape[axis_position] = axis_length + final_shape = list(tensor.shape) + for axis_position, axis_length in added_axes.items(): + final_shape[axis_position] = axis_length - tensor = xp.broadcast_to(tensor, final_shape) - if final_shapes is not None: - tensor = xp.reshape(tensor, final_shapes) - return tensor + tensor = xp.broadcast_to(tensor, final_shape) + if final_shapes is not None: + tensor = xp.reshape(tensor, final_shapes) + return tensor @functools.lru_cache(256) @@ -312,417 +312,417 @@ def _prepare_transformation_recipe( axes_names: Tuple[str, ...], ndim: int, ) -> TransformRecipe: - """Perform initial parsing of pattern and provided supplementary info - axes_lengths is a tuple of tuples (axis_name, axis_length) - """ - left_str, rght_str = pattern.split("->") - left = ParsedExpression(left_str) - rght = ParsedExpression(rght_str) - - # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction - if not left.has_ellipsis and rght.has_ellipsis: - raise EinopsError("Ellipsis found in right side, but not left side of a pattern {}".format(pattern)) - if left.has_ellipsis and left.has_ellipsis_parenthesized: - raise EinopsError("Ellipsis inside parenthesis in the left side is not allowed: {}".format(pattern)) - if operation == "rearrange": - if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: - raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)") - difference = set.symmetric_difference(left.identifiers, rght.identifiers) - if len(difference) > 0: - raise EinopsError("Identifiers only on one side of expression (should be on both): {}".format(difference)) - elif operation == "repeat": - difference = set.difference(left.identifiers, rght.identifiers) - if len(difference) > 0: - raise EinopsError("Unexpected identifiers on the left side of repeat: {}".format(difference)) - axes_without_size = set.difference( - {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, - {*left.identifiers, *axes_names}, - ) - if len(axes_without_size) > 0: - raise EinopsError("Specify sizes for new axes in repeat: {}".format(axes_without_size)) - elif operation in _reductions or callable(operation): - difference = set.difference(rght.identifiers, left.identifiers) - if len(difference) > 0: - raise EinopsError("Unexpected identifiers on the right side of reduce {}: {}".format(operation, difference)) - else: - raise EinopsError("Unknown reduction {}. Expect one of {}.".format(operation, _reductions)) - - if left.has_ellipsis: - n_other_dims = len(left.composition) - 1 - if ndim < n_other_dims: - raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.") - ellipsis_ndim = ndim - n_other_dims - ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)] - left_composition = [] - for composite_axis in left.composition: - if composite_axis == _ellipsis: - for axis in ell_axes: - left_composition.append([axis]) - else: - left_composition.append(composite_axis) - - rght_composition = [] - for composite_axis in rght.composition: - if composite_axis == _ellipsis: - for axis in ell_axes: - rght_composition.append([axis]) - else: - group = [] - for axis in composite_axis: - if axis == _ellipsis: - group.extend(ell_axes) - else: - group.append(axis) - rght_composition.append(group) - - left.identifiers.update(ell_axes) - left.identifiers.remove(_ellipsis) - if rght.has_ellipsis: - rght.identifiers.update(ell_axes) - rght.identifiers.remove(_ellipsis) - else: - if ndim != len(left.composition): - raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.") - left_composition = left.composition - rght_composition = rght.composition - - # parsing all dimensions to find out lengths - axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict() - for composite_axis in left_composition: - for axis_name in composite_axis: - if isinstance(axis_name, AnonymousAxis): - axis_name2known_length[axis_name] = axis_name.value - else: - axis_name2known_length[axis_name] = _unknown_axis_length - - # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point - - repeat_axes_names = [] - for axis_name in rght.identifiers: - if axis_name not in axis_name2known_length: - if isinstance(axis_name, AnonymousAxis): - axis_name2known_length[axis_name] = axis_name.value - else: - axis_name2known_length[axis_name] = _unknown_axis_length - repeat_axes_names.append(axis_name) - - axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} - - # axes provided as kwargs - for elementary_axis in axes_names: - if not ParsedExpression.check_axis_name(elementary_axis): - raise EinopsError("Invalid name for an axis", elementary_axis) - if elementary_axis not in axis_name2known_length: - raise EinopsError("Axis {} is not used in transform".format(elementary_axis)) - axis_name2known_length[elementary_axis] = _expected_axis_length - - input_axes_known_unknown = [] - # some shapes are inferred later - all information is prepared for faster inference - for i, composite_axis in enumerate(left_composition): - known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} - unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} - if len(unknown) > 1: - raise EinopsError("Could not infer sizes for {}".format(unknown)) - assert len(unknown) + len(known) == len(composite_axis) - input_axes_known_unknown.append( - ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown]) + """Perform initial parsing of pattern and provided supplementary info + axes_lengths is a tuple of tuples (axis_name, axis_length) + """ + left_str, rght_str = pattern.split("->") + left = ParsedExpression(left_str) + rght = ParsedExpression(rght_str) + + # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction + if not left.has_ellipsis and rght.has_ellipsis: + raise EinopsError("Ellipsis found in right side, but not left side of a pattern {}".format(pattern)) + if left.has_ellipsis and left.has_ellipsis_parenthesized: + raise EinopsError("Ellipsis inside parenthesis in the left side is not allowed: {}".format(pattern)) + if operation == "rearrange": + if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: + raise EinopsError("Non-unitary anonymous axes are not supported in rearrange (exception is length 1)") + difference = set.symmetric_difference(left.identifiers, rght.identifiers) + if len(difference) > 0: + raise EinopsError("Identifiers only on one side of expression (should be on both): {}".format(difference)) + elif operation == "repeat": + difference = set.difference(left.identifiers, rght.identifiers) + if len(difference) > 0: + raise EinopsError("Unexpected identifiers on the left side of repeat: {}".format(difference)) + axes_without_size = set.difference( + {ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, + {*left.identifiers, *axes_names}, + ) + if len(axes_without_size) > 0: + raise EinopsError("Specify sizes for new axes in repeat: {}".format(axes_without_size)) + elif operation in _reductions or callable(operation): + difference = set.difference(rght.identifiers, left.identifiers) + if len(difference) > 0: + raise EinopsError("Unexpected identifiers on the right side of reduce {}: {}".format(operation, difference)) + else: + raise EinopsError("Unknown reduction {}. Expect one of {}.".format(operation, _reductions)) + + if left.has_ellipsis: + n_other_dims = len(left.composition) - 1 + if ndim < n_other_dims: + raise EinopsError(f"Wrong shape: expected >={n_other_dims} dims. Received {ndim}-dim tensor.") + ellipsis_ndim = ndim - n_other_dims + ell_axes = [_ellipsis + str(i) for i in range(ellipsis_ndim)] + left_composition = [] + for composite_axis in left.composition: + if composite_axis == _ellipsis: + for axis in ell_axes: + left_composition.append([axis]) + else: + left_composition.append(composite_axis) + + rght_composition = [] + for composite_axis in rght.composition: + if composite_axis == _ellipsis: + for axis in ell_axes: + rght_composition.append([axis]) + else: + group = [] + for axis in composite_axis: + if axis == _ellipsis: + group.extend(ell_axes) + else: + group.append(axis) + rght_composition.append(group) + + left.identifiers.update(ell_axes) + left.identifiers.remove(_ellipsis) + if rght.has_ellipsis: + rght.identifiers.update(ell_axes) + rght.identifiers.remove(_ellipsis) + else: + if ndim != len(left.composition): + raise EinopsError(f"Wrong shape: expected {len(left.composition)} dims. Received {ndim}-dim tensor.") + left_composition = left.composition + rght_composition = rght.composition + + # parsing all dimensions to find out lengths + axis_name2known_length: Dict[Union[str, AnonymousAxis], int] = OrderedDict() + for composite_axis in left_composition: + for axis_name in composite_axis: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + + # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point + + repeat_axes_names = [] + for axis_name in rght.identifiers: + if axis_name not in axis_name2known_length: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + repeat_axes_names.append(axis_name) + + axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} + + # axes provided as kwargs + for elementary_axis in axes_names: + if not ParsedExpression.check_axis_name(elementary_axis): + raise EinopsError("Invalid name for an axis", elementary_axis) + if elementary_axis not in axis_name2known_length: + raise EinopsError("Axis {} is not used in transform".format(elementary_axis)) + axis_name2known_length[elementary_axis] = _expected_axis_length + + input_axes_known_unknown = [] + # some shapes are inferred later - all information is prepared for faster inference + for i, composite_axis in enumerate(left_composition): + known: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} + unknown: Set[str] = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} + if len(unknown) > 1: + raise EinopsError("Could not infer sizes for {}".format(unknown)) + assert len(unknown) + len(known) == len(composite_axis) + input_axes_known_unknown.append( + ([axis_name2position[axis] for axis in known], [axis_name2position[axis] for axis in unknown]) + ) + + axis_position_after_reduction: Dict[str, int] = {} + for axis_name in itertools.chain(*left_composition): + if axis_name in rght.identifiers: + axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) + + result_axes_grouping: List[List[int]] = [ + [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition) + ] + + ordered_axis_left = list(itertools.chain(*left_composition)) + ordered_axis_rght = list(itertools.chain(*rght_composition)) + reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers] + order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes + axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition] + added_axes = { + i: axis_name2position[axis_name] + for i, axis_name in enumerate(ordered_axis_rght) + if axis_name not in left.identifiers + } + + first_reduced_axis = len(order_after_transposition) - len(reduced_axes) + + return TransformRecipe( + elementary_axes_lengths=list(axis_name2known_length.values()), + axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names}, + input_composition_known_unknown=input_axes_known_unknown, + axes_permutation=axes_permutation, + first_reduced_axis=first_reduced_axis, + added_axes=added_axes, + output_composite_axes=result_axes_grouping, ) - axis_position_after_reduction: Dict[str, int] = {} - for axis_name in itertools.chain(*left_composition): - if axis_name in rght.identifiers: - axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) - - result_axes_grouping: List[List[int]] = [ - [axis_name2position[axis] for axis in composite_axis] for i, composite_axis in enumerate(rght_composition) - ] - - ordered_axis_left = list(itertools.chain(*left_composition)) - ordered_axis_rght = list(itertools.chain(*rght_composition)) - reduced_axes = [axis for axis in ordered_axis_left if axis not in rght.identifiers] - order_after_transposition = [axis for axis in ordered_axis_rght if axis in left.identifiers] + reduced_axes - axes_permutation = [ordered_axis_left.index(axis) for axis in order_after_transposition] - added_axes = { - i: axis_name2position[axis_name] - for i, axis_name in enumerate(ordered_axis_rght) - if axis_name not in left.identifiers - } - - first_reduced_axis = len(order_after_transposition) - len(reduced_axes) - - return TransformRecipe( - elementary_axes_lengths=list(axis_name2known_length.values()), - axis_name2elementary_axis={axis: axis_name2position[axis] for axis in axes_names}, - input_composition_known_unknown=input_axes_known_unknown, - axes_permutation=axes_permutation, - first_reduced_axis=first_reduced_axis, - added_axes=added_axes, - output_composite_axes=result_axes_grouping, - ) - def _prepare_recipes_for_all_dims( pattern: str, operation: Reduction, axes_names: Tuple[str, ...] ) -> Dict[int, TransformRecipe]: - """ - Internal function, used in layers. - Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims - """ - left_str, rght_str = pattern.split("->") - left = ParsedExpression(left_str) - dims = [len(left.composition)] - if left.has_ellipsis: - dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)] - return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims} + """ + Internal function, used in layers. + Layer makes all recipe creation when it is initialized, thus to keep recipes simple we pre-compute for all dims + """ + left_str, rght_str = pattern.split("->") + left = ParsedExpression(left_str) + dims = [len(left.composition)] + if left.has_ellipsis: + dims = [len(left.composition) - 1 + ellipsis_dims for ellipsis_dims in range(8)] + return {ndim: _prepare_transformation_recipe(pattern, operation, axes_names, ndim=ndim) for ndim in dims} def ein_reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: - """ - ``ein_reduce`` provides combination of reordering and reduction using reader-friendly notation. - - Examples for reduce operation: - - ```python - >>> x = np.random.randn(100, 32, 64) - - # perform max-reduction on the first axis - >>> y = ein_reduce(x, 't b c -> b c', 'max') - - # same as previous, but with clearer axes meaning - >>> y = ein_reduce(x, 'time batch channel -> batch channel', 'max') - - >>> x = np.random.randn(10, 20, 30, 40) - - # 2d max-pooling with kernel size = 2 * 2 for image processing - >>> y1 = ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) - - # if one wants to go back to the original height and width, depth-to-space trick can be applied - >>> y2 = ein_rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) - >>> assert ein_shape(x, 'b _ h w') == ein_shape(y2, 'b _ h w') - - # Adaptive 2d max-pooling to 3 * 4 grid - >>> ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape - (10, 20, 3, 4) - - # Global average pooling - >>> ein_reduce(x, 'b c h w -> b c', 'mean').shape - (10, 20) - - # Subtracting mean over batch for each channel - >>> y = x - ein_reduce(x, 'b c h w -> () c () ()', 'mean') - - # Subtracting per-image mean for each channel - >>> y = x - ein_reduce(x, 'b c h w -> b c () ()', 'mean') - - ``` - - Parameters: - tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, reduction pattern - reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive - alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. - This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. - axes_lengths: any additional specifications for dimensions - - Returns: - tensor of the same type as input - """ - try: - hashable_axes_lengths = tuple(axes_lengths.items()) - shape = bnp.shape(tensor) - recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape)) - return _apply_recipe(recipe, - cast(Tensor, tensor), - reduction_type=reduction, - axes_lengths=hashable_axes_lengths) - except EinopsError as e: - message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) - if not isinstance(tensor, list): - message += "\n Input tensor shape: {}. ".format(shape) - else: - message += "\n Input is list. " - message += "Additional info: {}.".format(axes_lengths) - raise EinopsError(message + "\n {}".format(e)) + """ + ``ein_reduce`` provides combination of reordering and reduction using reader-friendly notation. + + Examples for reduce operation: + + ```python + >>> x = np.random.randn(100, 32, 64) + + # perform max-reduction on the first axis + >>> y = ein_reduce(x, 't b c -> b c', 'max') + + # same as previous, but with clearer axes meaning + >>> y = ein_reduce(x, 'time batch channel -> batch channel', 'max') + + >>> x = np.random.randn(10, 20, 30, 40) + + # 2d max-pooling with kernel size = 2 * 2 for image processing + >>> y1 = ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) + + # if one wants to go back to the original height and width, depth-to-space trick can be applied + >>> y2 = ein_rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) + >>> assert ein_shape(x, 'b _ h w') == ein_shape(y2, 'b _ h w') + + # Adaptive 2d max-pooling to 3 * 4 grid + >>> ein_reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape + (10, 20, 3, 4) + + # Global average pooling + >>> ein_reduce(x, 'b c h w -> b c', 'mean').shape + (10, 20) + + # Subtracting mean over batch for each channel + >>> y = x - ein_reduce(x, 'b c h w -> () c () ()', 'mean') + + # Subtracting per-image mean for each channel + >>> y = x - ein_reduce(x, 'b c h w -> b c () ()', 'mean') + + ``` + + Parameters: + tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, reduction pattern + reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive + alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. + This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input + """ + try: + hashable_axes_lengths = tuple(axes_lengths.items()) + shape = bnp.shape(tensor) + recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape)) + return _apply_recipe(recipe, + cast(Tensor, tensor), + reduction_type=reduction, + axes_lengths=hashable_axes_lengths) + except EinopsError as e: + message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) + if not isinstance(tensor, list): + message += "\n Input tensor shape: {}. ".format(shape) + else: + message += "\n Input is list. " + message += "Additional info: {}.".format(axes_lengths) + raise EinopsError(message + "\n {}".format(e)) def ein_rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: - """ - ``ein_rearrange`` is a reader-friendly smart element reordering for multidimensional tensors. - This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, - stack, concatenate and other operations. + """ + ``ein_rearrange`` is a reader-friendly smart element reordering for multidimensional tensors. + This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, + stack, concatenate and other operations. - Examples for rearrange operation: + Examples for rearrange operation: - ```python - # suppose we have a set of 32 images in "h w c" format (height-width-channel) - >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] + ```python + # suppose we have a set of 32 images in "h w c" format (height-width-channel) + >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] - # stack along first (batch) axis, output is a single array - >>> ein_rearrange(images, 'b h w c -> b h w c').shape - (32, 30, 40, 3) + # stack along first (batch) axis, output is a single array + >>> ein_rearrange(images, 'b h w c -> b h w c').shape + (32, 30, 40, 3) - # concatenate images along height (vertical axis), 960 = 32 * 30 - >>> ein_rearrange(images, 'b h w c -> (b h) w c').shape - (960, 40, 3) + # concatenate images along height (vertical axis), 960 = 32 * 30 + >>> ein_rearrange(images, 'b h w c -> (b h) w c').shape + (960, 40, 3) - # concatenated images along horizontal axis, 1280 = 32 * 40 - >>> ein_rearrange(images, 'b h w c -> h (b w) c').shape - (30, 1280, 3) + # concatenated images along horizontal axis, 1280 = 32 * 40 + >>> ein_rearrange(images, 'b h w c -> h (b w) c').shape + (30, 1280, 3) - # reordered axes to "b c h w" format for deep learning - >>> ein_rearrange(images, 'b h w c -> b c h w').shape - (32, 3, 30, 40) + # reordered axes to "b c h w" format for deep learning + >>> ein_rearrange(images, 'b h w c -> b c h w').shape + (32, 3, 30, 40) - # flattened each image into a vector, 3600 = 30 * 40 * 3 - >>> ein_rearrange(images, 'b h w c -> b (c h w)').shape - (32, 3600) + # flattened each image into a vector, 3600 = 30 * 40 * 3 + >>> ein_rearrange(images, 'b h w c -> b (c h w)').shape + (32, 3600) - # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 - >>> ein_rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape - (128, 15, 20, 3) + # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 + >>> ein_rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape + (128, 15, 20, 3) - # space-to-depth operation - >>> ein_rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape - (32, 15, 20, 12) + # space-to-depth operation + >>> ein_rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape + (32, 15, 20, 12) - ``` + ``` - When composing axes, C-order enumeration used (consecutive elements have different last axis) - Find more examples in einops tutorial. + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. - Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, rearrangement pattern - axes_lengths: any additional specifications for dimensions + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions - Returns: - tensor of the same type as input. If possible, a view to the original tensor is returned. + Returns: + tensor of the same type as input. If possible, a view to the original tensor is returned. - """ - return ein_reduce(tensor, pattern, reduction="rearrange", **axes_lengths) + """ + return ein_reduce(tensor, pattern, reduction="rearrange", **axes_lengths) def ein_repeat(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: - """ - ``ein_repeat`` allows reordering elements and repeating them in arbitrary combinations. - This operation includes functionality of repeat, tile, broadcast functions. + """ + ``ein_repeat`` allows reordering elements and repeating them in arbitrary combinations. + This operation includes functionality of repeat, tile, broadcast functions. - Examples for repeat operation: + Examples for repeat operation: - ```python - # a grayscale image (of shape height x width) - >>> image = np.random.randn(30, 40) + ```python + # a grayscale image (of shape height x width) + >>> image = np.random.randn(30, 40) - # change it to RGB format by repeating in each channel - >>> ein_repeat(image, 'h w -> h w c', c=3).shape - (30, 40, 3) + # change it to RGB format by repeating in each channel + >>> ein_repeat(image, 'h w -> h w c', c=3).shape + (30, 40, 3) - # repeat image 2 times along height (vertical axis) - >>> ein_repeat(image, 'h w -> (repeat h) w', repeat=2).shape - (60, 40) + # repeat image 2 times along height (vertical axis) + >>> ein_repeat(image, 'h w -> (repeat h) w', repeat=2).shape + (60, 40) - # repeat image 2 time along height and 3 times along width - >>> ein_repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape - (60, 120) + # repeat image 2 time along height and 3 times along width + >>> ein_repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape + (60, 120) - # convert each pixel to a small square 2x2. Upsample image by 2x - >>> ein_repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape - (60, 80) + # convert each pixel to a small square 2x2. Upsample image by 2x + >>> ein_repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (60, 80) - # pixelate image first by downsampling by 2x, then upsampling - >>> downsampled = ein_reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) - >>> ein_repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape - (30, 40) + # pixelate image first by downsampling by 2x, then upsampling + >>> downsampled = ein_reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) + >>> ein_repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (30, 40) - ``` + ``` - When composing axes, C-order enumeration used (consecutive elements have different last axis) - Find more examples in einops tutorial. + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. - Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). - list of tensors is also accepted, those should be of the same type and shape - pattern: string, rearrangement pattern - axes_lengths: any additional specifications for dimensions + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions - Returns: - Tensor of the same type as input. If possible, a view to the original tensor is returned. + Returns: + Tensor of the same type as input. If possible, a view to the original tensor is returned. - """ - return ein_reduce(tensor, pattern, reduction="repeat", **axes_lengths) + """ + return ein_reduce(tensor, pattern, reduction="repeat", **axes_lengths) def ein_shape(x, pattern: str) -> dict: - """ - Parse a tensor shape to dictionary mapping axes names to their lengths. - - ```python - # Use underscore to skip the dimension in parsing. - >>> x = np.zeros([2, 3, 5, 7]) - >>> ein_shape(x, 'batch _ h w') - {'batch': 2, 'h': 5, 'w': 7} - - # `parse_shape` output can be used to specify axes_lengths for other operations: - >>> y = np.zeros([700]) - >>> ein_rearrange(y, '(b c h w) -> b c h w', **ein_shape(x, 'b _ h w')).shape - (2, 10, 5, 7) - - ``` - - For symbolic frameworks may return symbols, not integers. - - Parameters: - x: tensor of any supported framework - pattern: str, space separated names for axes, underscore means skip axis - - Returns: - dict, maps axes names to their lengths - """ - exp = ParsedExpression(pattern, allow_underscore=True) - shape = bnp.shape(x) - if exp.has_composed_axes(): - raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}") - if len(shape) != len(exp.composition): + """ + Parse a tensor shape to dictionary mapping axes names to their lengths. + + ```python + # Use underscore to skip the dimension in parsing. + >>> x = np.zeros([2, 3, 5, 7]) + >>> ein_shape(x, 'batch _ h w') + {'batch': 2, 'h': 5, 'w': 7} + + # `parse_shape` output can be used to specify axes_lengths for other operations: + >>> y = np.zeros([700]) + >>> ein_rearrange(y, '(b c h w) -> b c h w', **ein_shape(x, 'b _ h w')).shape + (2, 10, 5, 7) + + ``` + + For symbolic frameworks may return symbols, not integers. + + Parameters: + x: tensor of any supported framework + pattern: str, space separated names for axes, underscore means skip axis + + Returns: + dict, maps axes names to their lengths + """ + exp = ParsedExpression(pattern, allow_underscore=True) + shape = bnp.shape(x) + if exp.has_composed_axes(): + raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}") + if len(shape) != len(exp.composition): + if exp.has_ellipsis: + if len(shape) < len(exp.composition) - 1: + raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}") + else: + raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}") if exp.has_ellipsis: - if len(shape) < len(exp.composition) - 1: - raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}") + ellipsis_idx = exp.composition.index(_ellipsis) + composition = ( + exp.composition[:ellipsis_idx] + + ["_"] * (len(shape) - len(exp.composition) + 1) + + exp.composition[ellipsis_idx + 1:] + ) else: - raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}") - if exp.has_ellipsis: - ellipsis_idx = exp.composition.index(_ellipsis) - composition = ( - exp.composition[:ellipsis_idx] - + ["_"] * (len(shape) - len(exp.composition) + 1) - + exp.composition[ellipsis_idx + 1:] - ) - else: - composition = exp.composition - result = {} - for (axis_name,), axis_length in zip(composition, shape): # type: ignore - if axis_name != "_": - result[axis_name] = axis_length - return result + composition = exp.composition + result = {} + for (axis_name,), axis_length in zip(composition, shape): # type: ignore + if axis_name != "_": + result[axis_name] = axis_length + return result # _enumerate_directions is not exposed in the public API def _enumerate_directions(x): - """ - For an n-dimensional tensor, returns tensors to enumerate each axis. - ```python - x = np.zeros([2, 3, 4]) # or any other tensor - i, j, k = _enumerate_directions(x) - result = i + 2*j + 3*k - ``` - - `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result - Works very similarly to numpy.ogrid (open indexing grid) - """ - shape = bnp.shape(x) - result = [] - for axis_id, axis_length in enumerate(shape): - shape = [1] * len(shape) - shape[axis_id] = axis_length - result.append(bnp.reshape(bnp.arange(0, axis_length), shape)) - return result + """ + For an n-dimensional tensor, returns tensors to enumerate each axis. + ```python + x = np.zeros([2, 3, 4]) # or any other tensor + i, j, k = _enumerate_directions(x) + result = i + 2*j + 3*k + ``` + + `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result + Works very similarly to numpy.ogrid (open indexing grid) + """ + shape = bnp.shape(x) + result = [] + for axis_id, axis_length in enumerate(shape): + shape = [1] * len(shape) + shape[axis_id] = axis_length + result.append(bnp.reshape(bnp.arange(0, axis_length), shape)) + return result diff --git a/brainpy/_src/math/einops_parsing.py b/brainpy/_src/math/einops_parsing.py index 6ce055bdb..20a358f81 100644 --- a/brainpy/_src/math/einops_parsing.py +++ b/brainpy/_src/math/einops_parsing.py @@ -6,148 +6,148 @@ class EinopsError(Exception): - pass + pass class AnonymousAxis(object): - """Important thing: all instances of this class are not equal to each other """ + """Important thing: all instances of this class are not equal to each other """ - def __init__(self, value: str): - self.value = int(value) - if self.value <= 1: - if self.value == 1: - raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') - else: - raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) + def __init__(self, value: str): + self.value = int(value) + if self.value <= 1: + if self.value == 1: + raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') + else: + raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) - def __repr__(self): - return "{}-axis".format(str(self.value)) + def __repr__(self): + return "{}-axis".format(str(self.value)) class ParsedExpression: - """ - non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') - and keeps some information important for downstream - """ - - def __init__(self, expression: str, *, allow_underscore: bool = False, - allow_duplicates: bool = False): - self.has_ellipsis: bool = False - self.has_ellipsis_parenthesized: Optional[bool] = None - self.identifiers: Set[str] = set() - # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition - self.has_non_unitary_anonymous_axes: bool = False - # composition keeps structure of composite axes, see how different corner cases are handled in tests - self.composition: List[Union[List[str], str]] = [] - if '.' in expression: - if '...' not in expression: - raise EinopsError('Expression may contain dots only inside ellipsis (...)') - if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: - raise EinopsError( - 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') - expression = expression.replace('...', _ellipsis) - self.has_ellipsis = True - - bracket_group: Optional[List[str]] = None - - def add_axis_name(x): - if x in self.identifiers: - if not (allow_underscore and x == "_") and not allow_duplicates: - raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) - if x == _ellipsis: - self.identifiers.add(_ellipsis) - if bracket_group is None: - self.composition.append(_ellipsis) - self.has_ellipsis_parenthesized = False - else: - bracket_group.append(_ellipsis) - self.has_ellipsis_parenthesized = True - else: - is_number = str.isdecimal(x) - if is_number and int(x) == 1: - # handling the case of anonymous axis of length 1 - if bracket_group is None: - self.composition.append([]) - else: - pass # no need to think about 1s inside parenthesis - return - is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) - if not (is_number or is_axis_name): - raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) - if is_number: - x = AnonymousAxis(x) - self.identifiers.add(x) - if is_number: - self.has_non_unitary_anonymous_axes = True - if bracket_group is None: - self.composition.append([x]) - else: - bracket_group.append(x) + """ + non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') + and keeps some information important for downstream + """ + + def __init__(self, expression: str, *, allow_underscore: bool = False, + allow_duplicates: bool = False): + self.has_ellipsis: bool = False + self.has_ellipsis_parenthesized: Optional[bool] = None + self.identifiers: Set[str] = set() + # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition + self.has_non_unitary_anonymous_axes: bool = False + # composition keeps structure of composite axes, see how different corner cases are handled in tests + self.composition: List[Union[List[str], str]] = [] + if '.' in expression: + if '...' not in expression: + raise EinopsError('Expression may contain dots only inside ellipsis (...)') + if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: + raise EinopsError( + 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') + expression = expression.replace('...', _ellipsis) + self.has_ellipsis = True + + bracket_group: Optional[List[str]] = None + + def add_axis_name(x): + if x in self.identifiers: + if not (allow_underscore and x == "_") and not allow_duplicates: + raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False + else: + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 + if bracket_group is None: + self.composition.append([]) + else: + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) + if not (is_number or is_axis_name): + raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) + else: + bracket_group.append(x) - current_identifier = None - for char in expression: - if char in '() ': - if current_identifier is not None: - add_axis_name(current_identifier) current_identifier = None - if char == '(': - if bracket_group is not None: - raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") - bracket_group = [] - elif char == ')': - if bracket_group is None: - raise EinopsError('Brackets are not balanced') - self.composition.append(bracket_group) - bracket_group = None - elif str.isalnum(char) or char in ['_', _ellipsis]: - if current_identifier is None: - current_identifier = char + for char in expression: + if char in '() ': + if current_identifier is not None: + add_axis_name(current_identifier) + current_identifier = None + if char == '(': + if bracket_group is not None: + raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") + bracket_group = [] + elif char == ')': + if bracket_group is None: + raise EinopsError('Brackets are not balanced') + self.composition.append(bracket_group) + bracket_group = None + elif str.isalnum(char) or char in ['_', _ellipsis]: + if current_identifier is None: + current_identifier = char + else: + current_identifier += char + else: + raise EinopsError("Unknown character '{}'".format(char)) + + if bracket_group is not None: + raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) + if current_identifier is not None: + add_axis_name(current_identifier) + + def flat_axes_order(self) -> List: + result = [] + for composed_axis in self.composition: + assert isinstance(composed_axis, list), 'does not work with ellipsis' + for axis in composed_axis: + result.append(axis) + return result + + def has_composed_axes(self) -> bool: + # this will ignore 1 inside brackets + for axes in self.composition: + if isinstance(axes, list) and len(axes) > 1: + return True + return False + + @staticmethod + def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: + if not str.isidentifier(name): + return False, 'not a valid python identifier' + elif name[0] == '_' or name[-1] == '_': + if name == '_' and allow_underscore: + return True, '' + return False, 'axis name should should not start or end with underscore' else: - current_identifier += char - else: - raise EinopsError("Unknown character '{}'".format(char)) - - if bracket_group is not None: - raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) - if current_identifier is not None: - add_axis_name(current_identifier) - - def flat_axes_order(self) -> List: - result = [] - for composed_axis in self.composition: - assert isinstance(composed_axis, list), 'does not work with ellipsis' - for axis in composed_axis: - result.append(axis) - return result - - def has_composed_axes(self) -> bool: - # this will ignore 1 inside brackets - for axes in self.composition: - if isinstance(axes, list) and len(axes) > 1: - return True - return False - - @staticmethod - def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: - if not str.isidentifier(name): - return False, 'not a valid python identifier' - elif name[0] == '_' or name[-1] == '_': - if name == '_' and allow_underscore: - return True, '' - return False, 'axis name should should not start or end with underscore' - else: - if keyword.iskeyword(name): - warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) - if name in ['axis']: - warnings.warn("It is discouraged to use 'axis' as an axis name " - "and will raise an error in future", FutureWarning) - return True, '' - - @staticmethod - def check_axis_name(name: str) -> bool: - """ - Valid axes names are python identifiers except keywords, - and additionally should not start or end with underscore - """ - is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) - return is_valid + if keyword.iskeyword(name): + warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) + if name in ['axis']: + warnings.warn("It is discouraged to use 'axis' as an axis name " + "and will raise an error in future", FutureWarning) + return True, '' + + @staticmethod + def check_axis_name(name: str) -> bool: + """ + Valid axes names are python identifiers except keywords, + and additionally should not start or end with underscore + """ + is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) + return is_valid diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 91b479b62..6b1d7e1c1 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,3 +1,2 @@ -from .csr_matvec import * from .csr_matmat import * - +from .csr_matvec import * diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 308d194b2..e04e7033c 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -3,12 +3,11 @@ from typing import Union, Tuple +import brainevent from jax import numpy as jnp from brainpy._src.math.ndarray import BaseArray as Array -import brainevent - __all__ = [ 'csrmm', ] @@ -53,4 +52,4 @@ def csrmm( if transpose: return matrix @ csr else: - return csr @ matrix \ No newline at end of file + return csr @ matrix diff --git a/brainpy/_src/math/fft.py b/brainpy/_src/math/fft.py index fd745eadc..2f02b596e 100644 --- a/brainpy/_src/math/fft.py +++ b/brainpy/_src/math/fft.py @@ -5,9 +5,9 @@ from ._utils import _compatible_with_brainpy_array __all__ = [ - "fft", "fft2", "fftfreq", "fftn", "fftshift", "hfft", - "ifft", "ifft2", "ifftn", "ifftshift", "ihfft", "irfft", - "irfft2", "irfftn", "rfft", "rfft2", "rfftfreq", "rfftn" + "fft", "fft2", "fftfreq", "fftn", "fftshift", "hfft", + "ifft", "ifft2", "ifftn", "ifftshift", "ihfft", "irfft", + "irfft2", "irfftn", "rfft", "rfft2", "rfftfreq", "rfftn" ] fft = _compatible_with_brainpy_array(jfft.fft, module='fft.') diff --git a/brainpy/_src/math/interoperability.py b/brainpy/_src/math/interoperability.py index 577a9bca8..60b29df55 100644 --- a/brainpy/_src/math/interoperability.py +++ b/brainpy/_src/math/interoperability.py @@ -3,107 +3,105 @@ import jax.numpy as jnp import numpy as np -from .ndarray import Array, BaseArray - +from .ndarray import BaseArray __all__ = [ - 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', - 'from_numpy', + 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy', 'as_variable', + 'from_numpy', - 'is_bp_array' + 'is_bp_array' ] def _as_jax_array_(obj): - return obj.value if isinstance(obj, BaseArray) else obj + return obj.value if isinstance(obj, BaseArray) else obj def is_bp_array(x): - """Check if the input is a ``brainpy.math.Array``. - """ - return isinstance(x, BaseArray) + """Check if the input is a ``brainpy.math.Array``. + """ + return isinstance(x, BaseArray) def as_device_array(tensor, dtype=None): - """Convert the input to a ``jax.numpy.DeviceArray``. + """Convert the input to a ``jax.numpy.DeviceArray``. - Parameters:: + Parameters:: - tensor: array_like - Input data, in any form that can be converted to an array. This - includes lists, lists of tuples, tuples, tuples of tuples, tuples - of lists, ArrayType. - dtype: data-type, optional - By default, the data-type is inferred from the input data. + tensor: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, ArrayType. + dtype: data-type, optional + By default, the data-type is inferred from the input data. - Returns:: + Returns:: - out : ArrayType - Array interpretation of `tensor`. No copy is performed if the input - is already an ndarray with matching dtype. - """ - if isinstance(tensor, BaseArray): - return tensor.to_jax(dtype) - elif isinstance(tensor, jnp.ndarray): - return tensor if (dtype is None) else jnp.asarray(tensor, dtype=dtype) - elif isinstance(tensor, np.ndarray): - return jnp.asarray(tensor, dtype=dtype) - else: - return jnp.asarray(tensor, dtype=dtype) + out : ArrayType + Array interpretation of `tensor`. No copy is performed if the input + is already an ndarray with matching dtype. + """ + if isinstance(tensor, BaseArray): + return tensor.to_jax(dtype) + elif isinstance(tensor, jnp.ndarray): + return tensor if (dtype is None) else jnp.asarray(tensor, dtype=dtype) + elif isinstance(tensor, np.ndarray): + return jnp.asarray(tensor, dtype=dtype) + else: + return jnp.asarray(tensor, dtype=dtype) as_jax = as_device_array def as_ndarray(tensor, dtype=None): - """Convert the input to a ``numpy.ndarray``. + """Convert the input to a ``numpy.ndarray``. - Parameters:: + Parameters:: - tensor: array_like - Input data, in any form that can be converted to an array. This - includes lists, lists of tuples, tuples, tuples of tuples, tuples - of lists, ArrayType. - dtype: data-type, optional - By default, the data-type is inferred from the input data. + tensor: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, ArrayType. + dtype: data-type, optional + By default, the data-type is inferred from the input data. - Returns:: + Returns:: - out : ndarray - Array interpretation of `tensor`. No copy is performed if the input - is already an ndarray with matching dtype. - """ - if isinstance(tensor, BaseArray): - return tensor.to_numpy(dtype=dtype) - else: - return np.asarray(tensor, dtype=dtype) + out : ndarray + Array interpretation of `tensor`. No copy is performed if the input + is already an ndarray with matching dtype. + """ + if isinstance(tensor, BaseArray): + return tensor.to_numpy(dtype=dtype) + else: + return np.asarray(tensor, dtype=dtype) as_numpy = as_ndarray def as_variable(tensor, dtype=None): - """Convert the input to a ``brainpy.math.Variable``. + """Convert the input to a ``brainpy.math.Variable``. - Parameters:: + Parameters:: - tensor: array_like - Input data, in any form that can be converted to an array. This - includes lists, lists of tuples, tuples, tuples of tuples, tuples - of lists, ArrayType. - dtype: data-type, optional - By default, the data-type is inferred from the input data. + tensor: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, ArrayType. + dtype: data-type, optional + By default, the data-type is inferred from the input data. - Returns:: + Returns:: - out : ndarray - Array interpretation of `tensor`. No copy is performed if the input - is already an ndarray with matching dtype. - """ - from .object_transform.variables import Variable - return Variable(tensor, dtype=dtype) + out : ndarray + Array interpretation of `tensor`. No copy is performed if the input + is already an ndarray with matching dtype. + """ + from .object_transform.variables import Variable + return Variable(tensor, dtype=dtype) def from_numpy(arr, dtype=None): - return as_ndarray(arr, dtype=dtype) - + return as_ndarray(arr, dtype=dtype) diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index bb6a3c1f4..386defd85 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,2 +1,2 @@ -from .matvec import * from .event_matvec import * +from .matvec import * diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index a1404c7a3..1e4d9c256 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -2,14 +2,14 @@ from typing import Tuple, Optional +import brainevent import jax - import numpy as np + from brainpy._src.math.jitconn.matvec import (mv_prob_homo, mv_prob_uniform, mv_prob_normal) from brainpy._src.math.ndarray import BaseArray as Array -import brainevent __all__ = [ 'event_mv_prob_homo', diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index e1edfb949..9cbb3c39e 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- from typing import Tuple, Optional, Union +import brainevent import jax - -from brainpy._src.math.ndarray import BaseArray as Array import numpy as np -import brainevent +from brainpy._src.math.ndarray import BaseArray as Array __all__ = [ 'mv_prob_homo', diff --git a/brainpy/_src/math/linalg.py b/brainpy/_src/math/linalg.py index a47207d3c..7257bb1b4 100644 --- a/brainpy/_src/math/linalg.py +++ b/brainpy/_src/math/linalg.py @@ -5,9 +5,9 @@ from ._utils import _compatible_with_brainpy_array __all__ = [ - 'cholesky', 'cond', 'det', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'inv', 'svd', - 'lstsq', 'matrix_power', 'matrix_rank', 'norm', 'pinv', 'qr', 'solve', 'slogdet', - 'tensorinv', 'tensorsolve', 'multi_dot' + 'cholesky', 'cond', 'det', 'eig', 'eigh', 'eigvals', 'eigvalsh', 'inv', 'svd', + 'lstsq', 'matrix_power', 'matrix_rank', 'norm', 'pinv', 'qr', 'solve', 'slogdet', + 'tensorinv', 'tensorsolve', 'multi_dot' ] cholesky = _compatible_with_brainpy_array(linalg.cholesky, module='linalg.') @@ -29,4 +29,4 @@ slogdet = _compatible_with_brainpy_array(linalg.slogdet, module='linalg.') tensorinv = _compatible_with_brainpy_array(linalg.tensorinv, module='linalg.') tensorsolve = _compatible_with_brainpy_array(linalg.tensorsolve, module='linalg.') -multi_dot = _compatible_with_brainpy_array(linalg.multi_dot, module='linalg.') \ No newline at end of file +multi_dot = _compatible_with_brainpy_array(linalg.multi_dot, module='linalg.') diff --git a/brainpy/_src/math/modes.py b/brainpy/_src/math/modes.py index 575c27526..118af788a 100644 --- a/brainpy/_src/math/modes.py +++ b/brainpy/_src/math/modes.py @@ -4,97 +4,98 @@ import brainstate __all__ = [ - 'Mode', - 'NonBatchingMode', - 'BatchingMode', - 'TrainingMode', - 'nonbatching_mode', - 'batching_mode', - 'training_mode', + 'Mode', + 'NonBatchingMode', + 'BatchingMode', + 'TrainingMode', + 'nonbatching_mode', + 'batching_mode', + 'training_mode', ] class Mode(brainstate.mixin.Mode): - """Base class for computation Mode - """ - - def __repr__(self): - return self.__class__.__name__ - - def __eq__(self, other: 'Mode'): - if not isinstance(other, Mode): - return False - return other.__class__ == self.__class__ - - def is_one_of(self, *modes): - for m_ in modes: - if not isinstance(m_, type): - raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') - return self.__class__ in modes - - def is_a(self, mode: type): - """Check whether the mode is exactly the desired mode.""" - assert isinstance(mode, type), 'Must be a type.' - return self.__class__ == mode - - def is_parent_of(self, *modes): - """Check whether the mode is a parent of the given modes.""" - cls = self.__class__ - for m_ in modes: - if not isinstance(m_, type): - raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') - if all([not issubclass(m_, cls) for m_ in modes]): - return False - else: - return True - - def is_child_of(self, *modes): - """Check whether the mode is a children of one of the given modes.""" - for m_ in modes: - if not isinstance(m_, type): - raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') - return isinstance(self, modes) - - def is_batch_mode(self): - return isinstance(self, BatchingMode) - - def is_train_mode(self): - return isinstance(self, TrainingMode) - - def is_nonbatch_mode(self): - return isinstance(self, NonBatchingMode) + """Base class for computation Mode + """ + + def __repr__(self): + return self.__class__.__name__ + + def __eq__(self, other: 'Mode'): + if not isinstance(other, Mode): + return False + return other.__class__ == self.__class__ + + def is_one_of(self, *modes): + for m_ in modes: + if not isinstance(m_, type): + raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') + return self.__class__ in modes + + def is_a(self, mode: type): + """Check whether the mode is exactly the desired mode.""" + assert isinstance(mode, type), 'Must be a type.' + return self.__class__ == mode + + def is_parent_of(self, *modes): + """Check whether the mode is a parent of the given modes.""" + cls = self.__class__ + for m_ in modes: + if not isinstance(m_, type): + raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') + if all([not issubclass(m_, cls) for m_ in modes]): + return False + else: + return True + + def is_child_of(self, *modes): + """Check whether the mode is a children of one of the given modes.""" + for m_ in modes: + if not isinstance(m_, type): + raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') + return isinstance(self, modes) + + def is_batch_mode(self): + return isinstance(self, BatchingMode) + + def is_train_mode(self): + return isinstance(self, TrainingMode) + + def is_nonbatch_mode(self): + return isinstance(self, NonBatchingMode) class NonBatchingMode(Mode): - """Normal non-batching mode. + """Normal non-batching mode. - :py:class:`~.NonBatchingMode` is usually used in models of traditional - computational neuroscience. - """ - pass + :py:class:`~.NonBatchingMode` is usually used in models of traditional + computational neuroscience. + """ + pass - @property - def batch_size(self): - return tuple() + @property + def batch_size(self): + return tuple() class BatchingMode(Mode): - """Batching mode. + """Batching mode. - :py:class:`~.NonBatchingMode` is usually used in models of model trainings. - """ + :py:class:`~.NonBatchingMode` is usually used in models of model trainings. + """ - def __init__(self, batch_size: int = 1): - self.batch_size = batch_size + def __init__(self, batch_size: int = 1): + self.batch_size = batch_size - def __repr__(self): - return f'{self.__class__.__name__}(batch_size={self.batch_size})' + def __repr__(self): + return f'{self.__class__.__name__}(batch_size={self.batch_size})' class TrainingMode(BatchingMode): - """Training mode requires data batching.""" - def to_batch_mode(self): - return BatchingMode(self.batch_size) + """Training mode requires data batching.""" + + def to_batch_mode(self): + return BatchingMode(self.batch_size) nonbatching_mode = NonBatchingMode() diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 53ea8781e..be1f43f32 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -14,10 +14,9 @@ bm = None - __all__ = [ - 'BaseArray', 'Array', 'ndarray', 'JaxArray', # alias of Array - 'ShardedArray', + 'BaseArray', 'Array', 'ndarray', 'JaxArray', # alias of Array + 'ShardedArray', ] # Ways to change values in a zero-dimensional array @@ -33,1511 +32,1523 @@ def _check_input_array(array): - if isinstance(array, BaseArray): - return array.value - elif isinstance(array, np.ndarray): - return jnp.asarray(array) - else: - return array + if isinstance(array, BaseArray): + return array.value + elif isinstance(array, np.ndarray): + return jnp.asarray(array) + else: + return array def _return(a): - if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0: - return Array(a) - return a + if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0: + return Array(a) + return a def _as_jax_array_(obj): - return obj.value if isinstance(obj, BaseArray) else obj + return obj.value if isinstance(obj, BaseArray) else obj def _check_out(out): - if not isinstance(out, BaseArray): - raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}') + if not isinstance(out, BaseArray): + raise TypeError(f'out must be an instance of brainpy Array. But got {type(out)}') def _get_dtype(v): - if hasattr(v, 'dtype'): - dtype = v.dtype - else: - dtype = canonicalize_dtype(type(v)) - return dtype - -class BaseArray: - @property - def sharding(self): - return self._value.sharding - - @property - def addressable_shards(self): - return self._value.addressable_shards - - @property - def value(self): - # return the value - return self._value - - @value.setter - def value(self, value): - self_value = self._check_tracer() - - if isinstance(value, BaseArray): - value = value.value - elif isinstance(value, np.ndarray): - value = jnp.asarray(value) - elif isinstance(value, jax.Array): - pass - else: - value = jnp.asarray(value) - # check - if value.shape != self_value.shape: - raise MathError(f"The shape of the original data is {self_value.shape}, " - f"while we got {value.shape}.") - if value.dtype != self_value.dtype: - raise MathError(f"The dtype of the original data is {self_value.dtype}, " - f"while we got {value.dtype}.") - self._value = value - - def update(self, value): - """Update the value of this Array. - """ - self.value = value - - def __hash__(self): - return hash(self.value) - - @property - def dtype(self): - """Variable dtype.""" - return _get_dtype(self._value) - - @property - def shape(self): - """Variable shape.""" - return self.value.shape - - @property - def ndim(self): - return self.value.ndim - - @property - def imag(self): - return _return(self.value.image) - - @property - def real(self): - return _return(self.value.real) - - @property - def size(self): - return self.value.size - - @property - def T(self): - return _return(self.value.T) - - # ----------------------- # - # Python inherent methods # - # ----------------------- # - - def __repr__(self) -> str: - print_code = repr(self.value) - if ', dtype' in print_code: - print_code = print_code.split(', dtype')[0] + ')' - prefix = f'{self.__class__.__name__}' - prefix2 = f'{self.__class__.__name__}(value=' - if '\n' in print_code: - lines = print_code.split("\n") - blank1 = " " * len(prefix2) - lines[0] = prefix2 + lines[0] - for i in range(1, len(lines)): - lines[i] = blank1 + lines[i] - lines[-1] += "," - blank2 = " " * (len(prefix) + 1) - lines.append(f'{blank2}dtype={self.dtype})') - print_code = "\n".join(lines) - else: - print_code = prefix2 + print_code + f', dtype={self.dtype})' - return print_code - - def __format__(self, format_spec: str) -> str: - return format(self.value) - - def __iter__(self): - """Solve the issue of DeviceArray.__iter__. - - Details please see JAX issues: - - - https://github.com/google/jax/issues/7713 - - https://github.com/google/jax/pull/3821 - """ - for i in range(self.value.shape[0]): - yield self.value[i] - - def __getitem__(self, index): - if isinstance(index, slice) and (index == _all_slice): - return self.value - elif isinstance(index, tuple): - index = tuple((x.value if isinstance(x, BaseArray) else x) for x in index) - elif isinstance(index, BaseArray): - index = index.value - return self.value[index] - - def __setitem__(self, index, value): - # value is Array - if isinstance(value, BaseArray): - value = value.value - # value is numpy.ndarray - elif isinstance(value, np.ndarray): - value = jnp.asarray(value) - - # index is a tuple - if isinstance(index, tuple): - index = tuple(_check_input_array(x) for x in index) - # index is Array - elif isinstance(index, BaseArray): - index = index.value - # index is numpy.ndarray - elif isinstance(index, np.ndarray): - index = jnp.asarray(index) - - # update - self_value = self.value - self.value = self_value.at[index].set(value) - - # ---------- # - # operations # - # ---------- # - - def __len__(self) -> int: - return len(self.value) - - def __neg__(self): - return _return(self.value.__neg__()) - - def __pos__(self): - return _return(self.value.__pos__()) - - def __abs__(self): - return _return(self.value.__abs__()) - - def __invert__(self): - return _return(self.value.__invert__()) - - def __eq__(self, oc): - return _return(self.value == _check_input_array(oc)) - - def __ne__(self, oc): - return _return(self.value != _check_input_array(oc)) - - def __lt__(self, oc): - return _return(self.value < _check_input_array(oc)) - - def __le__(self, oc): - return _return(self.value <= _check_input_array(oc)) - - def __gt__(self, oc): - return _return(self.value > _check_input_array(oc)) - - def __ge__(self, oc): - return _return(self.value >= _check_input_array(oc)) - - def __add__(self, oc): - return _return(self.value + _check_input_array(oc)) - - def __radd__(self, oc): - return _return(self.value + _check_input_array(oc)) - - def __iadd__(self, oc): - # a += b - self.value = self.value + _check_input_array(oc) - return self - - def __sub__(self, oc): - return _return(self.value - _check_input_array(oc)) - - def __rsub__(self, oc): - return _return(_check_input_array(oc) - self.value) - - def __isub__(self, oc): - # a -= b - self.value = self.value - _check_input_array(oc) - return self - - def __mul__(self, oc): - return _return(self.value * _check_input_array(oc)) - - def __rmul__(self, oc): - return _return(_check_input_array(oc) * self.value) - - def __imul__(self, oc): - # a *= b - self.value = self.value * _check_input_array(oc) - return self - - def __rdiv__(self, oc): - return _return(_check_input_array(oc) / self.value) - - def __truediv__(self, oc): - return _return(self.value / _check_input_array(oc)) - - def __rtruediv__(self, oc): - return _return(_check_input_array(oc) / self.value) - - def __itruediv__(self, oc): - # a /= b - self.value = self.value / _check_input_array(oc) - return self - - def __floordiv__(self, oc): - return _return(self.value // _check_input_array(oc)) - - def __rfloordiv__(self, oc): - return _return(_check_input_array(oc) // self.value) - - def __ifloordiv__(self, oc): - # a //= b - self.value = self.value // _check_input_array(oc) - return self - - def __divmod__(self, oc): - return _return(self.value.__divmod__(_check_input_array(oc))) - - def __rdivmod__(self, oc): - return _return(self.value.__rdivmod__(_check_input_array(oc))) - - def __mod__(self, oc): - return _return(self.value % _check_input_array(oc)) - - def __rmod__(self, oc): - return _return(_check_input_array(oc) % self.value) - - def __imod__(self, oc): - # a %= b - self.value = self.value % _check_input_array(oc) - return self - - def __pow__(self, oc): - return _return(self.value ** _check_input_array(oc)) - - def __rpow__(self, oc): - return _return(_check_input_array(oc) ** self.value) - - def __ipow__(self, oc): - # a **= b - self.value = self.value ** _check_input_array(oc) - return self - - def __matmul__(self, oc): - return _return(self.value @ _check_input_array(oc)) - - def __rmatmul__(self, oc): - return _return(_check_input_array(oc) @ self.value) - - def __imatmul__(self, oc): - # a @= b - self.value = self.value @ _check_input_array(oc) - return self - - def __and__(self, oc): - return _return(self.value & _check_input_array(oc)) - - def __rand__(self, oc): - return _return(_check_input_array(oc) & self.value) - - def __iand__(self, oc): - # a &= b - self.value = self.value & _check_input_array(oc) - return self - - def __or__(self, oc): - return _return(self.value | _check_input_array(oc)) - - def __ror__(self, oc): - return _return(_check_input_array(oc) | self.value) - - def __ior__(self, oc): - # a |= b - self.value = self.value | _check_input_array(oc) - return self - - def __xor__(self, oc): - return _return(self.value ^ _check_input_array(oc)) - - def __rxor__(self, oc): - return _return(_check_input_array(oc) ^ self.value) - - def __ixor__(self, oc): - # a ^= b - self.value = self.value ^ _check_input_array(oc) - return self - - def __lshift__(self, oc): - return _return(self.value << _check_input_array(oc)) - - def __rlshift__(self, oc): - return _return(_check_input_array(oc) << self.value) - - def __ilshift__(self, oc): - # a <<= b - self.value = self.value << _check_input_array(oc) - return self - - def __rshift__(self, oc): - return _return(self.value >> _check_input_array(oc)) - - def __rrshift__(self, oc): - return _return(_check_input_array(oc) >> self.value) - - def __irshift__(self, oc): - # a >>= b - self.value = self.value >> _check_input_array(oc) - return self - - def __round__(self, ndigits=None): - return _return(self.value.__round__(ndigits)) - - # ----------------------- # - # JAX methods # - # ----------------------- # - - @property - def at(self): - return self.value.at - - def block_host_until_ready(self, *args): - return self.value.block_host_until_ready(*args) - - def block_until_ready(self, *args): - return self.value.block_until_ready(*args) - - def device(self): - return self.value.device() - - @property - def device_buffer(self): - return self.value.device_buffer - - # ----------------------- # - # NumPy methods # - # ----------------------- # - - def all(self, axis=None, keepdims=False): - """Returns True if all elements evaluate to True.""" - r = self.value.all(axis=axis, keepdims=keepdims) - return _return(r) - - def any(self, axis=None, keepdims=False): - """Returns True if any of the elements of a evaluate to True.""" - r = self.value.any(axis=axis, keepdims=keepdims) - return _return(r) - - def argmax(self, axis=None): - """Return indices of the maximum values along the given axis.""" - return _return(self.value.argmax(axis=axis)) - - def argmin(self, axis=None): - """Return indices of the minimum values along the given axis.""" - return _return(self.value.argmin(axis=axis)) - - def argpartition(self, kth, axis=-1, kind='introselect', order=None): - """Returns the indices that would partition this array.""" - return _return(self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order)) - - def argsort(self, axis=-1, kind=None, order=None): - """Returns the indices that would sort this array.""" - return _return(self.value.argsort(axis=axis, kind=kind, order=order)) - - def astype(self, dtype): - """Copy of the array, cast to a specified type. - - Parameters:: - - dtype: str, dtype - Typecode or data-type to which the array is cast. - """ - if dtype is None: - return _return(self.value) - else: - return _return(self.value.astype(dtype)) - - def byteswap(self, inplace=False): - """Swap the bytes of the array elements - - Toggle between low-endian and big-endian data representation by - returning a byteswapped array, optionally swapped in-place. - Arrays of byte-strings are not swapped. The real and imaginary - parts of a complex number are swapped individually.""" - return _return(self.value.byteswap(inplace=inplace)) - - def choose(self, choices, mode='raise'): - """Use an index array to construct a new array from a set of choices.""" - return _return(self.value.choose(choices=_as_jax_array_(choices), mode=mode)) - - def clip(self, min=None, max=None, out=None, ): - """Return an array whose values are limited to [min, max]. One of max or min must be given.""" - min = _as_jax_array_(min) - max = _as_jax_array_(max) - r = self.value.clip(min=min, max=max) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r - - def compress(self, condition, axis=None): - """Return selected slices of this array along given axis.""" - return _return(self.value.compress(condition=_as_jax_array_(condition), axis=axis)) - - def conj(self): - """Complex-conjugate all elements.""" - return _return(self.value.conj()) - - def conjugate(self): - """Return the complex conjugate, element-wise.""" - return _return(self.value.conjugate()) - - def copy(self): - """Return a copy of the array.""" - return _return(self.value.copy()) - - def cumprod(self, axis=None, dtype=None): - """Return the cumulative product of the elements along the given axis.""" - return _return(self.value.cumprod(axis=axis, dtype=dtype)) - - def cumsum(self, axis=None, dtype=None): - """Return the cumulative sum of the elements along the given axis.""" - return _return(self.value.cumsum(axis=axis, dtype=dtype)) - - def diagonal(self, offset=0, axis1=0, axis2=1): - """Return specified diagonals.""" - return _return(self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2)) - - def dot(self, b): - """Dot product of two arrays.""" - return _return(self.value.dot(_as_jax_array_(b))) - - def fill(self, value): - """Fill the array with a scalar value.""" - self.value = jnp.ones_like(self.value) * value - - def flatten(self): - return _return(self.value.flatten()) - - def item(self, *args): - """Copy an element of an array to a standard Python scalar and return it.""" - return self.value.item(*args) - - def max(self, axis=None, keepdims=False, *args, **kwargs): - """Return the maximum along a given axis.""" - res = self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs) - return _return(res) - - def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): - """Returns the average of the array elements along given axis.""" - res = self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) - return _return(res) - - def min(self, axis=None, keepdims=False, *args, **kwargs): - """Return the minimum along a given axis.""" - res = self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs) - return _return(res) - - def nonzero(self): - """Return the indices of the elements that are non-zero.""" - return tuple(_return(a) for a in self.value.nonzero()) - - def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): - """Return the product of the array elements over the given axis.""" - res = self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - return _return(res) - - def ptp(self, axis=None, keepdims=False): - """Peak to peak (maximum - minimum) value along a given axis.""" - r = self.value.ptp(axis=axis, keepdims=keepdims) - return _return(r) - - def put(self, indices, values): - """Replaces specified elements of an array with given values. - - Parameters:: - - indices: array_like - Target indices, interpreted as integers. - values: array_like - Values to place in the array at target indices. - """ - self.__setitem__(indices, values) - - def ravel(self, order=None): - """Return a flattened array.""" - return _return(self.value.ravel(order=order)) - - def repeat(self, repeats, axis=None): - """Repeat elements of an array.""" - return _return(self.value.repeat(repeats=repeats, axis=axis)) - - def reshape(self, *shape, order='C'): - """Returns an array containing the same data with a new shape.""" - return _return(self.value.reshape(*shape, order=order)) - - def resize(self, new_shape): - """Change shape and size of array in-place.""" - self.value = self.value.reshape(new_shape) - - def round(self, decimals=0): - """Return ``a`` with each element rounded to the given number of decimals.""" - return _return(self.value.round(decimals=decimals)) - - def searchsorted(self, v, side='left', sorter=None): - """Find indices where elements should be inserted to maintain order. - - Find the indices into a sorted array `a` such that, if the - corresponding elements in `v` were inserted before the indices, the - order of `a` would be preserved. - - Assuming that `a` is sorted: - - ====== ============================ - `side` returned index `i` satisfies - ====== ============================ - left ``a[i-1] < v <= a[i]`` - right ``a[i-1] <= v < a[i]`` - ====== ============================ - - Parameters:: - - v : array_like - Values to insert into `a`. - side : {'left', 'right'}, optional - If 'left', the index of the first suitable location found is given. - If 'right', return the last such index. If there is no suitable - index, return either 0 or N (where N is the length of `a`). - sorter : 1-D array_like, optional - Optional array of integer indices that sort array a into ascending - order. They are typically the result of argsort. - - Returns:: - - indices : array of ints - Array of insertion points with the same shape as `v`. - """ - return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter)) - - def sort(self, axis=-1, stable=True, order=None): - """Sort an array in-place. - - Parameters:: - - axis : int, optional - Axis along which to sort. Default is -1, which means sort along the - last axis. - stable : bool, optional - Whether to use a stable sorting algorithm. The default is True. - order : str or list of str, optional - When `a` is an array with fields defined, this argument specifies - which fields to compare first, second, etc. A single field can - be specified as a string, and not all fields need be specified, - but unspecified fields will still be used, in the order in which - they come up in the dtype, to break ties. - """ - self.value = self.value.sort(axis=axis, stable=stable, order=order) - - - def squeeze(self, axis=None): - """Remove axes of length one from ``a``.""" - return _return(self.value.squeeze(axis=axis)) - - def std(self, axis=None, dtype=None, ddof=0, keepdims=False): - """Compute the standard deviation along the specified axis. - - Returns the standard deviation, a measure of the spread of a distribution, - of the array elements. The standard deviation is computed for the - flattened array by default, otherwise over the specified axis. - - Parameters:: - - axis : None or int or tuple of ints, optional - Axis or axes along which the standard deviation is computed. The - default is to compute the standard deviation of the flattened array. - If this is a tuple of ints, a standard deviation is performed over - multiple axes, instead of a single axis or all the axes as before. - dtype : dtype, optional - Type to use in computing the standard deviation. For arrays of - integer type the default is float64, for arrays of float types it is - the same as the array type. - ddof : int, optional - Means Delta Degrees of Freedom. The divisor used in calculations - is ``N - ddof``, where ``N`` represents the number of elements. - By default `ddof` is zero. - keepdims : bool, optional - If this is set to True, the axes which are reduced are left - in the result as dimensions with size one. With this option, - the result will broadcast correctly against the input array. - - If the default value is passed, then `keepdims` will not be - passed through to the `std` method of sub-classes of - `ndarray`, however any non-default value will be. If the - sub-class' method does not implement `keepdims` any - exceptions will be raised. - - Returns:: - - standard_deviation : ndarray, see dtype parameter above. - If `out` is None, return a new array containing the standard deviation, - otherwise return a reference to the output array. - """ - r = self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) - return _return(r) - - def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): - """Return the sum of the array elements over the given axis.""" - res = self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - return _return(res) - - def swapaxes(self, axis1, axis2): - """Return a view of the array with `axis1` and `axis2` interchanged.""" - return _return(self.value.swapaxes(axis1, axis2)) - - def split(self, indices_or_sections, axis=0): - """Split an array into multiple sub-arrays as views into ``ary``. - - Parameters:: - - indices_or_sections : int, 1-D array - If `indices_or_sections` is an integer, N, the array will be divided - into N equal arrays along `axis`. If such a split is not possible, - an error is raised. - - If `indices_or_sections` is a 1-D array of sorted integers, the entries - indicate where along `axis` the array is split. For example, - ``[2, 3]`` would, for ``axis=0``, result in - - - ary[:2] - - ary[2:3] - - ary[3:] - - If an index exceeds the dimension of the array along `axis`, - an empty sub-array is returned correspondingly. - axis : int, optional - The axis along which to split, default is 0. - - Returns:: - - sub-arrays : list of ndarrays - A list of sub-arrays as views into `ary`. - """ - return [_return(a) for a in jnp.split(self.value, indices_or_sections, axis=axis)] - - def take(self, indices, axis=None, mode=None): - """Return an array formed from the elements of a at the given indices.""" - return _return(self.value.take(indices=_as_jax_array_(indices), axis=axis, mode=mode)) - - def tobytes(self): - """Construct Python bytes containing the raw data bytes in the array. - - Constructs Python bytes showing a copy of the raw contents of data memory. - The bytes object is produced in C-order by default. This behavior is - controlled by the ``order`` parameter.""" - return self.value.tobytes() - - def tolist(self): - """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. - - Return a copy of the array data as a (nested) Python list. - Data items are converted to the nearest compatible builtin Python type, via - the `~numpy.ndarray.item` function. - - If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will - not be a list at all, but a simple Python scalar. - """ - return self.value.tolist() - - def trace(self, offset=0, axis1=0, axis2=1, dtype=None): - """Return the sum along diagonals of the array.""" - return _return(self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)) - - def transpose(self, *axes): - """Returns a view of the array with axes transposed. - - For a 1-D array this has no effect, as a transposed vector is simply the - same vector. To convert a 1-D array into a 2D column vector, an additional - dimension must be added. `np.atleast2d(a).T` achieves this, as does - `a[:, np.newaxis]`. - For a 2-D array, this is a standard matrix transpose. - For an n-D array, if axes are given, their order indicates how the - axes are permuted (see Examples). If axes are not provided and - ``a.shape = (i[0], i[1], ... i[n-2], i[n-1])``, then - ``a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0])``. - - Parameters:: - - axes : None, tuple of ints, or `n` ints - - * None or no argument: reverses the order of the axes. - - * tuple of ints: `i` in the `j`-th place in the tuple means `a`'s - `i`-th axis becomes `a.transpose()`'s `j`-th axis. - - * `n` ints: same as an n-tuple of the same ints (this form is - intended simply as a "convenience" alternative to the tuple form) - - Returns:: - - out : ndarray - View of `a`, with axes suitably permuted. - """ - return _return(self.value.transpose(*axes)) - - def tile(self, reps): - """Construct an array by repeating A the number of times given by reps. - - If `reps` has length ``d``, the result will have dimension of - ``max(d, A.ndim)``. - - If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new - axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, - or shape (1, 1, 3) for 3-D replication. If this is not the desired - behavior, promote `A` to d-dimensions manually before calling this - function. - - If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it. - Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as - (1, 1, 2, 2). - - Note : Although tile may be used for broadcasting, it is strongly - recommended to use numpy's broadcasting operations and functions. - - Parameters:: - - reps : array_like - The number of repetitions of `A` along each axis. - - Returns:: - - c : ndarray - The tiled output array. - """ - return _return(self.value.tile(_as_jax_array_(reps))) - - def var(self, axis=None, dtype=None, ddof=0, keepdims=False): - """Returns the variance of the array elements, along given axis.""" - r = self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) - return _return(r) - - def view(self, *args, dtype=None): - r"""New view of array with the same data. - - This function is compatible with pytorch syntax. - - Returns a new tensor with the same data as the :attr:`self` tensor but of a - different :attr:`shape`. - - The returned tensor shares the same data and must have the same number - of elements, but may have a different size. For a tensor to be viewed, the new - view size must be compatible with its original size and stride, i.e., each new - view dimension must either be a subspace of an original dimension, or only span - across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following - contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, - - .. math:: - - \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] - - Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` - without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a - :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which - returns a view if the shapes are compatible, and copies (equivalent to calling - :meth:`contiguous`) otherwise. - - Args: - shape (int...): the desired size - - Example:: - - >>> x = brainpy.math.random.randn(4, 4) - >>> x.size - [4, 4] - >>> y = x.view(16) - >>> y.size - [16] - >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions - >>> z.size - [2, 8] - - >>> a = brainpy.math.random.randn(1, 2, 3, 4) - >>> a.size - [1, 2, 3, 4] - >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension - >>> b.size - [1, 3, 2, 4] - >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory - >>> c.size - [1, 3, 2, 4] - >>> brainpy.math.equal(b, c) - False - - - .. method:: view(dtype) -> Tensor - :noindex: - - Returns a new tensor with the same data as the :attr:`self` tensor but of a - different :attr:`dtype`. - - If the element size of :attr:`dtype` is different than that of ``self.dtype``, - then the size of the last dimension of the output will be scaled - proportionally. For instance, if :attr:`dtype` element size is twice that of - ``self.dtype``, then each pair of elements in the last dimension of - :attr:`self` will be combined, and the size of the last dimension of the output - will be half that of :attr:`self`. If :attr:`dtype` element size is half that - of ``self.dtype``, then each element in the last dimension of :attr:`self` will - be split in two, and the size of the last dimension of the output will be - double that of :attr:`self`. For this to be possible, the following conditions - must be true: - - * ``self.dim()`` must be greater than 0. - * ``self.stride(-1)`` must be 1. - - Additionally, if the element size of :attr:`dtype` is greater than that of - ``self.dtype``, the following conditions must be true as well: - - * ``self.size(-1)`` must be divisible by the ratio between the element - sizes of the dtypes. - * ``self.storage_offset()`` must be divisible by the ratio between the - element sizes of the dtypes. - * The strides of all dimensions, except the last dimension, must be - divisible by the ratio between the element sizes of the dtypes. - - If any of the above conditions are not met, an error is thrown. - - - Args: - dtype (:class:`dtype`): the desired dtype - - Example:: - - >>> x = brainpy.math.random.randn(4, 4) - >>> x - Array([[ 0.9482, -0.0310, 1.4999, -0.5316], - [-0.1520, 0.7472, 0.5617, -0.8649], - [-2.4724, -0.0334, -0.2976, -0.8499], - [-0.2109, 1.9913, -0.9607, -0.6123]]) - >>> x.dtype - brainpy.math.float32 - - >>> y = x.view(brainpy.math.int32) - >>> y - tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], - [-1105482831, 1061112040, 1057999968, -1084397505], - [-1071760287, -1123489973, -1097310419, -1084649136], - [-1101533110, 1073668768, -1082790149, -1088634448]], - dtype=brainpy.math.int32) - >>> y[0, 0] = 1000000000 - >>> x - tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], - [-0.1520, 0.7472, 0.5617, -0.8649], - [-2.4724, -0.0334, -0.2976, -0.8499], - [-0.2109, 1.9913, -0.9607, -0.6123]]) - - >>> x.view(brainpy.math.cfloat) - tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], - [-0.1520+0.7472j, 0.5617-0.8649j], - [-2.4724-0.0334j, -0.2976-0.8499j], - [-0.2109+1.9913j, -0.9607-0.6123j]]) - >>> x.view(brainpy.math.cfloat).size - [4, 2] - - >>> x.view(brainpy.math.uint8) - tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, - 8, 191], - [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, - 93, 191], - [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, - 89, 191], - [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, - 28, 191]], dtype=brainpy.math.uint8) - >>> x.view(brainpy.math.uint8).size - [4, 16] - - """ - if len(args) == 0: - if dtype is None: - raise ValueError('Provide dtype or shape.') - else: - return _return(self.value.view(dtype)) + if hasattr(v, 'dtype'): + dtype = v.dtype else: - if isinstance(args[0], int): # shape - if dtype is not None: - raise ValueError('Provide one of dtype or shape. Not both.') - return _return(self.value.reshape(*args)) - else: # dtype - assert not isinstance(args[0], int) - assert dtype is None - return _return(self.value.view(args[0])) - - # ------------------ - # NumPy support - # ------------------ - - def numpy(self, dtype=None): - """Convert to numpy.ndarray.""" - # warnings.warn('Deprecated since 2.1.12. Please use ".to_numpy()" instead.', DeprecationWarning) - return np.asarray(self.value, dtype=dtype) - - def to_numpy(self, dtype=None): - """Convert to numpy.ndarray.""" - return np.asarray(self.value, dtype=dtype) - - def to_jax(self, dtype=None): - """Convert to jax.numpy.ndarray.""" - if dtype is None: - return self.value - else: - return jnp.asarray(self.value, dtype=dtype) - - def __array__(self, dtype=None): - """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" - return np.asarray(self.value, dtype=dtype) - - def __jax_array__(self): - return self.value - - def as_variable(self): - """As an instance of Variable.""" - global bm - if bm is None: from brainpy import math as bm - return bm.Variable(self) - - def __format__(self, specification): - return self.value.__format__(specification) - - def __bool__(self) -> bool: - return self.value.__bool__() + dtype = canonicalize_dtype(type(v)) + return dtype - def __float__(self): - return self.value.__float__() - def __int__(self): - return self.value.__int__() +class BaseArray: + @property + def sharding(self): + return self._value.sharding + + @property + def addressable_shards(self): + return self._value.addressable_shards + + @property + def value(self): + # return the value + return self._value + + @value.setter + def value(self, value): + self_value = self._check_tracer() + + if isinstance(value, BaseArray): + value = value.value + elif isinstance(value, np.ndarray): + value = jnp.asarray(value) + elif isinstance(value, jax.Array): + pass + else: + value = jnp.asarray(value) + # check + if value.shape != self_value.shape: + raise MathError(f"The shape of the original data is {self_value.shape}, " + f"while we got {value.shape}.") + if value.dtype != self_value.dtype: + raise MathError(f"The dtype of the original data is {self_value.dtype}, " + f"while we got {value.dtype}.") + self._value = value + + def update(self, value): + """Update the value of this Array. + """ + self.value = value + + def __hash__(self): + return hash(self.value) + + @property + def dtype(self): + """Variable dtype.""" + return _get_dtype(self._value) + + @property + def shape(self): + """Variable shape.""" + return self.value.shape + + @property + def ndim(self): + return self.value.ndim + + @property + def imag(self): + return _return(self.value.image) + + @property + def real(self): + return _return(self.value.real) + + @property + def size(self): + return self.value.size + + @property + def T(self): + return _return(self.value.T) + + # ----------------------- # + # Python inherent methods # + # ----------------------- # + + def __repr__(self) -> str: + print_code = repr(self.value) + if ', dtype' in print_code: + print_code = print_code.split(', dtype')[0] + ')' + prefix = f'{self.__class__.__name__}' + prefix2 = f'{self.__class__.__name__}(value=' + if '\n' in print_code: + lines = print_code.split("\n") + blank1 = " " * len(prefix2) + lines[0] = prefix2 + lines[0] + for i in range(1, len(lines)): + lines[i] = blank1 + lines[i] + lines[-1] += "," + blank2 = " " * (len(prefix) + 1) + lines.append(f'{blank2}dtype={self.dtype})') + print_code = "\n".join(lines) + else: + print_code = prefix2 + print_code + f', dtype={self.dtype})' + return print_code + + def __format__(self, format_spec: str) -> str: + return format(self.value) + + def __iter__(self): + """Solve the issue of DeviceArray.__iter__. + + Details please see JAX issues: + + - https://github.com/google/jax/issues/7713 + - https://github.com/google/jax/pull/3821 + """ + for i in range(self.value.shape[0]): + yield self.value[i] + + def __getitem__(self, index): + if isinstance(index, slice) and (index == _all_slice): + return self.value + elif isinstance(index, tuple): + index = tuple((x.value if isinstance(x, BaseArray) else x) for x in index) + elif isinstance(index, BaseArray): + index = index.value + return self.value[index] + + def __setitem__(self, index, value): + # value is Array + if isinstance(value, BaseArray): + value = value.value + # value is numpy.ndarray + elif isinstance(value, np.ndarray): + value = jnp.asarray(value) + + # index is a tuple + if isinstance(index, tuple): + index = tuple(_check_input_array(x) for x in index) + # index is Array + elif isinstance(index, BaseArray): + index = index.value + # index is numpy.ndarray + elif isinstance(index, np.ndarray): + index = jnp.asarray(index) + + # update + self_value = self.value + self.value = self_value.at[index].set(value) + + # ---------- # + # operations # + # ---------- # + + def __len__(self) -> int: + return len(self.value) + + def __neg__(self): + return _return(self.value.__neg__()) + + def __pos__(self): + return _return(self.value.__pos__()) + + def __abs__(self): + return _return(self.value.__abs__()) + + def __invert__(self): + return _return(self.value.__invert__()) + + def __eq__(self, oc): + return _return(self.value == _check_input_array(oc)) + + def __ne__(self, oc): + return _return(self.value != _check_input_array(oc)) + + def __lt__(self, oc): + return _return(self.value < _check_input_array(oc)) + + def __le__(self, oc): + return _return(self.value <= _check_input_array(oc)) + + def __gt__(self, oc): + return _return(self.value > _check_input_array(oc)) + + def __ge__(self, oc): + return _return(self.value >= _check_input_array(oc)) + + def __add__(self, oc): + return _return(self.value + _check_input_array(oc)) - def __complex__(self): - return self.value.__complex__() + def __radd__(self, oc): + return _return(self.value + _check_input_array(oc)) - def __hex__(self): - assert self.ndim == 0, 'hex only works on scalar values' - return hex(self.value) # type: ignore + def __iadd__(self, oc): + # a += b + self.value = self.value + _check_input_array(oc) + return self - def __oct__(self): - assert self.ndim == 0, 'oct only works on scalar values' - return oct(self.value) # type: ignore + def __sub__(self, oc): + return _return(self.value - _check_input_array(oc)) - def __index__(self): - return operator.index(self.value) + def __rsub__(self, oc): + return _return(_check_input_array(oc) - self.value) - def __dlpack__(self): - from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top - return to_dlpack(self.value) + def __isub__(self, oc): + # a -= b + self.value = self.value - _check_input_array(oc) + return self - # ---------------------- - # PyTorch compatibility - # ---------------------- + def __mul__(self, oc): + return _return(self.value * _check_input_array(oc)) - def unsqueeze(self, dim: int) -> 'Array': - """ - Array.unsqueeze(dim) -> Array, or so called Tensor - equals - Array.expand_dims(dim) + def __rmul__(self, oc): + return _return(_check_input_array(oc) * self.value) - See :func:`brainpy.math.unsqueeze` - """ - return _return(jnp.expand_dims(self.value, dim)) + def __imul__(self, oc): + # a *= b + self.value = self.value * _check_input_array(oc) + return self - def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': - """ - self.expand_dims(axis: int|Sequence[int]) + def __rdiv__(self, oc): + return _return(_check_input_array(oc) / self.value) - 1. 如果axis类型为int: - 返回一个在self基础上的第axis维度前插入一个维度Array, - axis<0表示倒数第|axis|维度, - 令n=len(self._value.shape),则axis的范围为[-(n+1),n] + def __truediv__(self, oc): + return _return(self.value / _check_input_array(oc)) - 2. 如果axis类型为Sequence[int]: - 则返回依次扩展axis[i]的结果, - 即self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])...expand_dims(axis[len(axis)-1]) + def __rtruediv__(self, oc): + return _return(_check_input_array(oc) / self.value) + def __itruediv__(self, oc): + # a /= b + self.value = self.value / _check_input_array(oc) + return self - 1. If the type of axis is int: + def __floordiv__(self, oc): + return _return(self.value // _check_input_array(oc)) - Returns an Array of dimensions inserted before the axis dimension based on self, + def __rfloordiv__(self, oc): + return _return(_check_input_array(oc) // self.value) - The first | axis < 0 indicates the bottom axis | dimensions, + def __ifloordiv__(self, oc): + # a //= b + self.value = self.value // _check_input_array(oc) + return self - Set n=len(self._value.shape), then axis has the range [-(n+1),n] + def __divmod__(self, oc): + return _return(self.value.__divmod__(_check_input_array(oc))) + def __rdivmod__(self, oc): + return _return(self.value.__rdivmod__(_check_input_array(oc))) - 2. If the type of axis is Sequence[int] : + def __mod__(self, oc): + return _return(self.value % _check_input_array(oc)) - Returns the result of extending axis[i] in sequence, - - self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])... expand_dims(axis[len(axis)-1]) - - """ - return _return(jnp.expand_dims(self.value, axis)) + def __rmod__(self, oc): + return _return(_check_input_array(oc) % self.value) - def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': - """ - Expand an array to a shape of another array. + def __imod__(self, oc): + # a %= b + self.value = self.value % _check_input_array(oc) + return self - Parameters:: + def __pow__(self, oc): + return _return(self.value ** _check_input_array(oc)) - array : Array + def __rpow__(self, oc): + return _return(_check_input_array(oc) ** self.value) - Returns:: + def __ipow__(self, oc): + # a **= b + self.value = self.value ** _check_input_array(oc) + return self - expanded : Array - A readonly view on the original array with the given shape of array. It is - typically not contiguous. Furthermore, more than one element of a - expanded array may refer to a single memory location. - """ - return _return(jnp.broadcast_to(self.value, array)) + def __matmul__(self, oc): + return _return(self.value @ _check_input_array(oc)) - def pow(self, index: int): - return _return(self.value ** index) + def __rmatmul__(self, oc): + return _return(_check_input_array(oc) @ self.value) - def addr( - self, - vec1: Union['Array', jax.Array, np.ndarray], - vec2: Union['Array', jax.Array, np.ndarray], - *, - beta: float = 1.0, - alpha: float = 1.0, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None - ) -> Optional['Array']: - r"""Performs the outer-product of vectors ``vec1`` and ``vec2`` and adds it to the matrix ``input``. + def __imatmul__(self, oc): + # a @= b + self.value = self.value @ _check_input_array(oc) + return self - Optional values beta and alpha are scaling factors on the outer product - between vec1 and vec2 and the added matrix input respectively. + def __and__(self, oc): + return _return(self.value & _check_input_array(oc)) - .. math:: + def __rand__(self, oc): + return _return(_check_input_array(oc) & self.value) - out = \beta \mathrm{input} + \alpha (\text{vec1} \bigtimes \text{vec2}) + def __iand__(self, oc): + # a &= b + self.value = self.value & _check_input_array(oc) + return self - Args: - vec1: the first vector of the outer product - vec2: the second vector of the outer product - beta: multiplier for input - alpha: multiplier - out: the output tensor. + def __or__(self, oc): + return _return(self.value | _check_input_array(oc)) - """ - vec1 = _as_jax_array_(vec1) - vec2 = _as_jax_array_(vec2) - r = alpha * jnp.outer(vec1, vec2) + beta * self.value - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r - - def addr_( - self, - vec1: Union['Array', jax.Array, np.ndarray], - vec2: Union['Array', jax.Array, np.ndarray], - *, - beta: float = 1.0, - alpha: float = 1.0 - ): - vec1 = _as_jax_array_(vec1) - vec2 = _as_jax_array_(vec2) - r = alpha * jnp.outer(vec1, vec2) + beta * self.value - self.value = r - return self - - def outer(self, other: Union['Array', jax.Array, np.ndarray]) -> 'Array': - other = _as_jax_array_(other) - return _return(jnp.outer(self.value, other.value)) - - def abs(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.abs(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + def __ror__(self, oc): + return _return(_check_input_array(oc) | self.value) - def abs_(self): - """ - in-place version of Array.abs() - """ - self.value = jnp.abs(self.value) - return self + def __ior__(self, oc): + # a |= b + self.value = self.value | _check_input_array(oc) + return self - def add_(self, value): - self.value += value - return self + def __xor__(self, oc): + return _return(self.value ^ _check_input_array(oc)) - def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - """ - alias of Array.abs - """ - return self.abs(out=out) + def __rxor__(self, oc): + return _return(_check_input_array(oc) ^ self.value) - def absolute_(self): - """ - alias of Array.abs_() - """ - return self.abs_() + def __ixor__(self, oc): + # a ^= b + self.value = self.value ^ _check_input_array(oc) + return self - def mul(self, value): - return _return(self.value * value) + def __lshift__(self, oc): + return _return(self.value << _check_input_array(oc)) - def mul_(self, value): - """ - In-place version of :meth:`~Array.mul`. - """ - self.value *= value - return self + def __rlshift__(self, oc): + return _return(_check_input_array(oc) << self.value) - def multiply(self, value): # real signature unknown; restored from __doc__ - """ - multiply(value) -> Tensor + def __ilshift__(self, oc): + # a <<= b + self.value = self.value << _check_input_array(oc) + return self - See :func:`torch.multiply`. - """ - return self.value * value + def __rshift__(self, oc): + return _return(self.value >> _check_input_array(oc)) - def multiply_(self, value): # real signature unknown; restored from __doc__ - """ - multiply_(value) -> Tensor + def __rrshift__(self, oc): + return _return(_check_input_array(oc) >> self.value) - In-place version of :meth:`~Tensor.multiply`. - """ - self.value *= value - return self + def __irshift__(self, oc): + # a >>= b + self.value = self.value >> _check_input_array(oc) + return self - def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.sin(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + def __round__(self, ndigits=None): + return _return(self.value.__round__(ndigits)) - def sin_(self): - self.value = jnp.sin(self.value) - return self + # ----------------------- # + # JAX methods # + # ----------------------- # - def cos_(self): - self.value = jnp.cos(self.value) - return self + @property + def at(self): + return self.value.at - def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.cos(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + def block_host_until_ready(self, *args): + return self.value.block_host_until_ready(*args) - def tan_(self): - self.value = jnp.tan(self.value) - return self + def block_until_ready(self, *args): + return self.value.block_until_ready(*args) - def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.tan(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + def device(self): + return self.value.device() - def sinh_(self): - self.value = jnp.tanh(self.value) - return self + @property + def device_buffer(self): + return self.value.device_buffer - def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.tanh(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + # ----------------------- # + # NumPy methods # + # ----------------------- # - def cosh_(self): - self.value = jnp.cosh(self.value) - return self + def all(self, axis=None, keepdims=False): + """Returns True if all elements evaluate to True.""" + r = self.value.all(axis=axis, keepdims=keepdims) + return _return(r) - def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.cosh(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + def any(self, axis=None, keepdims=False): + """Returns True if any of the elements of a evaluate to True.""" + r = self.value.any(axis=axis, keepdims=keepdims) + return _return(r) + + def argmax(self, axis=None): + """Return indices of the maximum values along the given axis.""" + return _return(self.value.argmax(axis=axis)) + + def argmin(self, axis=None): + """Return indices of the minimum values along the given axis.""" + return _return(self.value.argmin(axis=axis)) + + def argpartition(self, kth, axis=-1, kind='introselect', order=None): + """Returns the indices that would partition this array.""" + return _return(self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order)) + + def argsort(self, axis=-1, kind=None, order=None): + """Returns the indices that would sort this array.""" + return _return(self.value.argsort(axis=axis, kind=kind, order=order)) + + def astype(self, dtype): + """Copy of the array, cast to a specified type. + + Parameters:: + + dtype: str, dtype + Typecode or data-type to which the array is cast. + """ + if dtype is None: + return _return(self.value) + else: + return _return(self.value.astype(dtype)) + + def byteswap(self, inplace=False): + """Swap the bytes of the array elements + + Toggle between low-endian and big-endian data representation by + returning a byteswapped array, optionally swapped in-place. + Arrays of byte-strings are not swapped. The real and imaginary + parts of a complex number are swapped individually.""" + return _return(self.value.byteswap(inplace=inplace)) + + def choose(self, choices, mode='raise'): + """Use an index array to construct a new array from a set of choices.""" + return _return(self.value.choose(choices=_as_jax_array_(choices), mode=mode)) + + def clip(self, min=None, max=None, out=None, ): + """Return an array whose values are limited to [min, max]. One of max or min must be given.""" + min = _as_jax_array_(min) + max = _as_jax_array_(max) + r = self.value.clip(min=min, max=max) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def compress(self, condition, axis=None): + """Return selected slices of this array along given axis.""" + return _return(self.value.compress(condition=_as_jax_array_(condition), axis=axis)) + + def conj(self): + """Complex-conjugate all elements.""" + return _return(self.value.conj()) + + def conjugate(self): + """Return the complex conjugate, element-wise.""" + return _return(self.value.conjugate()) + + def copy(self): + """Return a copy of the array.""" + return _return(self.value.copy()) + + def cumprod(self, axis=None, dtype=None): + """Return the cumulative product of the elements along the given axis.""" + return _return(self.value.cumprod(axis=axis, dtype=dtype)) + + def cumsum(self, axis=None, dtype=None): + """Return the cumulative sum of the elements along the given axis.""" + return _return(self.value.cumsum(axis=axis, dtype=dtype)) + + def diagonal(self, offset=0, axis1=0, axis2=1): + """Return specified diagonals.""" + return _return(self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2)) + + def dot(self, b): + """Dot product of two arrays.""" + return _return(self.value.dot(_as_jax_array_(b))) + + def fill(self, value): + """Fill the array with a scalar value.""" + self.value = jnp.ones_like(self.value) * value + + def flatten(self): + return _return(self.value.flatten()) + + def item(self, *args): + """Copy an element of an array to a standard Python scalar and return it.""" + return self.value.item(*args) + + def max(self, axis=None, keepdims=False, *args, **kwargs): + """Return the maximum along a given axis.""" + res = self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs) + return _return(res) + + def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): + """Returns the average of the array elements along given axis.""" + res = self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) + return _return(res) + + def min(self, axis=None, keepdims=False, *args, **kwargs): + """Return the minimum along a given axis.""" + res = self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs) + return _return(res) + + def nonzero(self): + """Return the indices of the elements that are non-zero.""" + return tuple(_return(a) for a in self.value.nonzero()) + + def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): + """Return the product of the array elements over the given axis.""" + res = self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + return _return(res) + + def ptp(self, axis=None, keepdims=False): + """Peak to peak (maximum - minimum) value along a given axis.""" + r = self.value.ptp(axis=axis, keepdims=keepdims) + return _return(r) + + def put(self, indices, values): + """Replaces specified elements of an array with given values. + + Parameters:: + + indices: array_like + Target indices, interpreted as integers. + values: array_like + Values to place in the array at target indices. + """ + self.__setitem__(indices, values) + + def ravel(self, order=None): + """Return a flattened array.""" + return _return(self.value.ravel(order=order)) + + def repeat(self, repeats, axis=None): + """Repeat elements of an array.""" + return _return(self.value.repeat(repeats=repeats, axis=axis)) + + def reshape(self, *shape, order='C'): + """Returns an array containing the same data with a new shape.""" + return _return(self.value.reshape(*shape, order=order)) + + def resize(self, new_shape): + """Change shape and size of array in-place.""" + self.value = self.value.reshape(new_shape) + + def round(self, decimals=0): + """Return ``a`` with each element rounded to the given number of decimals.""" + return _return(self.value.round(decimals=decimals)) + + def searchsorted(self, v, side='left', sorter=None): + """Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted array `a` such that, if the + corresponding elements in `v` were inserted before the indices, the + order of `a` would be preserved. + + Assuming that `a` is sorted: + + ====== ============================ + `side` returned index `i` satisfies + ====== ============================ + left ``a[i-1] < v <= a[i]`` + right ``a[i-1] <= v < a[i]`` + ====== ============================ + + Parameters:: + + v : array_like + Values to insert into `a`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `a`). + sorter : 1-D array_like, optional + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. + + Returns:: + + indices : array of ints + Array of insertion points with the same shape as `v`. + """ + return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter)) + + def sort(self, axis=-1, stable=True, order=None): + """Sort an array in-place. + + Parameters:: + + axis : int, optional + Axis along which to sort. Default is -1, which means sort along the + last axis. + stable : bool, optional + Whether to use a stable sorting algorithm. The default is True. + order : str or list of str, optional + When `a` is an array with fields defined, this argument specifies + which fields to compare first, second, etc. A single field can + be specified as a string, and not all fields need be specified, + but unspecified fields will still be used, in the order in which + they come up in the dtype, to break ties. + """ + self.value = self.value.sort(axis=axis, stable=stable, order=order) + + def squeeze(self, axis=None): + """Remove axes of length one from ``a``.""" + return _return(self.value.squeeze(axis=axis)) + + def std(self, axis=None, dtype=None, ddof=0, keepdims=False): + """Compute the standard deviation along the specified axis. + + Returns the standard deviation, a measure of the spread of a distribution, + of the array elements. The standard deviation is computed for the + flattened array by default, otherwise over the specified axis. + + Parameters:: + + axis : None or int or tuple of ints, optional + Axis or axes along which the standard deviation is computed. The + default is to compute the standard deviation of the flattened array. + If this is a tuple of ints, a standard deviation is performed over + multiple axes, instead of a single axis or all the axes as before. + dtype : dtype, optional + Type to use in computing the standard deviation. For arrays of + integer type the default is float64, for arrays of float types it is + the same as the array type. + ddof : int, optional + Means Delta Degrees of Freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + By default `ddof` is zero. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + If the default value is passed, then `keepdims` will not be + passed through to the `std` method of sub-classes of + `ndarray`, however any non-default value will be. If the + sub-class' method does not implement `keepdims` any + exceptions will be raised. + + Returns:: + + standard_deviation : ndarray, see dtype parameter above. + If `out` is None, return a new array containing the standard deviation, + otherwise return a reference to the output array. + """ + r = self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) + return _return(r) + + def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): + """Return the sum of the array elements over the given axis.""" + res = self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + return _return(res) + + def swapaxes(self, axis1, axis2): + """Return a view of the array with `axis1` and `axis2` interchanged.""" + return _return(self.value.swapaxes(axis1, axis2)) + + def split(self, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays as views into ``ary``. + + Parameters:: + + indices_or_sections : int, 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + + Returns:: + + sub-arrays : list of ndarrays + A list of sub-arrays as views into `ary`. + """ + return [_return(a) for a in jnp.split(self.value, indices_or_sections, axis=axis)] + + def take(self, indices, axis=None, mode=None): + """Return an array formed from the elements of a at the given indices.""" + return _return(self.value.take(indices=_as_jax_array_(indices), axis=axis, mode=mode)) + + def tobytes(self): + """Construct Python bytes containing the raw data bytes in the array. + + Constructs Python bytes showing a copy of the raw contents of data memory. + The bytes object is produced in C-order by default. This behavior is + controlled by the ``order`` parameter.""" + return self.value.tobytes() + + def tolist(self): + """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. - def tanh_(self): - self.value = jnp.tanh(self.value) - return self + Return a copy of the array data as a (nested) Python list. + Data items are converted to the nearest compatible builtin Python type, via + the `~numpy.ndarray.item` function. - def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.tanh(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will + not be a list at all, but a simple Python scalar. + """ + return self.value.tolist() - def arcsin_(self): - self.value = jnp.arcsin(self.value) - return self + def trace(self, offset=0, axis1=0, axis2=1, dtype=None): + """Return the sum along diagonals of the array.""" + return _return(self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)) - def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.arcsin(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + def transpose(self, *axes): + """Returns a view of the array with axes transposed. + + For a 1-D array this has no effect, as a transposed vector is simply the + same vector. To convert a 1-D array into a 2D column vector, an additional + dimension must be added. `np.atleast2d(a).T` achieves this, as does + `a[:, np.newaxis]`. + For a 2-D array, this is a standard matrix transpose. + For an n-D array, if axes are given, their order indicates how the + axes are permuted (see Examples). If axes are not provided and + ``a.shape = (i[0], i[1], ... i[n-2], i[n-1])``, then + ``a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0])``. - def arccos_(self): - self.value = jnp.arccos(self.value) - return self + Parameters:: - def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.arccos(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + axes : None, tuple of ints, or `n` ints - def arctan_(self): - self.value = jnp.arctan(self.value) - return self + * None or no argument: reverses the order of the axes. + + * tuple of ints: `i` in the `j`-th place in the tuple means `a`'s + `i`-th axis becomes `a.transpose()`'s `j`-th axis. - def arctan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: - r = jnp.arctan(self.value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r - - def clamp( - self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - *, - out: Optional[Union['Array', jax.Array, np.ndarray]] = None - ) -> Optional['Array']: - """ - return the value between min_value and max_value, - if min_value is None, then no lower bound, - if max_value is None, then no upper bound. - """ - min_value = _as_jax_array_(min_value) - max_value = _as_jax_array_(max_value) - r = jnp.clip(self.value, max_value, max_value) - if out is None: - return _return(r) - else: - _check_out(out) - out.value = r + * `n` ints: same as an n-tuple of the same ints (this form is + intended simply as a "convenience" alternative to the tuple form) - def clamp_(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None): - """ - return the value between min_value and max_value, - if min_value is None, then no lower bound, - if max_value is None, then no upper bound. - """ - self.clamp(min_value, max_value, out=self) - return self + Returns:: - def clip_(self, - min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, - max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None): - """ - alias for clamp_ - """ - self.value = self.clip(min_value, max_value, out=self) - return self - - def clone(self) -> 'Array': - return _return(self.value.copy()) - - def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': - self.value = jnp.copy(_as_jax_array_(src)) - return self - - def cov_with( - self, - y: Optional[Union['Array', jax.Array, np.ndarray]] = None, - rowvar: bool = True, - bias: bool = False, - ddof: Optional[int] = None, - fweights: Union['Array', jax.Array, np.ndarray] = None, - aweights: Union['Array', jax.Array, np.ndarray] = None - ) -> 'Array': - y = _as_jax_array_(y) - fweights = _as_jax_array_(fweights) - aweights = _as_jax_array_(aweights) - r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) - return _return(r) - - def expand(self, *sizes) -> 'Array': - """ - Expand an array to a new shape. + out : ndarray + View of `a`, with axes suitably permuted. + """ + return _return(self.value.transpose(*axes)) - Parameters:: + def tile(self, reps): + """Construct an array by repeating A the number of times given by reps. - sizes : tuple or int - The shape of the desired array. A single integer ``i`` is interpreted - as ``(i,)``. + If `reps` has length ``d``, the result will have dimension of + ``max(d, A.ndim)``. - Returns:: + If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new + axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication, + or shape (1, 1, 3) for 3-D replication. If this is not the desired + behavior, promote `A` to d-dimensions manually before calling this + function. - expanded : Array - A readonly view on the original array with the given shape. It is - typically not contiguous. Furthermore, more than one element of a - expanded array may refer to a single memory location. - """ - l_ori = len(self.shape) - l_tar = len(sizes) - base = l_tar - l_ori - sizes_list = list(sizes) - if base < 0: - raise ValueError(f'the number of sizes provided ({len(sizes)}) must be greater or equal to the number of ' - f'dimensions in the tensor ({len(self.shape)})') - for i, v in enumerate(sizes[:base]): - if v < 0: - raise ValueError( - f'The expanded size of the tensor ({v}) isn\'t allowed in a leading, non-existing dimension {i + 1}') - for i, v in enumerate(self.shape): - sizes_list[base + i] = v if sizes_list[base + i] == -1 else sizes_list[base + i] - if v != 1 and sizes_list[base + i] != v: - raise ValueError( - f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' - f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') - return _return(jnp.broadcast_to(self.value, sizes_list)) - - def tree_flatten(self): - return (self.value,), None - - @classmethod - def tree_unflatten(cls, aux_data, flat_contents): - return cls(*flat_contents) - - def zero_(self): - self.value = jnp.zeros_like(self.value) - return self - - def fill_(self, value): - self.fill(value) - return self - - def uniform_(self, low=0., high=1.): - global bm - if bm is None: from brainpy import math as bm - self.value = bm.random.uniform(low, high, self.shape) - return self - - def log_normal_(self, mean=1, std=2): - r"""Fills self tensor with numbers samples from the log-normal distribution parameterized by the given mean - :math:`\mu` and standard deviation :math:`\sigma`. Note that mean and std are the mean and standard - deviation of the underlying normal distribution, and not of the returned distribution: - - .. math:: - - f(x)=\frac{1}{x \sigma \sqrt{2 \pi}} e^{-\frac{(\ln x-\mu)^2}{2 \sigma^2}} + If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it. + Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as + (1, 1, 2, 2). + + Note : Although tile may be used for broadcasting, it is strongly + recommended to use numpy's broadcasting operations and functions. - Args: - mean: the mean value. - std: the standard deviation. - """ - global bm - if bm is None: from brainpy import math as bm - self.value = bm.random.lognormal(mean, std, self.shape) - return self + Parameters:: - def normal_(self, ): - """ - Fills self tensor with elements samples from the normal distribution parameterized by mean and std. - """ - global bm - if bm is None: from brainpy import math as bm - self.value = bm.random.randn(*self.shape) - return self + reps : array_like + The number of repetitions of `A` along each axis. - def cuda(self): - self.value = jax.device_put(self.value, jax.devices('cuda')[0]) - return self + Returns:: - def cpu(self): - self.value = jax.device_put(self.value, jax.devices('cpu')[0]) - return self + c : ndarray + The tiled output array. + """ + return _return(self.value.tile(_as_jax_array_(reps))) - # dtype exchanging # - # ---------------- # + def var(self, axis=None, dtype=None, ddof=0, keepdims=False): + """Returns the variance of the array elements, along given axis.""" + r = self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) + return _return(r) - def bool(self): return jnp.asarray(self.value, dtype=jnp.bool_) - def int(self): return jnp.asarray(self.value, dtype=jnp.int32) - def long(self): return jnp.asarray(self.value, dtype=jnp.int64) - def half(self): return jnp.asarray(self.value, dtype=jnp.float16) - def float(self): return jnp.asarray(self.value, dtype=jnp.float32) - def double(self): return jnp.asarray(self.value, dtype=jnp.float64) + def view(self, *args, dtype=None): + r"""New view of array with the same data. + This function is compatible with pytorch syntax. + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`shape`. + + The returned tensor shares the same data and must have the same number + of elements, but may have a different size. For a tensor to be viewed, the new + view size must be compatible with its original size and stride, i.e., each new + view dimension must either be a subspace of an original dimension, or only span + across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following + contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`, + + .. math:: + + \text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1] + + Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape` + without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a + :meth:`view` can be performed, it is advisable to use :meth:`reshape`, which + returns a view if the shapes are compatible, and copies (equivalent to calling + :meth:`contiguous`) otherwise. + + Args: + shape (int...): the desired size + + Example:: + + >>> x = brainpy.math.random.randn(4, 4) + >>> x.size + [4, 4] + >>> y = x.view(16) + >>> y.size + [16] + >>> z = x.view(-1, 8) # the size -1 is inferred from other dimensions + >>> z.size + [2, 8] + + >>> a = brainpy.math.random.randn(1, 2, 3, 4) + >>> a.size + [1, 2, 3, 4] + >>> b = a.transpose(1, 2) # Swaps 2nd and 3rd dimension + >>> b.size + [1, 3, 2, 4] + >>> c = a.view(1, 3, 2, 4) # Does not change tensor layout in memory + >>> c.size + [1, 3, 2, 4] + >>> brainpy.math.equal(b, c) + False + + + .. method:: view(dtype) -> Tensor + :noindex: + + Returns a new tensor with the same data as the :attr:`self` tensor but of a + different :attr:`dtype`. + + If the element size of :attr:`dtype` is different than that of ``self.dtype``, + then the size of the last dimension of the output will be scaled + proportionally. For instance, if :attr:`dtype` element size is twice that of + ``self.dtype``, then each pair of elements in the last dimension of + :attr:`self` will be combined, and the size of the last dimension of the output + will be half that of :attr:`self`. If :attr:`dtype` element size is half that + of ``self.dtype``, then each element in the last dimension of :attr:`self` will + be split in two, and the size of the last dimension of the output will be + double that of :attr:`self`. For this to be possible, the following conditions + must be true: + + * ``self.dim()`` must be greater than 0. + * ``self.stride(-1)`` must be 1. + + Additionally, if the element size of :attr:`dtype` is greater than that of + ``self.dtype``, the following conditions must be true as well: + + * ``self.size(-1)`` must be divisible by the ratio between the element + sizes of the dtypes. + * ``self.storage_offset()`` must be divisible by the ratio between the + element sizes of the dtypes. + * The strides of all dimensions, except the last dimension, must be + divisible by the ratio between the element sizes of the dtypes. + + If any of the above conditions are not met, an error is thrown. + + + Args: + dtype (:class:`dtype`): the desired dtype + + Example:: + + >>> x = brainpy.math.random.randn(4, 4) + >>> x + Array([[ 0.9482, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + >>> x.dtype + brainpy.math.float32 + + >>> y = x.view(brainpy.math.int32) + >>> y + tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], + [-1105482831, 1061112040, 1057999968, -1084397505], + [-1071760287, -1123489973, -1097310419, -1084649136], + [-1101533110, 1073668768, -1082790149, -1088634448]], + dtype=brainpy.math.int32) + >>> y[0, 0] = 1000000000 + >>> x + tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], + [-0.1520, 0.7472, 0.5617, -0.8649], + [-2.4724, -0.0334, -0.2976, -0.8499], + [-0.2109, 1.9913, -0.9607, -0.6123]]) + + >>> x.view(brainpy.math.cfloat) + tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], + [-0.1520+0.7472j, 0.5617-0.8649j], + [-2.4724-0.0334j, -0.2976-0.8499j], + [-0.2109+1.9913j, -0.9607-0.6123j]]) + >>> x.view(brainpy.math.cfloat).size + [4, 2] + + >>> x.view(brainpy.math.uint8) + tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, + 8, 191], + [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, + 93, 191], + [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, + 89, 191], + [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, + 28, 191]], dtype=brainpy.math.uint8) + >>> x.view(brainpy.math.uint8).size + [4, 16] + + """ + if len(args) == 0: + if dtype is None: + raise ValueError('Provide dtype or shape.') + else: + return _return(self.value.view(dtype)) + else: + if isinstance(args[0], int): # shape + if dtype is not None: + raise ValueError('Provide one of dtype or shape. Not both.') + return _return(self.value.reshape(*args)) + else: # dtype + assert not isinstance(args[0], int) + assert dtype is None + return _return(self.value.view(args[0])) + + # ------------------ + # NumPy support + # ------------------ + + def numpy(self, dtype=None): + """Convert to numpy.ndarray.""" + # warnings.warn('Deprecated since 2.1.12. Please use ".to_numpy()" instead.', DeprecationWarning) + return np.asarray(self.value, dtype=dtype) + + def to_numpy(self, dtype=None): + """Convert to numpy.ndarray.""" + return np.asarray(self.value, dtype=dtype) + + def to_jax(self, dtype=None): + """Convert to jax.numpy.ndarray.""" + if dtype is None: + return self.value + else: + return jnp.asarray(self.value, dtype=dtype) + + def __array__(self, dtype=None): + """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" + return np.asarray(self.value, dtype=dtype) + + def __jax_array__(self): + return self.value + + def as_variable(self): + """As an instance of Variable.""" + global bm + if bm is None: from brainpy import math as bm + return bm.Variable(self) + + def __format__(self, specification): + return self.value.__format__(specification) + + def __bool__(self) -> bool: + return self.value.__bool__() + + def __float__(self): + return self.value.__float__() + + def __int__(self): + return self.value.__int__() + + def __complex__(self): + return self.value.__complex__() + + def __hex__(self): + assert self.ndim == 0, 'hex only works on scalar values' + return hex(self.value) # type: ignore + + def __oct__(self): + assert self.ndim == 0, 'oct only works on scalar values' + return oct(self.value) # type: ignore + + def __index__(self): + return operator.index(self.value) + + def __dlpack__(self): + from jax.dlpack import to_dlpack # pylint: disable=g-import-not-at-top + return to_dlpack(self.value) + + # ---------------------- + # PyTorch compatibility + # ---------------------- + + def unsqueeze(self, dim: int) -> 'Array': + """ + Array.unsqueeze(dim) -> Array, or so called Tensor + equals + Array.expand_dims(dim) + + See :func:`brainpy.math.unsqueeze` + """ + return _return(jnp.expand_dims(self.value, dim)) + + def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': + """ + self.expand_dims(axis: int|Sequence[int]) + + 1. 如果axis类型为int: + 返回一个在self基础上的第axis维度前插入一个维度Array, + axis<0表示倒数第|axis|维度, + 令n=len(self._value.shape),则axis的范围为[-(n+1),n] + + 2. 如果axis类型为Sequence[int]: + 则返回依次扩展axis[i]的结果, + 即self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])...expand_dims(axis[len(axis)-1]) + + + 1. If the type of axis is int: + + Returns an Array of dimensions inserted before the axis dimension based on self, + + The first | axis < 0 indicates the bottom axis | dimensions, + + Set n=len(self._value.shape), then axis has the range [-(n+1),n] + + + 2. If the type of axis is Sequence[int] : + + Returns the result of extending axis[i] in sequence, + + self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])... expand_dims(axis[len(axis)-1]) + + """ + return _return(jnp.expand_dims(self.value, axis)) + + def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': + """ + Expand an array to a shape of another array. + + Parameters:: + + array : Array + + Returns:: + + expanded : Array + A readonly view on the original array with the given shape of array. It is + typically not contiguous. Furthermore, more than one element of a + expanded array may refer to a single memory location. + """ + return _return(jnp.broadcast_to(self.value, array)) + + def pow(self, index: int): + return _return(self.value ** index) + + def addr( + self, + vec1: Union['Array', jax.Array, np.ndarray], + vec2: Union['Array', jax.Array, np.ndarray], + *, + beta: float = 1.0, + alpha: float = 1.0, + out: Optional[Union['Array', jax.Array, np.ndarray]] = None + ) -> Optional['Array']: + r"""Performs the outer-product of vectors ``vec1`` and ``vec2`` and adds it to the matrix ``input``. + + Optional values beta and alpha are scaling factors on the outer product + between vec1 and vec2 and the added matrix input respectively. + + .. math:: + + out = \beta \mathrm{input} + \alpha (\text{vec1} \bigtimes \text{vec2}) + + Args: + vec1: the first vector of the outer product + vec2: the second vector of the outer product + beta: multiplier for input + alpha: multiplier + out: the output tensor. + + """ + vec1 = _as_jax_array_(vec1) + vec2 = _as_jax_array_(vec2) + r = alpha * jnp.outer(vec1, vec2) + beta * self.value + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def addr_( + self, + vec1: Union['Array', jax.Array, np.ndarray], + vec2: Union['Array', jax.Array, np.ndarray], + *, + beta: float = 1.0, + alpha: float = 1.0 + ): + vec1 = _as_jax_array_(vec1) + vec2 = _as_jax_array_(vec2) + r = alpha * jnp.outer(vec1, vec2) + beta * self.value + self.value = r + return self + + def outer(self, other: Union['Array', jax.Array, np.ndarray]) -> 'Array': + other = _as_jax_array_(other) + return _return(jnp.outer(self.value, other.value)) + + def abs(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.abs(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def abs_(self): + """ + in-place version of Array.abs() + """ + self.value = jnp.abs(self.value) + return self + + def add_(self, value): + self.value += value + return self + + def absolute(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + """ + alias of Array.abs + """ + return self.abs(out=out) + + def absolute_(self): + """ + alias of Array.abs_() + """ + return self.abs_() + + def mul(self, value): + return _return(self.value * value) + + def mul_(self, value): + """ + In-place version of :meth:`~Array.mul`. + """ + self.value *= value + return self + + def multiply(self, value): # real signature unknown; restored from __doc__ + """ + multiply(value) -> Tensor + + See :func:`torch.multiply`. + """ + return self.value * value + + def multiply_(self, value): # real signature unknown; restored from __doc__ + """ + multiply_(value) -> Tensor + + In-place version of :meth:`~Tensor.multiply`. + """ + self.value *= value + return self + + def sin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.sin(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def sin_(self): + self.value = jnp.sin(self.value) + return self + + def cos_(self): + self.value = jnp.cos(self.value) + return self + + def cos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.cos(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def tan_(self): + self.value = jnp.tan(self.value) + return self + + def tan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tan(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def sinh_(self): + self.value = jnp.tanh(self.value) + return self + + def sinh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tanh(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def cosh_(self): + self.value = jnp.cosh(self.value) + return self + + def cosh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.cosh(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def tanh_(self): + self.value = jnp.tanh(self.value) + return self + + def tanh(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.tanh(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def arcsin_(self): + self.value = jnp.arcsin(self.value) + return self + + def arcsin(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arcsin(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def arccos_(self): + self.value = jnp.arccos(self.value) + return self + + def arccos(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arccos(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def arctan_(self): + self.value = jnp.arctan(self.value) + return self + + def arctan(self, *, out: Optional[Union['Array', jax.Array, np.ndarray]] = None) -> Optional['Array']: + r = jnp.arctan(self.value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def clamp( + self, + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + *, + out: Optional[Union['Array', jax.Array, np.ndarray]] = None + ) -> Optional['Array']: + """ + return the value between min_value and max_value, + if min_value is None, then no lower bound, + if max_value is None, then no upper bound. + """ + min_value = _as_jax_array_(min_value) + max_value = _as_jax_array_(max_value) + r = jnp.clip(self.value, max_value, max_value) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + + def clamp_(self, + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None): + """ + return the value between min_value and max_value, + if min_value is None, then no lower bound, + if max_value is None, then no upper bound. + """ + self.clamp(min_value, max_value, out=self) + return self + + def clip_(self, + min_value: Optional[Union['Array', jax.Array, np.ndarray]] = None, + max_value: Optional[Union['Array', jax.Array, np.ndarray]] = None): + """ + alias for clamp_ + """ + self.value = self.clip(min_value, max_value, out=self) + return self + + def clone(self) -> 'Array': + return _return(self.value.copy()) + + def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': + self.value = jnp.copy(_as_jax_array_(src)) + return self + + def cov_with( + self, + y: Optional[Union['Array', jax.Array, np.ndarray]] = None, + rowvar: bool = True, + bias: bool = False, + ddof: Optional[int] = None, + fweights: Union['Array', jax.Array, np.ndarray] = None, + aweights: Union['Array', jax.Array, np.ndarray] = None + ) -> 'Array': + y = _as_jax_array_(y) + fweights = _as_jax_array_(fweights) + aweights = _as_jax_array_(aweights) + r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) + return _return(r) + + def expand(self, *sizes) -> 'Array': + """ + Expand an array to a new shape. + + Parameters:: + + sizes : tuple or int + The shape of the desired array. A single integer ``i`` is interpreted + as ``(i,)``. + + Returns:: + + expanded : Array + A readonly view on the original array with the given shape. It is + typically not contiguous. Furthermore, more than one element of a + expanded array may refer to a single memory location. + """ + l_ori = len(self.shape) + l_tar = len(sizes) + base = l_tar - l_ori + sizes_list = list(sizes) + if base < 0: + raise ValueError(f'the number of sizes provided ({len(sizes)}) must be greater or equal to the number of ' + f'dimensions in the tensor ({len(self.shape)})') + for i, v in enumerate(sizes[:base]): + if v < 0: + raise ValueError( + f'The expanded size of the tensor ({v}) isn\'t allowed in a leading, non-existing dimension {i + 1}') + for i, v in enumerate(self.shape): + sizes_list[base + i] = v if sizes_list[base + i] == -1 else sizes_list[base + i] + if v != 1 and sizes_list[base + i] != v: + raise ValueError( + f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' + f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') + return _return(jnp.broadcast_to(self.value, sizes_list)) + + def tree_flatten(self): + return (self.value,), None + + @classmethod + def tree_unflatten(cls, aux_data, flat_contents): + return cls(*flat_contents) + + def zero_(self): + self.value = jnp.zeros_like(self.value) + return self + + def fill_(self, value): + self.fill(value) + return self + + def uniform_(self, low=0., high=1.): + global bm + if bm is None: from brainpy import math as bm + self.value = bm.random.uniform(low, high, self.shape) + return self + + def log_normal_(self, mean=1, std=2): + r"""Fills self tensor with numbers samples from the log-normal distribution parameterized by the given mean + :math:`\mu` and standard deviation :math:`\sigma`. Note that mean and std are the mean and standard + deviation of the underlying normal distribution, and not of the returned distribution: + + .. math:: + + f(x)=\frac{1}{x \sigma \sqrt{2 \pi}} e^{-\frac{(\ln x-\mu)^2}{2 \sigma^2}} + + Args: + mean: the mean value. + std: the standard deviation. + """ + global bm + if bm is None: from brainpy import math as bm + self.value = bm.random.lognormal(mean, std, self.shape) + return self + + def normal_(self, ): + """ + Fills self tensor with elements samples from the normal distribution parameterized by mean and std. + """ + global bm + if bm is None: from brainpy import math as bm + self.value = bm.random.randn(*self.shape) + return self + + def cuda(self): + self.value = jax.device_put(self.value, jax.devices('cuda')[0]) + return self + + def cpu(self): + self.value = jax.device_put(self.value, jax.devices('cpu')[0]) + return self + + # dtype exchanging # + # ---------------- # + + def bool(self): + return jnp.asarray(self.value, dtype=jnp.bool_) + + def int(self): + return jnp.asarray(self.value, dtype=jnp.int32) + + def long(self): + return jnp.asarray(self.value, dtype=jnp.int64) + + def half(self): + return jnp.asarray(self.value, dtype=jnp.float16) + + def float(self): + return jnp.asarray(self.value, dtype=jnp.float32) + + def double(self): + return jnp.asarray(self.value, dtype=jnp.float64) @register_pytree_node_class class Array(BaseArray): - """Multiple-dimensional array in BrainPy. - - Compared to ``jax.Array``, :py:class:`~.Array` has the following advantages: + """Multiple-dimensional array in BrainPy. - - In-place updating is supported. + Compared to ``jax.Array``, :py:class:`~.Array` has the following advantages: - >>> import brainpy.math as bm - >>> a = bm.asarray([1, 2, 3.]) - >>> a[0] = 10. + - In-place updating is supported. - - Keep sharding constraints during computation. + >>> import brainpy.math as bm + >>> a = bm.asarray([1, 2, 3.]) + >>> a[0] = 10. - - More dense array operations with PyTorch syntax. + - Keep sharding constraints during computation. - """ + - More dense array operations with PyTorch syntax. - __slots__ = ('_value', ) + """ - def __init__(self, value, dtype: Any = None): - # array value - if isinstance(value, BaseArray): - value = value._value - elif isinstance(value, (tuple, list, np.ndarray)): - value = jnp.asarray(value) - if dtype is not None: - value = jnp.asarray(value, dtype=dtype) - self._value = value + __slots__ = ('_value',) - def _check_tracer(self): - self_value = self.value - if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'): - if len(self_value._trace.main.jaxpr_stack) == 0: - raise jax.errors.UnexpectedTracerError('This Array is modified during the transformation. ' - 'BrainPy only supports transformations for Variable. ' - 'Please declare it as a Variable.') from jax.core.escaped_tracer_error(self_value, None) - return self_value + def __init__(self, value, dtype: Any = None): + # array value + if isinstance(value, BaseArray): + value = value._value + elif isinstance(value, (tuple, list, np.ndarray)): + value = jnp.asarray(value) + if dtype is not None: + value = jnp.asarray(value, dtype=dtype) + self._value = value + + def _check_tracer(self): + self_value = self.value + if hasattr(self_value, '_trace') and hasattr(self_value._trace.main, 'jaxpr_stack'): + if len(self_value._trace.main.jaxpr_stack) == 0: + raise jax.errors.UnexpectedTracerError('This Array is modified during the transformation. ' + 'BrainPy only supports transformations for Variable. ' + 'Please declare it as a Variable.') from jax.core.escaped_tracer_error( + self_value, None) + return self_value + + # @classmethod + # def __instancecheck__(cls, subclass): + # from brainpy.math import Variable + # if issubclass(subclass, Variable): + # return True + # if isinstance(subclass, Array): + # return True + # + # return False - # @classmethod - # def __instancecheck__(cls, subclass): - # from brainpy.math import Variable - # if issubclass(subclass, Variable): - # return True - # if isinstance(subclass, Array): - # return True - # - # return False setattr(Array, "__array_priority__", 100) @@ -1547,54 +1558,53 @@ def _check_tracer(self): @register_pytree_node_class class ShardedArray(Array): - """The sharded array, which stores data across multiple devices. - - A drawback of sharding is that the data may not be evenly distributed on shards. - - Args: - value: the array value. - dtype: the array type. - keep_sharding: keep the array sharding information using ``jax.lax.with_sharding_constraint``. Default True. - """ - - __slots__ = ('_value', '_keep_sharding') + """The sharded array, which stores data across multiple devices. - def __init__(self, value, dtype: Any = None, *, keep_sharding: bool = True): - super().__init__(value, dtype) - self._keep_sharding = keep_sharding + A drawback of sharding is that the data may not be evenly distributed on shards. - @property - def value(self): - """The value stored in this array. - - Returns: - The stored data. + Args: + value: the array value. + dtype: the array type. + keep_sharding: keep the array sharding information using ``jax.lax.with_sharding_constraint``. Default True. """ - v = self._value - # keep sharding constraints - if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None): - return jax.lax.with_sharding_constraint(v, v.sharding) - # return the value - return v - - @value.setter - def value(self, value): - self_value = self._check_tracer() - - if isinstance(value, BaseArray): - value = value.value - elif isinstance(value, np.ndarray): - value = jnp.asarray(value) - elif isinstance(value, jax.Array): - pass - else: - value = jnp.asarray(value) - # check - if value.shape != self_value.shape: - raise MathError(f"The shape of the original data is {self_value.shape}, " - f"while we got {value.shape}.") - if value.dtype != self_value.dtype: - raise MathError(f"The dtype of the original data is {self_value.dtype}, " - f"while we got {value.dtype}.") - self._value = value + __slots__ = ('_value', '_keep_sharding') + + def __init__(self, value, dtype: Any = None, *, keep_sharding: bool = True): + super().__init__(value, dtype) + self._keep_sharding = keep_sharding + + @property + def value(self): + """The value stored in this array. + + Returns: + The stored data. + """ + v = self._value + # keep sharding constraints + if self._keep_sharding and hasattr(v, 'sharding') and (v.sharding is not None): + return jax.lax.with_sharding_constraint(v, v.sharding) + # return the value + return v + + @value.setter + def value(self, value): + self_value = self._check_tracer() + + if isinstance(value, BaseArray): + value = value.value + elif isinstance(value, np.ndarray): + value = jnp.asarray(value) + elif isinstance(value, jax.Array): + pass + else: + value = jnp.asarray(value) + # check + if value.shape != self_value.shape: + raise MathError(f"The shape of the original data is {self_value.shape}, " + f"while we got {value.shape}.") + if value.dtype != self_value.dtype: + raise MathError(f"The dtype of the original data is {self_value.dtype}, " + f"while we got {value.dtype}.") + self._value = value diff --git a/brainpy/_src/math/object_transform/__init__.py b/brainpy/_src/math/object_transform/__init__.py index 714f781ca..3e509cb99 100644 --- a/brainpy/_src/math/object_transform/__init__.py +++ b/brainpy/_src/math/object_transform/__init__.py @@ -15,15 +15,11 @@ Details please see the following. """ +from .autograd import * from .base import * - -from .variables import * from .collectors import * - -from .autograd import * from .controls import * -from .jit import * from .function import * +from .jit import * from .naming import * - - +from .variables import * diff --git a/brainpy/_src/math/object_transform/_utils.py b/brainpy/_src/math/object_transform/_utils.py index 52942d9d2..a16520d63 100644 --- a/brainpy/_src/math/object_transform/_utils.py +++ b/brainpy/_src/math/object_transform/_utils.py @@ -1,29 +1,30 @@ # -*- coding: utf-8 -*- from typing import Dict + from .base import BrainPyObject, ArrayCollector __all__ = [ - 'infer_dyn_vars', - 'get_brainpy_object', + 'infer_dyn_vars', + 'get_brainpy_object', ] def infer_dyn_vars(target): - if isinstance(target, BrainPyObject): - dyn_vars = target.vars().unique() - elif hasattr(target, '__self__') and isinstance(target.__self__, BrainPyObject): - dyn_vars = target.__self__.vars().unique() - else: - dyn_vars = ArrayCollector() - return dyn_vars + if isinstance(target, BrainPyObject): + dyn_vars = target.vars().unique() + elif hasattr(target, '__self__') and isinstance(target.__self__, BrainPyObject): + dyn_vars = target.__self__.vars().unique() + else: + dyn_vars = ArrayCollector() + return dyn_vars def get_brainpy_object(target) -> Dict[str, BrainPyObject]: - if isinstance(target, BrainPyObject): - return {target.name: target} - elif hasattr(target, '__self__') and isinstance(target.__self__, BrainPyObject): - target = target.__self__ - return {target.name: target} - else: - return dict() + if isinstance(target, BrainPyObject): + return {target.name: target} + elif hasattr(target, '__self__') and isinstance(target.__self__, BrainPyObject): + target = target.__self__ + return {target.name: target} + else: + return dict() diff --git a/brainpy/_src/math/object_transform/collectors.py b/brainpy/_src/math/object_transform/collectors.py index 2f347c669..c7c20a72c 100644 --- a/brainpy/_src/math/object_transform/collectors.py +++ b/brainpy/_src/math/object_transform/collectors.py @@ -1,209 +1,209 @@ from typing import Sequence, Dict, Union -from jax.tree_util import register_pytree_node from brainstate._compatible_import import safe_zip +from jax.tree_util import register_pytree_node from .variables import Variable __all__ = [ - 'Collector', 'ArrayCollector', 'TensorCollector', + 'Collector', 'ArrayCollector', 'TensorCollector', ] class Collector(dict): - """A Collector is a dictionary (name, var) with some additional methods to make manipulation - of collections of variables easy. A Collector is ordered by insertion order. It is the object - returned by BrainPyObject.vars() and used as input in many Collector instance: optimizers, jit, etc...""" - - def __setitem__(self, key, value): - """Overload bracket assignment to catch potential conflicts during assignment.""" - if key in self: - if id(self[key]) != id(value): - raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.') - dict.__setitem__(self, key, value) - - def replace(self, key, new_value): - """Replace the original key with the new value.""" - self.pop(key) - self[key] = new_value - - def update(self, other, **kwargs): - assert isinstance(other, (dict, list, tuple)) - if isinstance(other, dict): - for key, value in other.items(): - self[key] = value - elif isinstance(other, (tuple, list)): - num = len(self) - for i, value in enumerate(other): - self[f'_var{i + num}'] = value - else: - raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}') - for key, value in kwargs.items(): - self[key] = value - return self - - def __add__(self, other): - """Merging two dicts. - - Parameters:: - - other: dict - The other dict instance. - - Returns:: - - gather: Collector - The new collector. - """ - gather = type(self)(self) - gather.update(other) - return gather - - def __sub__(self, other: Union[Dict, Sequence]): - """Remove other item in the collector. - - Parameters:: - - other: dict, sequence - The items to remove. - - Returns:: - - gather: Collector - The new collector. - """ - if not isinstance(other, (dict, tuple, list)): - raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.') - gather = type(self)(self) - if isinstance(other, dict): - for key, val in other.items(): - if key in gather: - if id(val) != id(gather[key]): - raise ValueError(f'Cannot remove {key}, because we got two different values: ' - f'{val} != {gather[key]}') - gather.pop(key) + """A Collector is a dictionary (name, var) with some additional methods to make manipulation + of collections of variables easy. A Collector is ordered by insertion order. It is the object + returned by BrainPyObject.vars() and used as input in many Collector instance: optimizers, jit, etc...""" + + def __setitem__(self, key, value): + """Overload bracket assignment to catch potential conflicts during assignment.""" + if key in self: + if id(self[key]) != id(value): + raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.') + dict.__setitem__(self, key, value) + + def replace(self, key, new_value): + """Replace the original key with the new value.""" + self.pop(key) + self[key] = new_value + + def update(self, other, **kwargs): + assert isinstance(other, (dict, list, tuple)) + if isinstance(other, dict): + for key, value in other.items(): + self[key] = value + elif isinstance(other, (tuple, list)): + num = len(self) + for i, value in enumerate(other): + self[f'_var{i + num}'] = value else: - raise ValueError(f'Cannot remove {key}, because we do not find it ' - f'in {self.keys()}.') - elif isinstance(other, (list, tuple)): - id_to_keys = {} - for k, v in self.items(): - id_ = id(v) - if id_ not in id_to_keys: - id_to_keys[id_] = [] - id_to_keys[id_].append(k) - - keys_to_remove = [] - for key in other: - if isinstance(key, str): - keys_to_remove.append(key) - else: - keys_to_remove.extend(id_to_keys[id(key)]) - - for key in set(keys_to_remove): - if key in gather: - gather.pop(key) + raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}') + for key, value in kwargs.items(): + self[key] = value + return self + + def __add__(self, other): + """Merging two dicts. + + Parameters:: + + other: dict + The other dict instance. + + Returns:: + + gather: Collector + The new collector. + """ + gather = type(self)(self) + gather.update(other) + return gather + + def __sub__(self, other: Union[Dict, Sequence]): + """Remove other item in the collector. + + Parameters:: + + other: dict, sequence + The items to remove. + + Returns:: + + gather: Collector + The new collector. + """ + if not isinstance(other, (dict, tuple, list)): + raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.') + gather = type(self)(self) + if isinstance(other, dict): + for key, val in other.items(): + if key in gather: + if id(val) != id(gather[key]): + raise ValueError(f'Cannot remove {key}, because we got two different values: ' + f'{val} != {gather[key]}') + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + elif isinstance(other, (list, tuple)): + id_to_keys = {} + for k, v in self.items(): + id_ = id(v) + if id_ not in id_to_keys: + id_to_keys[id_] = [] + id_to_keys[id_].append(k) + + keys_to_remove = [] + for key in other: + if isinstance(key, str): + keys_to_remove.append(key) + else: + keys_to_remove.extend(id_to_keys[id(key)]) + + for key in set(keys_to_remove): + if key in gather: + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') else: - raise ValueError(f'Cannot remove {key}, because we do not find it ' - f'in {self.keys()}.') - else: - raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}') - return gather - - def subset(self, var_type): - """Get the subset of the (key, value) pair. - - ``subset()`` can be used to get a subset of some class: - - >>> import brainpy as bp - >>> - >>> some_collector = Collector() - >>> - >>> # get all trainable variables - >>> some_collector.subset(bp.math.TrainVar) - >>> - >>> # get all Variable - >>> some_collector.subset(bp.math.Variable) - - or, it can be used to get a subset of integrators: - - >>> # get all ODE integrators - >>> some_collector.subset(bp.ode.ODEIntegrator) - - Parameters:: - - var_type : type - The type/class to match. - """ - gather = type(self)() - for key, value in self.items(): - if isinstance(value, var_type): - gather[key] = value - return gather - - def not_subset(self, var_type): - gather = type(self)() - for key, value in self.items(): - if not isinstance(value, var_type): - gather[key] = value - return gather - - def include(self, *types): - gather = type(self)() - for key, value in self.items(): - if value.__class__ in types: - gather[key] = value - return gather - - def exclude(self, *types): - gather = type(self)() - for key, value in self.items(): - if value.__class__ not in types: - gather[key] = value - return gather - - def unique(self): - """Get a new type of collector with unique values. - - If one value is assigned to two or more keys, - then only one pair of (key, value) will be returned. - """ - gather = type(self)() - seen = set() - for k, v in self.items(): - if id(v) not in seen: - seen.add(id(v)) - gather[k] = v - return gather + raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}') + return gather + + def subset(self, var_type): + """Get the subset of the (key, value) pair. + + ``subset()`` can be used to get a subset of some class: + + >>> import brainpy as bp + >>> + >>> some_collector = Collector() + >>> + >>> # get all trainable variables + >>> some_collector.subset(bp.math.TrainVar) + >>> + >>> # get all Variable + >>> some_collector.subset(bp.math.Variable) + + or, it can be used to get a subset of integrators: + + >>> # get all ODE integrators + >>> some_collector.subset(bp.ode.ODEIntegrator) + + Parameters:: + + var_type : type + The type/class to match. + """ + gather = type(self)() + for key, value in self.items(): + if isinstance(value, var_type): + gather[key] = value + return gather + + def not_subset(self, var_type): + gather = type(self)() + for key, value in self.items(): + if not isinstance(value, var_type): + gather[key] = value + return gather + + def include(self, *types): + gather = type(self)() + for key, value in self.items(): + if value.__class__ in types: + gather[key] = value + return gather + + def exclude(self, *types): + gather = type(self)() + for key, value in self.items(): + if value.__class__ not in types: + gather[key] = value + return gather + + def unique(self): + """Get a new type of collector with unique values. + + If one value is assigned to two or more keys, + then only one pair of (key, value) will be returned. + """ + gather = type(self)() + seen = set() + for k, v in self.items(): + if id(v) not in seen: + seen.add(id(v)) + gather[k] = v + return gather class ArrayCollector(Collector): - def __setitem__(self, key, value): - """Overload bracket assignment to catch potential conflicts during assignment.""" + def __setitem__(self, key, value): + """Overload bracket assignment to catch potential conflicts during assignment.""" - assert isinstance(value, Variable), type(value) - if key in self: - if id(self[key]) != id(value): - raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.') - dict.__setitem__(self, key, value) + assert isinstance(value, Variable), type(value) + if key in self: + if id(self[key]) != id(value): + raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.') + dict.__setitem__(self, key, value) - def dict(self): - """Get a dict with the key and the value data. - """ - gather = dict() - for k, v in self.items(): - gather[k] = v.value - return gather + def dict(self): + """Get a dict with the key and the value data. + """ + gather = dict() + for k, v in self.items(): + gather[k] = v.value + return gather - def data(self): - """Get all data in each value.""" - return [x.value for x in self.values()] + def data(self): + """Get all data in each value.""" + return [x.value for x in self.values()] TensorCollector = ArrayCollector register_pytree_node( - ArrayCollector, - lambda x: (x.values(), x.keys()), - lambda keys, values: ArrayCollector(safe_zip(keys, values)) + ArrayCollector, + lambda x: (x.values(), x.keys()), + lambda keys, values: ArrayCollector(safe_zip(keys, values)) ) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 9de0260ed..35e30e2a6 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -3,9 +3,9 @@ import numbers from typing import Union, Sequence, Any, Dict, Callable, Optional +import brainstate import jax.numpy as jnp -import brainstate from brainpy._src.math.ndarray import Array __all__ = [ @@ -137,16 +137,16 @@ def ifelse( operands = () elif not isinstance(operands, (tuple, list)): operands = (operands,) - + # Convert non-callable branches to callables def make_callable(branch): if callable(branch): return branch else: return lambda *args: branch - + branches = [make_callable(branch) for branch in branches] - + # Convert if-elif-else chain to mutually exclusive conditions if isinstance(conditions, (list, tuple)) and len(conditions) > 0: conditions = list(conditions) @@ -161,7 +161,7 @@ def make_callable(branch): for j in range(1, i): prev_conds_false = prev_conds_false & jnp.logical_not(conditions[j]) exclusive_conditions.append(cond & prev_conds_false) - + # If we have equal number of branches and conditions, the last branch is the default case if len(branches) == len(conditions): # Replace the last condition with "all previous conditions are false" @@ -175,9 +175,9 @@ def make_callable(branch): for cond in conditions[1:]: all_false = all_false & jnp.logical_not(cond) exclusive_conditions.append(all_false) - + conditions = exclusive_conditions - + return brainstate.transform.ifelse(conditions, branches, *operands) diff --git a/brainpy/_src/math/object_transform/function.py b/brainpy/_src/math/object_transform/function.py index a7ac99435..67e6d2608 100644 --- a/brainpy/_src/math/object_transform/function.py +++ b/brainpy/_src/math/object_transform/function.py @@ -3,34 +3,34 @@ import warnings from typing import Union, Sequence, Dict, Callable -from .variables import Variable from .base import FunAsObject, BrainPyObject +from .variables import Variable __all__ = [ - 'Partial', - 'to_object', - 'function', + 'Partial', + 'to_object', + 'function', ] class Partial(FunAsObject): - def __init__( - self, - fun: Callable, - *args, - child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None, - dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, - **keywords - ): - super().__init__(target=fun, child_objs=child_objs, dyn_vars=dyn_vars) + def __init__( + self, + fun: Callable, + *args, + child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None, + dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, + **keywords + ): + super().__init__(target=fun, child_objs=child_objs, dyn_vars=dyn_vars) - self.fun = fun - self.args = args - self.keywords = keywords + self.fun = fun + self.args = args + self.keywords = keywords - def __call__(self, *args, **keywords): - keywords = {**self.keywords, **keywords} - return self.fun(*self.args, *args, **keywords) + def __call__(self, *args, **keywords): + keywords = {**self.keywords, **keywords} + return self.fun(*self.args, *args, **keywords) def to_object( @@ -39,35 +39,35 @@ def to_object( dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, name: str = None ): - """Transform a Python function to :py:class:`~.BrainPyObject`. + """Transform a Python function to :py:class:`~.BrainPyObject`. - Parameters:: + Parameters:: - f: function, callable - The python function. - child_objs: Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject - The children objects used in this Python function. - dyn_vars: Variable, sequence of Variable, dict of Variable - The `Variable` instance used in the Python function. - name: str - The name of the created ``BrainPyObject``. + f: function, callable + The python function. + child_objs: Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject + The children objects used in this Python function. + dyn_vars: Variable, sequence of Variable, dict of Variable + The `Variable` instance used in the Python function. + name: str + The name of the created ``BrainPyObject``. - Returns:: + Returns:: - func: FunAsObject - The instance of ``BrainPyObject``. - """ + func: FunAsObject + The instance of ``BrainPyObject``. + """ - if f is None: - def wrap(func) -> FunAsObject: - return FunAsObject(target=func, child_objs=child_objs, dyn_vars=dyn_vars, name=name) + if f is None: + def wrap(func) -> FunAsObject: + return FunAsObject(target=func, child_objs=child_objs, dyn_vars=dyn_vars, name=name) - return wrap + return wrap - else: - if child_objs is None: - raise ValueError(f'"child_objs" cannot be None when "f" is provided.') - return FunAsObject(target=f, child_objs=child_objs, dyn_vars=dyn_vars, name=name) + else: + if child_objs is None: + raise ValueError(f'"child_objs" cannot be None when "f" is provided.') + return FunAsObject(target=f, child_objs=child_objs, dyn_vars=dyn_vars, name=name) def function( @@ -76,27 +76,27 @@ def function( dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, name: str = None ): - """Transform a Python function into a :py:class:`~.BrainPyObject`. - - .. deprecated:: 2.3.0 - Using :py:func:`~.to_object` instead. - - Parameters:: - - f: function, callable - The python function. - nodes: Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject - The children objects used in this Python function. - dyn_vars: Variable, sequence of Variable, dict of Variable - The `Variable` instance used in the Python function. - name: str - The name of the created ``BrainPyObject``. - - Returns:: - - func: FunAsObject - The instance of ``BrainPyObject``. - """ - warnings.warn('Using `brainpy.math.to_object()` instead. Will be removed after version 2.4.0.', - UserWarning) - return to_object(f, nodes, dyn_vars, name) + """Transform a Python function into a :py:class:`~.BrainPyObject`. + + .. deprecated:: 2.3.0 + Using :py:func:`~.to_object` instead. + + Parameters:: + + f: function, callable + The python function. + nodes: Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject + The children objects used in this Python function. + dyn_vars: Variable, sequence of Variable, dict of Variable + The `Variable` instance used in the Python function. + name: str + The name of the created ``BrainPyObject``. + + Returns:: + + func: FunAsObject + The instance of ``BrainPyObject``. + """ + warnings.warn('Using `brainpy.math.to_object()` instead. Will be removed after version 2.4.0.', + UserWarning) + return to_object(f, nodes, dyn_vars, name) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index b567c93e3..14baa4853 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -9,9 +9,8 @@ from typing import Callable, Union, Optional, Sequence, Any, Iterable -import jax.tree - import brainstate.transform +import jax.tree from brainstate.compile._jit import Missing __all__ = [ diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 717b9af8a..2dff7f910 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -5,66 +5,64 @@ from brainpy import errors __all__ = [ - 'clear_name_cache', + 'clear_name_cache', ] - _name2id = dict() _typed_names = {} def check_name_uniqueness(name, obj): - """Check the uniqueness of the name for the object type.""" - if not name.isidentifier(): - raise errors.BrainPyError(f'"{name}" isn\'t a valid identifier ' - f'according to Python language definition. ' - f'Please choose another name.') - if name in _name2id: - if _name2id[name] != id(obj): - raise errors.UniqueNameError( - f'In BrainPy, each object should have a unique name. ' - f'However, we detect that {obj} has a used name "{name}". \n' - f'If you try to run multiple trials, you may need \n\n' - f'>>> brainpy.math.clear_name_cache() \n\n' - f'to clear all cached names. ' - ) - else: - _name2id[name] = id(obj) + """Check the uniqueness of the name for the object type.""" + if not name.isidentifier(): + raise errors.BrainPyError(f'"{name}" isn\'t a valid identifier ' + f'according to Python language definition. ' + f'Please choose another name.') + if name in _name2id: + if _name2id[name] != id(obj): + raise errors.UniqueNameError( + f'In BrainPy, each object should have a unique name. ' + f'However, we detect that {obj} has a used name "{name}". \n' + f'If you try to run multiple trials, you may need \n\n' + f'>>> brainpy.math.clear_name_cache() \n\n' + f'to clear all cached names. ' + ) + else: + _name2id[name] = id(obj) def get_unique_name(type_: str): - """Get the unique name for the given object type.""" - if type_ not in _typed_names: - _typed_names[type_] = 0 - name = f'{type_}{_typed_names[type_]}' - _typed_names[type_] += 1 - return name + """Get the unique name for the given object type.""" + if type_ not in _typed_names: + _typed_names[type_] = 0 + name = f'{type_}{_typed_names[type_]}' + _typed_names[type_] += 1 + return name def clear_name_cache(ignore_warn=True): - """Clear the cached names.""" - _name2id.clear() - _typed_names.clear() - if not ignore_warn: - warnings.warn(f'All named models and their ids are cleared.', UserWarning) + """Clear the cached names.""" + _name2id.clear() + _typed_names.clear() + if not ignore_warn: + warnings.warn(f'All named models and their ids are cleared.', UserWarning) _fun2stack = dict() def cache_stack(func, stack): - _fun2stack[func] = stack + _fun2stack[func] = stack def clear_stack_cache(): - """Clear the cached stack.""" - for k in tuple(_fun2stack.keys()): - del _fun2stack[k] + """Clear the cached stack.""" + for k in tuple(_fun2stack.keys()): + del _fun2stack[k] def get_stack_cache(func): - if func in _fun2stack: - return _fun2stack[func] - else: - return None - + if func in _fun2stack: + return _fun2stack[func] + else: + return None diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index d4209129e..34cb2430c 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -import functools -import tempfile import unittest from pprint import pprint @@ -15,371 +13,370 @@ class TestPureFuncGrad(unittest.TestCase): - def test_grad_pure_func_1(self): - def call(a, b, c): return bm.sum(a + b + c) - - bm.random.seed(1) - a = bm.ones(10) - b = bm.random.randn(10) - c = bm.random.uniform(size=10) - f_grad = bm.grad(call, argnums=[0, 1, 2]) - grads = f_grad(a, b, c) - - for g in grads: assert (g == 1.).all() - - def test_grad_pure_func_2(self): - def call(a, b, c): return bm.sum(a + b + c) - - bm.random.seed(1) - a = bm.ones(10) - b = bm.random.randn(10) - c = bm.random.uniform(size=10) - f_grad = bm.grad(call) - assert (f_grad(a, b, c) == 1.).all() - - def test_grad_pure_func_aux1(self): - def call(a, b, c): - return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(1) - f_grad = bm.grad(call, argnums=[0, 1, 2]) - with pytest.raises(TypeError): - f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) - - def test_grad_pure_func_aux2(self): - def call(a, b, c): - return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(1) - f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True) - grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) - for g in grads: assert (g == 1.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - def test_grad_pure_func_return1(self): - def call(a, b, c): return bm.sum(a + b + c) - - bm.random.seed(1) - a = bm.ones(10) - b = bm.random.randn(10) - c = bm.random.uniform(size=10) - f_grad = bm.grad(call, return_value=True) - grads, returns = f_grad(a, b, c) - assert (grads == 1.).all() - assert returns == bm.sum(a + b + c) - - def test_grad_func_return_aux1(self): - def call(a, b, c): - return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(1) - a = bm.ones(10) - b = bm.random.randn(10) - c = bm.random.uniform(size=10) - f_grad = bm.grad(call, return_value=True, has_aux=True) - grads, returns, aux = f_grad(a, b, c) - assert (grads == 1.).all() - assert returns == bm.sum(a + b + c) - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - - def test_grad_jit(self): - def call(a, b, c): return bm.sum(a + b + c) - - bm.random.seed(1) - a = bm.ones(10) - b = bm.random.randn(10) - c = bm.random.uniform(size=10) - f_grad = bm.jit(bm.grad(call)) - assert (f_grad(a, b, c) == 1.).all() + def test_grad_pure_func_1(self): + def call(a, b, c): return bm.sum(a + b + c) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.grad(call, argnums=[0, 1, 2]) + grads = f_grad(a, b, c) + + for g in grads: assert (g == 1.).all() + + def test_grad_pure_func_2(self): + def call(a, b, c): return bm.sum(a + b + c) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.grad(call) + assert (f_grad(a, b, c) == 1.).all() + + def test_grad_pure_func_aux1(self): + def call(a, b, c): + return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(1) + f_grad = bm.grad(call, argnums=[0, 1, 2]) + with pytest.raises(TypeError): + f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) + + def test_grad_pure_func_aux2(self): + def call(a, b, c): + return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(1) + f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True) + grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) + for g in grads: assert (g == 1.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + def test_grad_pure_func_return1(self): + def call(a, b, c): return bm.sum(a + b + c) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.grad(call, return_value=True) + grads, returns = f_grad(a, b, c) + assert (grads == 1.).all() + assert returns == bm.sum(a + b + c) + + def test_grad_func_return_aux1(self): + def call(a, b, c): + return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.grad(call, return_value=True, has_aux=True) + grads, returns, aux = f_grad(a, b, c) + assert (grads == 1.).all() + assert returns == bm.sum(a + b + c) + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + def test_grad_jit(self): + def call(a, b, c): return bm.sum(a + b + c) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.jit(bm.grad(call)) + assert (f_grad(a, b, c) == 1.).all() class TestObjectFuncGrad(unittest.TestCase): - def test_grad_ob1(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self): - return bm.sum(self.a + self.b + self.c) - - bm.random.seed(0) - - t = Test() - f_grad = bm.grad(t, grad_vars={'a': t.a, 'b': t.b, 'c': t.c}) - grads = f_grad() - for g in grads.values(): - assert (g == 1.).all() - - t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b]) - grads = f_grad() - for g in grads: assert (g == 1.).all() - - t = Test() - f_grad = bm.grad(t, grad_vars=t.a) - grads = f_grad() - assert (grads == 1.).all() - - def test_grad_ob_aux(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self): - return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(0) - t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], has_aux=True) - grads, aux = f_grad() - for g in grads: assert (g == 1.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - t = Test() - f_grad = bm.grad(t, grad_vars=t.a, has_aux=True) - grads, aux = f_grad() - assert (grads == 1.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - def test_grad_ob_return(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self): - return bm.sum(self.a + self.b + self.c) - - bm.random.seed(0) - t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], return_value=True) - grads, returns = f_grad() - for g in grads: assert (g == 1.).all() - assert returns == t() - - t = Test() - f_grad = bm.grad(t, grad_vars=t.a, return_value=True) - grads, returns = f_grad() - assert (grads == 1.).all() - assert returns == t() - - def test_grad_ob_aux_return(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self): - return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(0) - t = Test() - f_grad = bm.grad(t, grad_vars=[t.a, t.b], - has_aux=True, return_value=True) - grads, returns, aux = f_grad() - for g in grads: assert (g == 1.).all() - assert returns == bm.sum(t.a + t.b + t.c) - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - t = Test() - f_grad = bm.grad(t, grad_vars=t.a, - has_aux=True, return_value=True) - grads, returns, aux = f_grad() - assert (grads == 1.).all() - assert returns == bm.sum(t.a + t.b + t.c) - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - def test_grad_ob_argnums(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - bp.math.random.seed() - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self, d): - return bm.sum(self.a + self.b + self.c + 2 * d) - - bm.random.seed(0) - - t = Test() - f_grad = bm.grad(t, t.vars(), argnums=0) - var_grads, arg_grads = f_grad(bm.random.random(10)) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads == 2.).all() - - t = Test() - f_grad = bm.grad(t, t.vars(), argnums=[0]) - var_grads, arg_grads = f_grad(bm.random.random(10)) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads[0] == 2.).all() - - t = Test() - f_grad = bm.grad(t, argnums=0) - arg_grads = f_grad(bm.random.random(10)) - assert (arg_grads == 2.).all() - - t = Test() - f_grad = bm.grad(t, argnums=[0]) - arg_grads = f_grad(bm.random.random(10)) - assert (arg_grads[0] == 2.).all() - - def test_grad_ob_argnums_aux(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self, d): - return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(0) - - t = Test() - f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True) - (var_grads, arg_grads), aux = f_grad(bm.random.random(10)) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - t = Test() - f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True) - (var_grads, arg_grads), aux = f_grad(bm.random.random(10)) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads[0] == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - t = Test() - f_grad = bm.grad(t, argnums=0, has_aux=True) - arg_grads, aux = f_grad(bm.random.random(10)) - assert (arg_grads == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - t = Test() - f_grad = bm.grad(t, argnums=[0], has_aux=True) - arg_grads, aux = f_grad(bm.random.random(10)) - assert (arg_grads[0] == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - - def test_grad_ob_argnums_return(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self, d): - return bm.sum(self.a + self.b + self.c + 2 * d) - - bm.random.seed(0) - - t = Test() - f_grad = bm.grad(t, t.vars(), argnums=0, return_value=True) - d = bm.random.random(10) - (var_grads, arg_grads), loss = f_grad(d) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads == 2.).all() - assert loss == t(d) - - t = Test() - f_grad = bm.grad(t, t.vars(), argnums=[0], return_value=True) - d = bm.random.random(10) - (var_grads, arg_grads), loss = f_grad(d) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads[0] == 2.).all() - assert loss == t(d) - - t = Test() - f_grad = bm.grad(t, argnums=0, return_value=True) - d = bm.random.random(10) - arg_grads, loss = f_grad(d) - assert (arg_grads == 2.).all() - assert loss == t(d) - - t = Test() - f_grad = bm.grad(t, argnums=[0], return_value=True) - d = bm.random.random(10) - arg_grads, loss = f_grad(d) - assert (arg_grads[0] == 2.).all() - assert loss == t(d) - - def test_grad_ob_argnums_aux_return(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.a = bm.TrainVar(bm.ones(10)) - self.b = bm.TrainVar(bm.random.randn(10)) - self.c = bm.TrainVar(bm.random.uniform(size=10)) - - def __call__(self, d): - return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) - - bm.random.seed(0) - - t = Test() - f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True) - d = bm.random.random(10) - (var_grads, arg_grads), loss, aux = f_grad(d) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - assert loss == t(d)[0] - - t = Test() - f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) - d = bm.random.random(10) - (var_grads, arg_grads), loss, aux = f_grad(d) - for g in var_grads.values(): assert (g == 1.).all() - assert (arg_grads[0] == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - assert loss == t(d)[0] - - t = Test() - f_grad = bm.grad(t, argnums=0, has_aux=True, return_value=True) - d = bm.random.random(10) - arg_grads, loss, aux = f_grad(d) - assert (arg_grads == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - assert loss == t(d)[0] - - t = Test() - f_grad = bm.grad(t, argnums=[0], has_aux=True, return_value=True) - d = bm.random.random(10) - arg_grads, loss, aux = f_grad(d) - assert (arg_grads[0] == 2.).all() - assert aux[0] == bm.sin(100) - assert aux[1] == bm.exp(0.1) - assert loss == t(d)[0] + def test_grad_ob1(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self): + return bm.sum(self.a + self.b + self.c) + + bm.random.seed(0) + + t = Test() + f_grad = bm.grad(t, grad_vars={'a': t.a, 'b': t.b, 'c': t.c}) + grads = f_grad() + for g in grads.values(): + assert (g == 1.).all() + + t = Test() + f_grad = bm.grad(t, grad_vars=[t.a, t.b]) + grads = f_grad() + for g in grads: assert (g == 1.).all() + + t = Test() + f_grad = bm.grad(t, grad_vars=t.a) + grads = f_grad() + assert (grads == 1.).all() + + def test_grad_ob_aux(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self): + return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(0) + t = Test() + f_grad = bm.grad(t, grad_vars=[t.a, t.b], has_aux=True) + grads, aux = f_grad() + for g in grads: assert (g == 1.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + t = Test() + f_grad = bm.grad(t, grad_vars=t.a, has_aux=True) + grads, aux = f_grad() + assert (grads == 1.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + def test_grad_ob_return(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self): + return bm.sum(self.a + self.b + self.c) + + bm.random.seed(0) + t = Test() + f_grad = bm.grad(t, grad_vars=[t.a, t.b], return_value=True) + grads, returns = f_grad() + for g in grads: assert (g == 1.).all() + assert returns == t() + + t = Test() + f_grad = bm.grad(t, grad_vars=t.a, return_value=True) + grads, returns = f_grad() + assert (grads == 1.).all() + assert returns == t() + + def test_grad_ob_aux_return(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self): + return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(0) + t = Test() + f_grad = bm.grad(t, grad_vars=[t.a, t.b], + has_aux=True, return_value=True) + grads, returns, aux = f_grad() + for g in grads: assert (g == 1.).all() + assert returns == bm.sum(t.a + t.b + t.c) + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + t = Test() + f_grad = bm.grad(t, grad_vars=t.a, + has_aux=True, return_value=True) + grads, returns, aux = f_grad() + assert (grads == 1.).all() + assert returns == bm.sum(t.a + t.b + t.c) + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + def test_grad_ob_argnums(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + bp.math.random.seed() + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self, d): + return bm.sum(self.a + self.b + self.c + 2 * d) + + bm.random.seed(0) + + t = Test() + f_grad = bm.grad(t, t.vars(), argnums=0) + var_grads, arg_grads = f_grad(bm.random.random(10)) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads == 2.).all() + + t = Test() + f_grad = bm.grad(t, t.vars(), argnums=[0]) + var_grads, arg_grads = f_grad(bm.random.random(10)) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads[0] == 2.).all() + + t = Test() + f_grad = bm.grad(t, argnums=0) + arg_grads = f_grad(bm.random.random(10)) + assert (arg_grads == 2.).all() + + t = Test() + f_grad = bm.grad(t, argnums=[0]) + arg_grads = f_grad(bm.random.random(10)) + assert (arg_grads[0] == 2.).all() + + def test_grad_ob_argnums_aux(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self, d): + return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(0) + + t = Test() + f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True) + (var_grads, arg_grads), aux = f_grad(bm.random.random(10)) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + t = Test() + f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True) + (var_grads, arg_grads), aux = f_grad(bm.random.random(10)) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads[0] == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + t = Test() + f_grad = bm.grad(t, argnums=0, has_aux=True) + arg_grads, aux = f_grad(bm.random.random(10)) + assert (arg_grads == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + t = Test() + f_grad = bm.grad(t, argnums=[0], has_aux=True) + arg_grads, aux = f_grad(bm.random.random(10)) + assert (arg_grads[0] == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + + def test_grad_ob_argnums_return(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self, d): + return bm.sum(self.a + self.b + self.c + 2 * d) + + bm.random.seed(0) + + t = Test() + f_grad = bm.grad(t, t.vars(), argnums=0, return_value=True) + d = bm.random.random(10) + (var_grads, arg_grads), loss = f_grad(d) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads == 2.).all() + assert loss == t(d) + + t = Test() + f_grad = bm.grad(t, t.vars(), argnums=[0], return_value=True) + d = bm.random.random(10) + (var_grads, arg_grads), loss = f_grad(d) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads[0] == 2.).all() + assert loss == t(d) + + t = Test() + f_grad = bm.grad(t, argnums=0, return_value=True) + d = bm.random.random(10) + arg_grads, loss = f_grad(d) + assert (arg_grads == 2.).all() + assert loss == t(d) + + t = Test() + f_grad = bm.grad(t, argnums=[0], return_value=True) + d = bm.random.random(10) + arg_grads, loss = f_grad(d) + assert (arg_grads[0] == 2.).all() + assert loss == t(d) + + def test_grad_ob_argnums_aux_return(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.a = bm.TrainVar(bm.ones(10)) + self.b = bm.TrainVar(bm.random.randn(10)) + self.c = bm.TrainVar(bm.random.uniform(size=10)) + + def __call__(self, d): + return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) + + bm.random.seed(0) + + t = Test() + f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True) + d = bm.random.random(10) + (var_grads, arg_grads), loss, aux = f_grad(d) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + assert loss == t(d)[0] + + t = Test() + f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) + d = bm.random.random(10) + (var_grads, arg_grads), loss, aux = f_grad(d) + for g in var_grads.values(): assert (g == 1.).all() + assert (arg_grads[0] == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + assert loss == t(d)[0] + + t = Test() + f_grad = bm.grad(t, argnums=0, has_aux=True, return_value=True) + d = bm.random.random(10) + arg_grads, loss, aux = f_grad(d) + assert (arg_grads == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + assert loss == t(d)[0] + + t = Test() + f_grad = bm.grad(t, argnums=[0], has_aux=True, return_value=True) + d = bm.random.random(10) + arg_grads, loss, aux = f_grad(d) + assert (arg_grads[0] == 2.).all() + assert aux[0] == bm.sin(100) + assert aux[1] == bm.exp(0.1) + assert loss == t(d)[0] # class TestPureFuncJacobian(unittest.TestCase): @@ -566,522 +563,521 @@ def __call__(self, d): class TestClassFuncJacobian(unittest.TestCase): - def test_jacrev1(self): - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - self.y = bm.Variable(bm.array([10., 5.])) - - def __call__(self, ): - a = self.x[0] * self.y[0] - b = 5 * self.x[2] * self.y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r - - _jr = jax.jacrev(f1)(_x, _y) - t = Test() - br = bm.jacrev(t, grad_vars=t.x)() - self.assertTrue((br == _jr).all()) - - _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) - t = Test() - br = bm.jacrev(t, grad_vars=[t.x, t.y])() - self.assertTrue((br[0] == _jr[0]).all()) - self.assertTrue((br[1] == _jr[1]).all()) - - def test_jacfwd1(self): - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - self.y = bm.Variable(bm.array([10., 5.])) - - def __call__(self, ): - a = self.x[0] * self.y[0] - b = 5 * self.x[2] * self.y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r - - _jr = jax.jacfwd(f1)(_x, _y) - t = Test() - br = bm.jacfwd(t, grad_vars=t.x)() - self.assertTrue((br == _jr).all()) - - _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) - t = Test() - br = bm.jacfwd(t, grad_vars=[t.x, t.y])() - self.assertTrue((br[0] == _jr[0]).all()) - self.assertTrue((br[1] == _jr[1]).all()) - - def test_jacrev2(self): - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - - def __call__(self, y): - a = self.x[0] * y[0] - b = 5 * self.x[2] * y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r - - _jr = jax.jacrev(f1)(_x, _y) - t = Test() - br = bm.jacrev(t, grad_vars=t.x)(_y) - self.assertTrue((br == _jr).all()) - - _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) - t = Test() - var_grads, arg_grads = bm.jacrev(t, grad_vars=t.x, argnums=0)(_y) - print(var_grads, ) - print(arg_grads, ) - self.assertTrue((var_grads == _jr[0]).all()) - self.assertTrue((arg_grads == _jr[1]).all()) - - def test_jacfwd2(self): - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - - def __call__(self, y): - a = self.x[0] * y[0] - b = 5 * self.x[2] * y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r - - _jr = jax.jacfwd(f1)(_x, _y) - t = Test() - br = bm.jacfwd(t, grad_vars=t.x)(_y) - self.assertTrue((br == _jr).all()) - - _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) - t = Test() - var_grads, arg_grads = bm.jacfwd(t, grad_vars=t.x, argnums=0)(_y) - print(var_grads, ) - print(arg_grads, ) - self.assertTrue((var_grads == _jr[0]).all()) - self.assertTrue((arg_grads == _jr[1]).all()) - - def test_jacrev_aux1(self): - bm.enable_x64() - - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - - def __call__(self, y): - a = self.x[0] * y[0] - b = 5 * self.x[2] * y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r, (c, d) - - _jr = jax.jacrev(f1)(_x, _y) - t = Test() - br, _ = bm.jacrev(t, grad_vars=t.x, has_aux=True)(_y) - self.assertTrue((br == _jr).all()) - - t = Test() - _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) - _aux = t(_y)[1] - (var_grads, arg_grads), aux = bm.jacrev(t, grad_vars=t.x, argnums=0, has_aux=True)(_y) - print(var_grads, ) - print(arg_grads, ) - self.assertTrue((var_grads == _jr[0]).all()) - self.assertTrue((arg_grads == _jr[1]).all()) - self.assertTrue(bm.array_equal(aux, _aux)) - - bm.disable_x64() - - def test_jacfwd_aux1(self): - bm.enable_x64() - - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - - def __call__(self, y): - a = self.x[0] * y[0] - b = 5 * self.x[2] * y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r, (c, d) - - _jr = jax.jacfwd(f1)(_x, _y) - t = Test() - br, (c, d) = bm.jacfwd(t, grad_vars=t.x, has_aux=True)(_y) - # print(_jr) - # print(br) - a = (br == _jr) - self.assertTrue(a.all()) - - t = Test() - _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) - _aux = t(_y)[1] - (var_grads, arg_grads), aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True)(_y) - print(var_grads, ) - print(arg_grads, ) - self.assertTrue((var_grads == _jr[0]).all()) - self.assertTrue((arg_grads == _jr[1]).all()) - self.assertTrue(bm.array_equal(aux, _aux)) - - bm.disable_x64() - - def test_jacrev_return_aux1(self): - bm.enable_x64() - - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - - def __call__(self, y): - a = self.x[0] * y[0] - b = 5 * self.x[2] * y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r, (c, d) - - _jr = jax.jacrev(f1)(_x, _y) - t = Test() - br, _ = bm.jacrev(t, grad_vars=t.x, has_aux=True)(_y) - self.assertTrue((br == _jr).all()) - - t = Test() - _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) - _val, _aux = t(_y) - (var_grads, arg_grads), value, aux = bm.jacrev(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)(_y) - print(var_grads, ) - print(arg_grads, ) - self.assertTrue((var_grads == _jr[0]).all()) - self.assertTrue((arg_grads == _jr[1]).all()) - self.assertTrue(bm.array_equal(aux, _aux)) - self.assertTrue(bm.array_equal(value, _val)) - - bm.disable_x64() - - def test_jacfwd_return_aux1(self): - bm.enable_x64() - - def f1(x, y): - r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) - return r - - _x = bm.array([1., 2., 3.]) - _y = bm.array([10., 5.]) - - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.array([1., 2., 3.])) - - def __call__(self, y): - a = self.x[0] * y[0] - b = 5 * self.x[2] * y[1] - c = 4 * self.x[1] ** 2 - 2 * self.x[2] - d = self.x[2] * jnp.sin(self.x[0]) - r = jnp.asarray([a, b, c, d]) - return r, (c, d) - - _jr = jax.jacfwd(f1)(_x, _y) - t = Test() - br, _ = bm.jacfwd(t, grad_vars=t.x, has_aux=True)(_y) - self.assertTrue((br == _jr).all()) - - t = Test() - _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) - _val, _aux = t(_y) - (var_grads, arg_grads), value, aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)(_y) - print(_val, ) - print('_aux: ', _aux, 'aux: ', aux) - print(var_grads, ) - print(arg_grads, ) - self.assertTrue((var_grads == _jr[0]).all()) - self.assertTrue((arg_grads == _jr[1]).all()) - self.assertTrue(bm.array_equal(aux, _aux)) - self.assertTrue(bm.array_equal(value, _val)) - - bm.disable_x64() + def test_jacrev1(self): + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + self.y = bm.Variable(bm.array([10., 5.])) + + def __call__(self, ): + a = self.x[0] * self.y[0] + b = 5 * self.x[2] * self.y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r + + _jr = jax.jacrev(f1)(_x, _y) + t = Test() + br = bm.jacrev(t, grad_vars=t.x)() + self.assertTrue((br == _jr).all()) + + _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) + t = Test() + br = bm.jacrev(t, grad_vars=[t.x, t.y])() + self.assertTrue((br[0] == _jr[0]).all()) + self.assertTrue((br[1] == _jr[1]).all()) + + def test_jacfwd1(self): + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + self.y = bm.Variable(bm.array([10., 5.])) + + def __call__(self, ): + a = self.x[0] * self.y[0] + b = 5 * self.x[2] * self.y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r + + _jr = jax.jacfwd(f1)(_x, _y) + t = Test() + br = bm.jacfwd(t, grad_vars=t.x)() + self.assertTrue((br == _jr).all()) + + _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) + t = Test() + br = bm.jacfwd(t, grad_vars=[t.x, t.y])() + self.assertTrue((br[0] == _jr[0]).all()) + self.assertTrue((br[1] == _jr[1]).all()) + + def test_jacrev2(self): + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + + def __call__(self, y): + a = self.x[0] * y[0] + b = 5 * self.x[2] * y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r + + _jr = jax.jacrev(f1)(_x, _y) + t = Test() + br = bm.jacrev(t, grad_vars=t.x)(_y) + self.assertTrue((br == _jr).all()) + + _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) + t = Test() + var_grads, arg_grads = bm.jacrev(t, grad_vars=t.x, argnums=0)(_y) + print(var_grads, ) + print(arg_grads, ) + self.assertTrue((var_grads == _jr[0]).all()) + self.assertTrue((arg_grads == _jr[1]).all()) + + def test_jacfwd2(self): + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + + def __call__(self, y): + a = self.x[0] * y[0] + b = 5 * self.x[2] * y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r + + _jr = jax.jacfwd(f1)(_x, _y) + t = Test() + br = bm.jacfwd(t, grad_vars=t.x)(_y) + self.assertTrue((br == _jr).all()) + + _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) + t = Test() + var_grads, arg_grads = bm.jacfwd(t, grad_vars=t.x, argnums=0)(_y) + print(var_grads, ) + print(arg_grads, ) + self.assertTrue((var_grads == _jr[0]).all()) + self.assertTrue((arg_grads == _jr[1]).all()) + + def test_jacrev_aux1(self): + bm.enable_x64() + + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + + def __call__(self, y): + a = self.x[0] * y[0] + b = 5 * self.x[2] * y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r, (c, d) + + _jr = jax.jacrev(f1)(_x, _y) + t = Test() + br, _ = bm.jacrev(t, grad_vars=t.x, has_aux=True)(_y) + self.assertTrue((br == _jr).all()) + + t = Test() + _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) + _aux = t(_y)[1] + (var_grads, arg_grads), aux = bm.jacrev(t, grad_vars=t.x, argnums=0, has_aux=True)(_y) + print(var_grads, ) + print(arg_grads, ) + self.assertTrue((var_grads == _jr[0]).all()) + self.assertTrue((arg_grads == _jr[1]).all()) + self.assertTrue(bm.array_equal(aux, _aux)) + + bm.disable_x64() + + def test_jacfwd_aux1(self): + bm.enable_x64() + + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + + def __call__(self, y): + a = self.x[0] * y[0] + b = 5 * self.x[2] * y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r, (c, d) + + _jr = jax.jacfwd(f1)(_x, _y) + t = Test() + br, (c, d) = bm.jacfwd(t, grad_vars=t.x, has_aux=True)(_y) + # print(_jr) + # print(br) + a = (br == _jr) + self.assertTrue(a.all()) + + t = Test() + _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) + _aux = t(_y)[1] + (var_grads, arg_grads), aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True)(_y) + print(var_grads, ) + print(arg_grads, ) + self.assertTrue((var_grads == _jr[0]).all()) + self.assertTrue((arg_grads == _jr[1]).all()) + self.assertTrue(bm.array_equal(aux, _aux)) + + bm.disable_x64() + + def test_jacrev_return_aux1(self): + bm.enable_x64() + + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + + def __call__(self, y): + a = self.x[0] * y[0] + b = 5 * self.x[2] * y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r, (c, d) + + _jr = jax.jacrev(f1)(_x, _y) + t = Test() + br, _ = bm.jacrev(t, grad_vars=t.x, has_aux=True)(_y) + self.assertTrue((br == _jr).all()) + + t = Test() + _jr = jax.jacrev(f1, argnums=(0, 1))(_x, _y) + _val, _aux = t(_y) + (var_grads, arg_grads), value, aux = bm.jacrev(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)(_y) + print(var_grads, ) + print(arg_grads, ) + self.assertTrue((var_grads == _jr[0]).all()) + self.assertTrue((arg_grads == _jr[1]).all()) + self.assertTrue(bm.array_equal(aux, _aux)) + self.assertTrue(bm.array_equal(value, _val)) + + bm.disable_x64() + + def test_jacfwd_return_aux1(self): + bm.enable_x64() + + def f1(x, y): + r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])]) + return r + + _x = bm.array([1., 2., 3.]) + _y = bm.array([10., 5.]) + + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.array([1., 2., 3.])) + + def __call__(self, y): + a = self.x[0] * y[0] + b = 5 * self.x[2] * y[1] + c = 4 * self.x[1] ** 2 - 2 * self.x[2] + d = self.x[2] * jnp.sin(self.x[0]) + r = jnp.asarray([a, b, c, d]) + return r, (c, d) + + _jr = jax.jacfwd(f1)(_x, _y) + t = Test() + br, _ = bm.jacfwd(t, grad_vars=t.x, has_aux=True)(_y) + self.assertTrue((br == _jr).all()) + + t = Test() + _jr = jax.jacfwd(f1, argnums=(0, 1))(_x, _y) + _val, _aux = t(_y) + (var_grads, arg_grads), value, aux = bm.jacfwd(t, grad_vars=t.x, argnums=0, has_aux=True, return_value=True)(_y) + print(_val, ) + print('_aux: ', _aux, 'aux: ', aux) + print(var_grads, ) + print(arg_grads, ) + self.assertTrue((var_grads == _jr[0]).all()) + self.assertTrue((arg_grads == _jr[1]).all()) + self.assertTrue(bm.array_equal(aux, _aux)) + self.assertTrue(bm.array_equal(value, _val)) + + bm.disable_x64() class TestPureFuncVectorGrad(unittest.TestCase): - def test1(self): - f = lambda x: 3 * x ** 2 - _x = bm.ones(10) - pprint(bm.vector_grad(f, argnums=0)(_x)) - - def test2(self): - def f(x, y): - dx = x ** 2 + y ** 2 + 10 - return dx - - _x = bm.ones(5) - _y = bm.ones(5) - - g = bm.vector_grad(f, argnums=0)(_x, _y) - pprint(g) - self.assertTrue(bm.array_equal(g, 2 * _x)) - - g = bm.vector_grad(f, argnums=(0,))(_x, _y) - self.assertTrue(bm.array_equal(g[0], 2 * _x)) - - g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) - pprint(g) - self.assertTrue(bm.array_equal(g[0], 2 * _x)) - self.assertTrue(bm.array_equal(g[1], 2 * _y)) - - def test3(self): - def f(x, y): - dx = x ** 2 + y ** 2 + 10 - dy = x ** 3 + y ** 3 - 10 - return dx, dy - - _x = bm.ones(5) - _y = bm.ones(5) - - g = bm.vector_grad(f, argnums=0)(_x, _y) - # pprint(g) - self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2)) - - g = bm.vector_grad(f, argnums=(0,))(_x, _y) - self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) - - g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) - # pprint(g) - self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) - self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2)) - - def test4_2d(self): - def f(x, y): - dx = x ** 2 + y ** 2 + 10 - return dx - - _x = bm.ones((5, 5)) - _y = bm.ones((5, 5)) - - g = bm.vector_grad(f, argnums=0)(_x, _y) - pprint(g) - self.assertTrue(bm.array_equal(g, 2 * _x)) - - g = bm.vector_grad(f, argnums=(0,))(_x, _y) - self.assertTrue(bm.array_equal(g[0], 2 * _x)) - - g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) - pprint(g) - self.assertTrue(bm.array_equal(g[0], 2 * _x)) - self.assertTrue(bm.array_equal(g[1], 2 * _y)) - - def test_aux1(self): - def f(x, y): - dx = x ** 2 + y ** 2 + 10 - dy = x ** 3 + y ** 3 - 10 - return dx, dy - - _x = bm.ones(5) - _y = bm.ones(5) - - g, aux = bm.vector_grad(f, has_aux=True)(_x, _y) - pprint(g, ) - pprint(aux) - self.assertTrue(bm.array_equal(g, 2 * _x)) - self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10)) - - def test_return1(self): - def f(x, y): - dx = x ** 2 + y ** 2 + 10 - return dx - - _x = bm.ones(5) - _y = bm.ones(5) - - g, value = bm.vector_grad(f, return_value=True)(_x, _y) - pprint(g, ) - pprint(value) - self.assertTrue(bm.array_equal(g, 2 * _x)) - self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) - - def test_return_aux1(self): - def f(x, y): - dx = x ** 2 + y ** 2 + 10 - dy = x ** 3 + y ** 3 - 10 - return dx, dy - - _x = bm.ones(5) - _y = bm.ones(5) - - g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y) - print('grad', g) - print('value', value) - print('aux', aux) - self.assertTrue(bm.array_equal(g, 2 * _x)) - self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) - self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10)) + def test1(self): + f = lambda x: 3 * x ** 2 + _x = bm.ones(10) + pprint(bm.vector_grad(f, argnums=0)(_x)) + + def test2(self): + def f(x, y): + dx = x ** 2 + y ** 2 + 10 + return dx + + _x = bm.ones(5) + _y = bm.ones(5) + + g = bm.vector_grad(f, argnums=0)(_x, _y) + pprint(g) + self.assertTrue(bm.array_equal(g, 2 * _x)) + + g = bm.vector_grad(f, argnums=(0,))(_x, _y) + self.assertTrue(bm.array_equal(g[0], 2 * _x)) + + g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) + pprint(g) + self.assertTrue(bm.array_equal(g[0], 2 * _x)) + self.assertTrue(bm.array_equal(g[1], 2 * _y)) + + def test3(self): + def f(x, y): + dx = x ** 2 + y ** 2 + 10 + dy = x ** 3 + y ** 3 - 10 + return dx, dy + + _x = bm.ones(5) + _y = bm.ones(5) + + g = bm.vector_grad(f, argnums=0)(_x, _y) + # pprint(g) + self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2)) + + g = bm.vector_grad(f, argnums=(0,))(_x, _y) + self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) + + g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) + # pprint(g) + self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) + self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2)) + + def test4_2d(self): + def f(x, y): + dx = x ** 2 + y ** 2 + 10 + return dx + + _x = bm.ones((5, 5)) + _y = bm.ones((5, 5)) + + g = bm.vector_grad(f, argnums=0)(_x, _y) + pprint(g) + self.assertTrue(bm.array_equal(g, 2 * _x)) + + g = bm.vector_grad(f, argnums=(0,))(_x, _y) + self.assertTrue(bm.array_equal(g[0], 2 * _x)) + + g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) + pprint(g) + self.assertTrue(bm.array_equal(g[0], 2 * _x)) + self.assertTrue(bm.array_equal(g[1], 2 * _y)) + + def test_aux1(self): + def f(x, y): + dx = x ** 2 + y ** 2 + 10 + dy = x ** 3 + y ** 3 - 10 + return dx, dy + + _x = bm.ones(5) + _y = bm.ones(5) + + g, aux = bm.vector_grad(f, has_aux=True)(_x, _y) + pprint(g, ) + pprint(aux) + self.assertTrue(bm.array_equal(g, 2 * _x)) + self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10)) + + def test_return1(self): + def f(x, y): + dx = x ** 2 + y ** 2 + 10 + return dx + + _x = bm.ones(5) + _y = bm.ones(5) + + g, value = bm.vector_grad(f, return_value=True)(_x, _y) + pprint(g, ) + pprint(value) + self.assertTrue(bm.array_equal(g, 2 * _x)) + self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) + + def test_return_aux1(self): + def f(x, y): + dx = x ** 2 + y ** 2 + 10 + dy = x ** 3 + y ** 3 - 10 + return dx, dy + + _x = bm.ones(5) + _y = bm.ones(5) + + g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y) + print('grad', g) + print('value', value) + print('aux', aux) + self.assertTrue(bm.array_equal(g, 2 * _x)) + self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) + self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10)) class TestClassFuncVectorGrad(unittest.TestCase): - def test1(self): - class Test(bp.BrainPyObject): - def __init__(self): - super(Test, self).__init__() - self.x = bm.Variable(bm.ones(5)) - self.y = bm.Variable(bm.ones(5)) + def test1(self): + class Test(bp.BrainPyObject): + def __init__(self): + super(Test, self).__init__() + self.x = bm.Variable(bm.ones(5)) + self.y = bm.Variable(bm.ones(5)) - def __call__(self, *args, **kwargs): - return self.x ** 2 + self.y ** 2 + 10 + def __call__(self, *args, **kwargs): + return self.x ** 2 + self.y ** 2 + 10 - t = Test() + t = Test() - g = bm.vector_grad(t, grad_vars=t.x)() - self.assertTrue(bm.array_equal(g, 2 * t.x)) + g = bm.vector_grad(t, grad_vars=t.x)() + self.assertTrue(bm.array_equal(g, 2 * t.x)) - g = bm.vector_grad(t, grad_vars=(t.x,))() - self.assertTrue(bm.array_equal(g[0], 2 * t.x)) + g = bm.vector_grad(t, grad_vars=(t.x,))() + self.assertTrue(bm.array_equal(g[0], 2 * t.x)) - g = bm.vector_grad(t, grad_vars=(t.x, t.y))() - self.assertTrue(bm.array_equal(g[0], 2 * t.x)) - self.assertTrue(bm.array_equal(g[1], 2 * t.y)) + g = bm.vector_grad(t, grad_vars=(t.x, t.y))() + self.assertTrue(bm.array_equal(g[0], 2 * t.x)) + self.assertTrue(bm.array_equal(g[1], 2 * t.y)) def vgrad(f, *x): - y, vjp_fn = jax.vjp(f, *x) - return vjp_fn(bm.ones(y.shape).value)[0] + y, vjp_fn = jax.vjp(f, *x) + return vjp_fn(bm.ones(y.shape).value)[0] class TestDebug(parameterized.TestCase): - def test_debug1(self): - a = bm.random.RandomState() - - def f(b): - print(a.value) - return a + b + a.random() - - f = bm.vector_grad(f, argnums=0) - f(1.) - - with jax.disable_jit(): - f(1.) - - def test_debug_correctness1(self): - def test_f(): - a = bm.Variable(bm.ones(2)) - b = bm.Variable(bm.zeros(2)) - - def f1(c): - a.value += 1 - b.value += 10 - return a * b * c - - return a, b, bm.vector_grad(f1, argnums=0)(1.) - - r1 = test_f() - print(r1) - - with jax.disable_jit(): - r2 = test_f() - print(r2) - self.assertTrue(bm.allclose(r1[0], r2[0])) - self.assertTrue(bm.allclose(r1[1], r2[1])) - self.assertTrue(bm.allclose(r1[2], r2[2])) - - def f1(c, a, b): - a += 1 - b += 10 - return a * b * c - - r3 = vgrad(f1, 1., bm.ones(2).value, bm.zeros(2).value) - self.assertTrue(bm.allclose(r1[2], r3)) - - def _bench_f2(self, dd): - a = bm.Variable(bm.ones(2)) - b = bm.Variable(bm.zeros(2)) - - @bm.jit - def run_fun(d): - def f1(c): - a.value += d - b.value += 10 - return a * b * c - - return a, b, bm.vector_grad(f1, argnums=0)(1.) - - return run_fun(dd) - - # def test_debug_correctness2(self): - # r1 = self._bench_f2(1.) - # print(r1) - # - # with jax.disable_jit(): - # r2 = self._bench_f2(1.) - # print(r2) - # - # self.assertTrue(bm.allclose(r1[0], r2[0])) - # self.assertTrue(bm.allclose(r1[1], r2[1])) - # self.assertTrue(bm.allclose(r1[2], r2[2])) - + def test_debug1(self): + a = bm.random.RandomState() + + def f(b): + print(a.value) + return a + b + a.random() + + f = bm.vector_grad(f, argnums=0) + f(1.) + + with jax.disable_jit(): + f(1.) + + def test_debug_correctness1(self): + def test_f(): + a = bm.Variable(bm.ones(2)) + b = bm.Variable(bm.zeros(2)) + + def f1(c): + a.value += 1 + b.value += 10 + return a * b * c + + return a, b, bm.vector_grad(f1, argnums=0)(1.) + + r1 = test_f() + print(r1) + + with jax.disable_jit(): + r2 = test_f() + print(r2) + self.assertTrue(bm.allclose(r1[0], r2[0])) + self.assertTrue(bm.allclose(r1[1], r2[1])) + self.assertTrue(bm.allclose(r1[2], r2[2])) + + def f1(c, a, b): + a += 1 + b += 10 + return a * b * c + + r3 = vgrad(f1, 1., bm.ones(2).value, bm.zeros(2).value) + self.assertTrue(bm.allclose(r1[2], r3)) + + def _bench_f2(self, dd): + a = bm.Variable(bm.ones(2)) + b = bm.Variable(bm.zeros(2)) + + @bm.jit + def run_fun(d): + def f1(c): + a.value += d + b.value += 10 + return a * b * c + + return a, b, bm.vector_grad(f1, argnums=0)(1.) + + return run_fun(dd) + + # def test_debug_correctness2(self): + # r1 = self._bench_f2(1.) + # print(r1) + # + # with jax.disable_jit(): + # r2 = self._bench_f2(1.) + # print(r2) + # + # self.assertTrue(bm.allclose(r1[0], r2[0])) + # self.assertTrue(bm.allclose(r1[1], r2[1])) + # self.assertTrue(bm.allclose(r1[2], r2[2])) # class TestHessian(unittest.TestCase): # def test_hessian5(self): @@ -1130,5 +1126,3 @@ def f1(c): # self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4)) # # bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index ce4d03c84..877c78491 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -9,293 +9,290 @@ class TestCollectionFunction(unittest.TestCase): - def test_f_nodes(self): - class C(bp.DynamicalSystem): - def __init__(self): - super(C, self).__init__() - - class B(bp.DynamicalSystem): - def __init__(self): - super(B, self).__init__() - - self.child1 = C() - self.child2 = C() - - class A(bp.DynamicalSystem): - def __init__(self): - super(A, self).__init__() - - self.child1 = B() - self.child2 = B() - - net = bp.Network(a1=A(), a2=A()) - print(net.nodes(level=2)) - self.assertTrue(len(net.nodes(level=0)) == 1) - self.assertTrue(len(net.nodes(level=0, include_self=False)) == 0) - self.assertTrue(len(net.nodes(level=1)) == (1 + 2)) - self.assertTrue(len(net.nodes(level=1, include_self=False)) == 2) - self.assertTrue(len(net.nodes(level=2)) == (1 + 2 + 4)) - self.assertTrue(len(net.nodes(level=2, include_self=False)) == (2 + 4)) - self.assertTrue(len(net.nodes(level=3)) == (1 + 2 + 4 + 8)) - self.assertTrue(len(net.nodes(level=3, include_self=False)) == (2 + 4 + 8)) - - def test_f_vars(self): - class C(bp.DynamicalSystem): - def __init__(self): - super(C, self).__init__() - - self.var1 = bm.Variable(bm.zeros(1)) - self.var2 = bm.Variable(bm.zeros(1)) - - class B(bp.DynamicalSystem): - def __init__(self): - super(B, self).__init__() - - self.child1 = C() - self.child2 = C() - - self.var1 = bm.Variable(bm.zeros(1)) - self.var2 = bm.Variable(bm.zeros(1)) - - class A(bp.DynamicalSystem): - def __init__(self): - super(A, self).__init__() - - self.child1 = B() - self.child2 = B() - - self.var1 = bm.Variable(bm.zeros(1)) - self.var2 = bm.Variable(bm.zeros(1)) - - net = bp.Network(a1=A(), a2=A()) - print(net.vars(level=2)) - self.assertTrue(len(net.vars(level=0)) == 0) - self.assertTrue(len(net.vars(level=0, include_self=False)) == 0) - self.assertTrue(len(net.vars(level=1)) == 2 * 2) - self.assertTrue(len(net.vars(level=1, include_self=False)) == 2 * 2) - self.assertTrue(len(net.vars(level=2)) == (2 + 4) * 2) - self.assertTrue(len(net.vars(level=2, include_self=False)) == (2 + 4) * 2) - self.assertTrue(len(net.vars(level=3)) == (2 + 4 + 8) * 2) - self.assertTrue(len(net.vars(level=3, include_self=False)) == (2 + 4 + 8) * 2) + def test_f_nodes(self): + class C(bp.DynamicalSystem): + def __init__(self): + super(C, self).__init__() + + class B(bp.DynamicalSystem): + def __init__(self): + super(B, self).__init__() + + self.child1 = C() + self.child2 = C() + + class A(bp.DynamicalSystem): + def __init__(self): + super(A, self).__init__() + + self.child1 = B() + self.child2 = B() + + net = bp.Network(a1=A(), a2=A()) + print(net.nodes(level=2)) + self.assertTrue(len(net.nodes(level=0)) == 1) + self.assertTrue(len(net.nodes(level=0, include_self=False)) == 0) + self.assertTrue(len(net.nodes(level=1)) == (1 + 2)) + self.assertTrue(len(net.nodes(level=1, include_self=False)) == 2) + self.assertTrue(len(net.nodes(level=2)) == (1 + 2 + 4)) + self.assertTrue(len(net.nodes(level=2, include_self=False)) == (2 + 4)) + self.assertTrue(len(net.nodes(level=3)) == (1 + 2 + 4 + 8)) + self.assertTrue(len(net.nodes(level=3, include_self=False)) == (2 + 4 + 8)) + + def test_f_vars(self): + class C(bp.DynamicalSystem): + def __init__(self): + super(C, self).__init__() + + self.var1 = bm.Variable(bm.zeros(1)) + self.var2 = bm.Variable(bm.zeros(1)) + + class B(bp.DynamicalSystem): + def __init__(self): + super(B, self).__init__() + + self.child1 = C() + self.child2 = C() + + self.var1 = bm.Variable(bm.zeros(1)) + self.var2 = bm.Variable(bm.zeros(1)) + + class A(bp.DynamicalSystem): + def __init__(self): + super(A, self).__init__() + + self.child1 = B() + self.child2 = B() + + self.var1 = bm.Variable(bm.zeros(1)) + self.var2 = bm.Variable(bm.zeros(1)) + + net = bp.Network(a1=A(), a2=A()) + print(net.vars(level=2)) + self.assertTrue(len(net.vars(level=0)) == 0) + self.assertTrue(len(net.vars(level=0, include_self=False)) == 0) + self.assertTrue(len(net.vars(level=1)) == 2 * 2) + self.assertTrue(len(net.vars(level=1, include_self=False)) == 2 * 2) + self.assertTrue(len(net.vars(level=2)) == (2 + 4) * 2) + self.assertTrue(len(net.vars(level=2, include_self=False)) == (2 + 4) * 2) + self.assertTrue(len(net.vars(level=3)) == (2 + 4 + 8) * 2) + self.assertTrue(len(net.vars(level=3, include_self=False)) == (2 + 4 + 8) * 2) class TestNodeList(unittest.TestCase): - def test_NodeList_1(self): - bm.random.seed() + def test_NodeList_1(self): + bm.random.seed() - class Object(bp.DynamicalSystem): - def __init__(self): - super().__init__() + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() - self.l1 = bp.layers.Dense(5, 10) - self.ls = bm.NodeList([bp.layers.Dense(10, 4), - bp.layers.Activation(bm.tanh), - bp.layers.Dropout(0.1), - bp.layers.Dense(4, 5), - bp.layers.Activation(bm.relu)]) + self.l1 = bp.layers.Dense(5, 10) + self.ls = bm.NodeList([bp.layers.Dense(10, 4), + bp.layers.Activation(bm.tanh), + bp.layers.Dropout(0.1), + bp.layers.Dense(4, 5), + bp.layers.Activation(bm.relu)]) - def update(self, x): - x = self.l1(x) - for l in self.ls: - x = l(x) - return x + def update(self, x): + x = self.l1(x) + for l in self.ls: + x = l(x) + return x - with bm.environment(mode=bm.NonBatchingMode()): - obj = Object() - self.assertTrue(len(obj.vars()) == 0) - self.assertTrue(len(obj.nodes()) == 7) + with bm.environment(mode=bm.NonBatchingMode()): + obj = Object() + self.assertTrue(len(obj.vars()) == 0) + self.assertTrue(len(obj.nodes()) == 7) - print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) - # print(jax.tree_util.tree_structure(obj)) + print(obj.nodes().keys()) + print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) + # print(jax.tree_util.tree_structure(obj)) - with bm.environment(mode=bm.TrainingMode()): - obj = Object() - self.assertTrue(len(obj.vars()) == 6) - self.assertTrue(len(obj.nodes()) == 7) + with bm.environment(mode=bm.TrainingMode()): + obj = Object() + self.assertTrue(len(obj.vars()) == 6) + self.assertTrue(len(obj.nodes()) == 7) - print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) - # print(jax.tree_util.tree_structure(obj)) + print(obj.nodes().keys()) + print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) + # print(jax.tree_util.tree_structure(obj)) class TestNodeDict(unittest.TestCase): - def test_NodeDict_1(self): - bm.random.seed() - - class Object(bp.DynamicalSystem): - def __init__(self): - super().__init__() - - self.l1 = bp.layers.Dense(5, 10) - self.ls = bm.NodeDict( - { - 'l1': bp.layers.Dense(10, 4), - 'l2': bp.layers.Activation(bm.tanh), - 'l3': bp.layers.Dropout(0.1), - 'l4': bp.layers.Dense(4, 5), - 'l5': bp.layers.Activation(bm.relu) - } - ) - - def update(self, x): - x = self.l1(x) - for l in self.ls: - x = l(x) - return x - - with bm.environment(mode=bm.NonBatchingMode()): - obj = Object() - - self.assertTrue(len(obj.vars()) == 0) - self.assertTrue(len(obj.nodes()) == 7) - self.assertTrue(len(jax.tree_util.tree_leaves(obj)) == 1) - - print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", - obj.nodes(method='relative').keys()) - # print(jax.tree_util.tree_structure(obj)) - - with bm.environment(mode=bm.TrainingMode()): - obj = Object() - self.assertTrue(len(obj.vars()) == 6) - self.assertTrue(len(obj.nodes()) == 7) - - print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", - obj.nodes(method='relative').keys()) - # print(jax.tree_util.tree_structure(obj)) + def test_NodeDict_1(self): + bm.random.seed() + + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.l1 = bp.layers.Dense(5, 10) + self.ls = bm.NodeDict( + { + 'l1': bp.layers.Dense(10, 4), + 'l2': bp.layers.Activation(bm.tanh), + 'l3': bp.layers.Dropout(0.1), + 'l4': bp.layers.Dense(4, 5), + 'l5': bp.layers.Activation(bm.relu) + } + ) + + def update(self, x): + x = self.l1(x) + for l in self.ls: + x = l(x) + return x + + with bm.environment(mode=bm.NonBatchingMode()): + obj = Object() + + self.assertTrue(len(obj.vars()) == 0) + self.assertTrue(len(obj.nodes()) == 7) + self.assertTrue(len(jax.tree_util.tree_leaves(obj)) == 1) + + print(obj.nodes().keys()) + print("obj.nodes(method='relative'): ", + obj.nodes(method='relative').keys()) + # print(jax.tree_util.tree_structure(obj)) + + with bm.environment(mode=bm.TrainingMode()): + obj = Object() + self.assertTrue(len(obj.vars()) == 6) + self.assertTrue(len(obj.nodes()) == 7) + + print(obj.nodes().keys()) + print("obj.nodes(method='relative'): ", + obj.nodes(method='relative').keys()) + # print(jax.tree_util.tree_structure(obj)) class TestVarList(unittest.TestCase): - def test_ListVar_1(self): - bm.random.seed() + def test_ListVar_1(self): + bm.random.seed() - class Object(bp.DynamicalSystem): - def __init__(self): - super().__init__() - self.vs = bm.VarList([bm.Variable(1.), - bm.Variable(2.), - bm.Variable(bm.ones(10))]) + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.vs = bm.VarList([bm.Variable(1.), + bm.Variable(2.), + bm.Variable(bm.ones(10))]) - def update(self): - self.vs[0] += 10. - self.vs[1] += 10. - self.vs[2] += 10. + def update(self): + self.vs[0] += 10. + self.vs[1] += 10. + self.vs[2] += 10. - obj = Object() - self.assertTrue(len(obj.vars()) == 3) - self.assertTrue(len(obj.nodes()) == 1) + obj = Object() + self.assertTrue(len(obj.vars()) == 3) + self.assertTrue(len(obj.nodes()) == 1) - @bm.jit - def f2(): - obj() + @bm.jit + def f2(): + obj() - f2() - print(obj.vs) - self.assertTrue(obj.vs[0].value == 11.) - self.assertTrue(obj.vs[1].value == 12.) - self.assertTrue(bm.allclose(obj.vs[2].value, bm.ones(10) * 11.)) + f2() + print(obj.vs) + self.assertTrue(obj.vs[0].value == 11.) + self.assertTrue(obj.vs[1].value == 12.) + self.assertTrue(bm.allclose(obj.vs[2].value, bm.ones(10) * 11.)) class TestVarDict(unittest.TestCase): - def test_DictVar_1(self): - bm.random.seed() + def test_DictVar_1(self): + bm.random.seed() - class Object(bp.DynamicalSystem): - def __init__(self): - super().__init__() - self.vs = bm.VarDict({'a': bm.Variable(1.), - 'b': bm.Variable(2.), - 'c': bm.Variable(bm.ones(10))}) + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.vs = bm.VarDict({'a': bm.Variable(1.), + 'b': bm.Variable(2.), + 'c': bm.Variable(bm.ones(10))}) - def update(self): - self.vs['a'] += 10. - self.vs['b'] += 10. - self.vs['c'] += 10. + def update(self): + self.vs['a'] += 10. + self.vs['b'] += 10. + self.vs['c'] += 10. - obj = Object() - print(obj.vars()) - self.assertTrue(len(obj.vars()) == 3) - self.assertTrue(len(obj.nodes()) == 1) + obj = Object() + print(obj.vars()) + self.assertTrue(len(obj.vars()) == 3) + self.assertTrue(len(obj.nodes()) == 1) - @bm.jit - def f1(): - obj() + @bm.jit + def f1(): + obj() - f1() - print(obj.vs) - self.assertTrue(obj.vs['a'].value == 11.) - self.assertTrue(obj.vs['b'].value == 12.) - self.assertTrue(bm.allclose(obj.vs['c'].value, bm.ones(10) * 11.)) + f1() + print(obj.vs) + self.assertTrue(obj.vs['a'].value == 11.) + self.assertTrue(obj.vs['b'].value == 12.) + self.assertTrue(bm.allclose(obj.vs['c'].value, bm.ones(10) * 11.)) class TestRegisterBPObjectAsPyTree(unittest.TestCase): - def test1(self): - bm.set(bp_object_as_pytree=True) + def test1(self): + bm.set(bp_object_as_pytree=True) - hh = bp.dyn.HH(1) - hh.reset() + hh = bp.dyn.HH(1) + hh.reset() - tree = jax.tree.structure(hh) - leaves = jax.tree.leaves(hh) - # tree = jax.tree.structure(hh) - # leaves = jax.tree.leaves(hh) + tree = jax.tree.structure(hh) + leaves = jax.tree.leaves(hh) + # tree = jax.tree.structure(hh) + # leaves = jax.tree.leaves(hh) - print(tree) - print(leaves) - print(jax.tree.unflatten(tree, leaves)) - # print(jax.tree.unflatten(tree, leaves)) - print() + print(tree) + print(leaves) + print(jax.tree.unflatten(tree, leaves)) + # print(jax.tree.unflatten(tree, leaves)) + print() class TestStateSavingAndLoading(unittest.TestCase): - def test_load_states(self): - class Object(bp.DynamicalSystem): - def __init__(self): - super().__init__() - - self.l1 = bp.layers.Dense(5, 10) - self.ls = bm.NodeList([bp.layers.Dense(10, 4), - bp.layers.Activation(bm.tanh), - bp.layers.Dropout(0.1), - bp.layers.Dense(4, 5), - bp.layers.Activation(bm.relu)]) - self.lif = bp.dyn.LifRef(5) - - def update(self, x): - x = self.l1(x) - for l in self.ls: - x = l(x) - return x - - with bm.training_environment(): - obj = Object() - variables = {k: dict(n.vars()) for k, n in obj.nodes(include_self=False).items()} - variables = {k: v for k, v in variables.items() if len(v) > 0} - - all_states = obj.state_dict() - all_states = {k: v for k, v in all_states.items() if len(v) > 0} - print(set(all_states.keys())) - print(set(variables.keys())) - - def not_close(x, y): - assert not bm.allclose(x, y) - def all_close(x, y): - assert bm.allclose(x, y) - - jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) - # jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) - - random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) - jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) - # random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) - # jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) - - obj.load_state_dict(random_state) - jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) - # jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) - - - - + def test_load_states(self): + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.l1 = bp.layers.Dense(5, 10) + self.ls = bm.NodeList([bp.layers.Dense(10, 4), + bp.layers.Activation(bm.tanh), + bp.layers.Dropout(0.1), + bp.layers.Dense(4, 5), + bp.layers.Activation(bm.relu)]) + self.lif = bp.dyn.LifRef(5) + + def update(self, x): + x = self.l1(x) + for l in self.ls: + x = l(x) + return x + + with bm.training_environment(): + obj = Object() + variables = {k: dict(n.vars()) for k, n in obj.nodes(include_self=False).items()} + variables = {k: v for k, v in variables.items() if len(v) > 0} + + all_states = obj.state_dict() + all_states = {k: v for k, v in all_states.items() if len(v) > 0} + print(set(all_states.keys())) + print(set(variables.keys())) + + def not_close(x, y): + assert not bm.allclose(x, y) + + def all_close(x, y): + assert bm.allclose(x, y) + + jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + # jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + + random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + # random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + # jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + + obj.load_state_dict(random_state) + jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + # jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) diff --git a/brainpy/_src/math/object_transform/tests/test_circular_reference.py b/brainpy/_src/math/object_transform/tests/test_circular_reference.py index 8ef89dfca..2d1591b7e 100644 --- a/brainpy/_src/math/object_transform/tests/test_circular_reference.py +++ b/brainpy/_src/math/object_transform/tests/test_circular_reference.py @@ -6,71 +6,71 @@ class HH(bp.dyn.NeuDyn): - def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, - gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs): - super(HH, self).__init__(size=size, **kwargs) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - self.phi = phi - - # variables - self.V = bp.math.ones(self.num) * -65. - self.h = bp.math.ones(self.num) * 0.6 - self.n = bp.math.ones(self.num) * 0.32 - self.inputs = bp.math.zeros(self.num) - self.spikes = bp.math.zeros(self.num, dtype=bp.math.bool_) - - self.pre = None - - @bp.odeint - def integral(self, V, h, n, t, Iext): - alpha = 0.07 * bp.math.exp(-(V + 58) / 20) - beta = 1 / (bp.math.exp(-0.1 * (V + 28)) + 1) - dhdt = self.phi * (alpha * (1 - h) - beta * h) - - alpha = -0.01 * (V + 34) / (bp.math.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * bp.math.exp(-(V + 44) / 80) - dndt = self.phi * (alpha * (1 - n) - beta * n) - - m_alpha = -0.1 * (V + 35) / (bp.math.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * bp.math.exp(-(V + 60) / 18) - m = m_alpha / (m_alpha + m_beta) - INa = self.gNa * m ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + Iext) / self.C - - return dVdt, dhdt, dndt - - def update(self, t, dt): - V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) - self.spikes[:] = bp.math.logical_and(self.V < self.V_th, V >= self.V_th) - self.V[:] = V - self.h[:] = h - self.n[:] = n - self.inputs[:] = 0. + def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, + gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs): + super(HH, self).__init__(size=size, **kwargs) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = bp.math.ones(self.num) * -65. + self.h = bp.math.ones(self.num) * 0.6 + self.n = bp.math.ones(self.num) * 0.32 + self.inputs = bp.math.zeros(self.num) + self.spikes = bp.math.zeros(self.num, dtype=bp.math.bool_) + + self.pre = None + + @bp.odeint + def integral(self, V, h, n, t, Iext): + alpha = 0.07 * bp.math.exp(-(V + 58) / 20) + beta = 1 / (bp.math.exp(-0.1 * (V + 28)) + 1) + dhdt = self.phi * (alpha * (1 - h) - beta * h) + + alpha = -0.01 * (V + 34) / (bp.math.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * bp.math.exp(-(V + 44) / 80) + dndt = self.phi * (alpha * (1 - n) - beta * n) + + m_alpha = -0.1 * (V + 35) / (bp.math.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * bp.math.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + return dVdt, dhdt, dndt + + def update(self, t, dt): + V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) + self.spikes[:] = bp.math.logical_and(self.V < self.V_th, V >= self.V_th) + self.V[:] = V + self.h[:] = h + self.n[:] = n + self.inputs[:] = 0. def test_nodes(): - A = HH(1, name='X') - B = HH(1, name='Y') - A.pre = B - B.pre = A - - net = bp.Network(A, B) - abs_nodes = net.nodes(method='absolute') - rel_nodes = net.nodes(method='relative') - print() - pprint(abs_nodes) - pprint(rel_nodes) - - assert len(abs_nodes) == 3 - assert len(rel_nodes) == 5 + A = HH(1, name='X') + B = HH(1, name='Y') + A.pre = B + B.pre = A + + net = bp.Network(A, B) + abs_nodes = net.nodes(method='absolute') + rel_nodes = net.nodes(method='relative') + print() + pprint(abs_nodes) + pprint(rel_nodes) + + assert len(abs_nodes) == 3 + assert len(rel_nodes) == 5 diff --git a/brainpy/_src/math/object_transform/tests/test_collector.py b/brainpy/_src/math/object_transform/tests/test_collector.py index 17ba00ec9..f597ec387 100644 --- a/brainpy/_src/math/object_transform/tests/test_collector.py +++ b/brainpy/_src/math/object_transform/tests/test_collector.py @@ -4,286 +4,286 @@ from pprint import pprint import jax.numpy as jnp + import brainpy as bp class GABAa_without_Variable(bp.synapses.TwoEndConn): - def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., - alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): - super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs) - - # parameters - self.g_max = g_max - self.E = E - self.alpha = alpha - self.beta = beta - self.T = T - self.T_duration = T_duration - self.delay = delay - - # connections - self.conn = conn(pre.size, post.size) - self.conn_mat = self.conn.requires('conn_mat') - self.size = jnp.shape(self.conn_mat) - - # variables - self.t_last_pre_spike = jnp.ones(self.size) * -1e7 - self.s = jnp.zeros(self.size) - - self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s) - - def update(self, tdi): - spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat - self.t_last_pre_spike[:] = jnp.where(spike, tdi.t, self.t_last_pre_spike) - TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T - self.s[:] = self.int_s(self.s, tdi.t, TT) - self.post.inputs -= jnp.sum(self.s, axis=0) * (self.post.V - self.E) + def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., + alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): + super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs) + + # parameters + self.g_max = g_max + self.E = E + self.alpha = alpha + self.beta = beta + self.T = T + self.T_duration = T_duration + self.delay = delay + + # connections + self.conn = conn(pre.size, post.size) + self.conn_mat = self.conn.requires('conn_mat') + self.size = jnp.shape(self.conn_mat) + + # variables + self.t_last_pre_spike = jnp.ones(self.size) * -1e7 + self.s = jnp.zeros(self.size) + + self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s) + + def update(self, tdi): + spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat + self.t_last_pre_spike[:] = jnp.where(spike, tdi.t, self.t_last_pre_spike) + TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T + self.s[:] = self.int_s(self.s, tdi.t, TT) + self.post.inputs -= jnp.sum(self.s, axis=0) * (self.post.V - self.E) class HH_without_Variable(bp.dyn.NeuDyn): - def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, - gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs): - super(HH_without_Variable, self).__init__(size=size, **kwargs) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - self.phi = phi - - # variables - self.V = jnp.ones(self.num) * -65. - self.h = jnp.ones(self.num) * 0.6 - self.n = jnp.ones(self.num) * 0.32 - self.inputs = jnp.zeros(self.num) - self.spikes = jnp.zeros(self.num, dtype=jnp.bool_) - - self.integral = bp.odeint(self.derivative) - - def derivative(self, V, h, n, t, Iext): - alpha = 0.07 * jnp.exp(-(V + 58) / 20) - beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) - dhdt = self.phi * (alpha * (1 - h) - beta * h) - - alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * jnp.exp(-(V + 44) / 80) - dndt = self.phi * (alpha * (1 - n) - beta * n) - - m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * jnp.exp(-(V + 60) / 18) - m = m_alpha / (m_alpha + m_beta) - INa = self.gNa * m ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + Iext) / self.C - - return dVdt, dhdt, dndt - - def update(self, tdi): - V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs) - self.spikes[:] = jnp.logical_and(self.V < self.V_th, V >= self.V_th) - self.V[:] = V - self.h[:] = h - self.n[:] = n - self.inputs[:] = 0. + def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, + gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs): + super(HH_without_Variable, self).__init__(size=size, **kwargs) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = jnp.ones(self.num) * -65. + self.h = jnp.ones(self.num) * 0.6 + self.n = jnp.ones(self.num) * 0.32 + self.inputs = jnp.zeros(self.num) + self.spikes = jnp.zeros(self.num, dtype=jnp.bool_) + + self.integral = bp.odeint(self.derivative) + + def derivative(self, V, h, n, t, Iext): + alpha = 0.07 * jnp.exp(-(V + 58) / 20) + beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) + dhdt = self.phi * (alpha * (1 - h) - beta * h) + + alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * jnp.exp(-(V + 44) / 80) + dndt = self.phi * (alpha * (1 - n) - beta * n) + + m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * jnp.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + return dVdt, dhdt, dndt + + def update(self, tdi): + V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs) + self.spikes[:] = jnp.logical_and(self.V < self.V_th, V >= self.V_th) + self.V[:] = V + self.h[:] = h + self.n[:] = n + self.inputs[:] = 0. def test_subset_integrator(): - neu = HH_without_Variable(10) - syn = GABAa_without_Variable(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) - syn.g_max = 0.1 / neu.num - net = bp.Network(neu, syn) + neu = HH_without_Variable(10) + syn = GABAa_without_Variable(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) + syn.g_max = 0.1 / neu.num + net = bp.Network(neu, syn) - ints = net.nodes(level=-1).subset(bp.Integrator) - print() - print(ints) + ints = net.nodes(level=-1).subset(bp.Integrator) + print() + print(ints) - ode_ints = ints.subset(bp.ode.ODEIntegrator).unique() - print(ode_ints) - assert len(ode_ints) == 2 + ode_ints = ints.subset(bp.ode.ODEIntegrator).unique() + print(ode_ints) + assert len(ode_ints) == 2 def test_neu_vars_1(): - neu = HH_without_Variable(10) - vars = neu.vars() + neu = HH_without_Variable(10) + vars = neu.vars() - print() - print(vars) - assert len(vars) == 0 + print() + print(vars) + assert len(vars) == 0 class HH_with_Variable(bp.dyn.NeuDyn): - def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, - gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs): - super(HH_with_Variable, self).__init__(size=size, **kwargs) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - self.phi = phi - - # variables - self.V = bp.math.Variable(jnp.ones(self.num) * -65.) - self.h = bp.math.Variable(jnp.ones(self.num) * 0.6) - self.n = bp.math.Variable(jnp.ones(self.num) * 0.32) - self.inputs = bp.math.Variable(jnp.zeros(self.num)) - self.spikes = bp.math.Variable(jnp.zeros(self.num, dtype=jnp.bool_)) - - self.integral = bp.odeint(self.derivative) - - def derivative(self, V, h, n, t, Iext): - alpha = 0.07 * jnp.exp(-(V + 58) / 20) - beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) - dhdt = self.phi * (alpha * (1 - h) - beta * h) - - alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * jnp.exp(-(V + 44) / 80) - dndt = self.phi * (alpha * (1 - n) - beta * n) - - m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * jnp.exp(-(V + 60) / 18) - m = m_alpha / (m_alpha + m_beta) - INa = self.gNa * m ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + Iext) / self.C - - return dVdt, dhdt, dndt - - def update(self, tdi): - V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs) - self.spikes[:] = jnp.logical_and(self.V < self.V_th, V >= self.V_th) - self.V[:] = V - self.h[:] = h - self.n[:] = n - self.inputs[:] = 0. + def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, + gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, **kwargs): + super(HH_with_Variable, self).__init__(size=size, **kwargs) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = bp.math.Variable(jnp.ones(self.num) * -65.) + self.h = bp.math.Variable(jnp.ones(self.num) * 0.6) + self.n = bp.math.Variable(jnp.ones(self.num) * 0.32) + self.inputs = bp.math.Variable(jnp.zeros(self.num)) + self.spikes = bp.math.Variable(jnp.zeros(self.num, dtype=jnp.bool_)) + + self.integral = bp.odeint(self.derivative) + + def derivative(self, V, h, n, t, Iext): + alpha = 0.07 * jnp.exp(-(V + 58) / 20) + beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) + dhdt = self.phi * (alpha * (1 - h) - beta * h) + + alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * jnp.exp(-(V + 44) / 80) + dndt = self.phi * (alpha * (1 - n) - beta * n) + + m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * jnp.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + return dVdt, dhdt, dndt + + def update(self, tdi): + V, h, n = self.integral(self.V, self.h, self.n, tdi.t, self.inputs) + self.spikes[:] = jnp.logical_and(self.V < self.V_th, V >= self.V_th) + self.V[:] = V + self.h[:] = h + self.n[:] = n + self.inputs[:] = 0. def test_neu_vars_2(): - neu = HH_with_Variable(10) - vars = neu.vars() - print() - print(vars.keys()) + neu = HH_with_Variable(10) + vars = neu.vars() + print() + print(vars.keys()) - vars = neu.vars(method='relative') - print() - print(vars.keys()) + vars = neu.vars(method='relative') + print() + print(vars.keys()) def test_neu_nodes_1(): - neu = HH_with_Variable(10) - print() - print(neu.nodes().keys()) - assert len(neu.nodes(level=-1, include_self=False)) == 1 + neu = HH_with_Variable(10) + print() + print(neu.nodes().keys()) + assert len(neu.nodes(level=-1, include_self=False)) == 1 - print() - print(neu.nodes(method='relative').keys()) - assert len(neu.nodes(method='relative', include_self=False)) == 1 + print() + print(neu.nodes(method='relative').keys()) + assert len(neu.nodes(method='relative', include_self=False)) == 1 class GABAa_with_Variable(bp.synapses.TwoEndConn): - def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., - alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): - super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs) - - # parameters - self.g_max = g_max - self.E = E - self.alpha = alpha - self.beta = beta - self.T = T - self.T_duration = T_duration - self.delay = delay - - # connections - self.conn = conn(pre.size, post.size) - self.conn_mat = self.conn.requires('conn_mat') - self.size = jnp.shape(self.conn_mat) - - # variables - self.t_last_pre_spike = bp.math.Variable(jnp.ones(self.size) * -1e7) - self.s = bp.math.Variable(jnp.zeros(self.size)) - self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s) - - def update(self, tdi): - spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat - self.t_last_pre_spike[:] = jnp.where(spike, tdi.t, self.t_last_pre_spike.value) - TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T - self.s[:] = self.int_s(self.s, tdi.t, TT) - self.post.inputs -= jnp.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) + def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., + alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): + super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs) + + # parameters + self.g_max = g_max + self.E = E + self.alpha = alpha + self.beta = beta + self.T = T + self.T_duration = T_duration + self.delay = delay + + # connections + self.conn = conn(pre.size, post.size) + self.conn_mat = self.conn.requires('conn_mat') + self.size = jnp.shape(self.conn_mat) + + # variables + self.t_last_pre_spike = bp.math.Variable(jnp.ones(self.size) * -1e7) + self.s = bp.math.Variable(jnp.zeros(self.size)) + self.int_s = bp.odeint(lambda s, t, TT: self.alpha * TT * (1 - s) - self.beta * s) + + def update(self, tdi): + spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat + self.t_last_pre_spike[:] = jnp.where(spike, tdi.t, self.t_last_pre_spike.value) + TT = ((tdi.t - self.t_last_pre_spike) < self.T_duration) * self.T + self.s[:] = self.int_s(self.s, tdi.t, TT) + self.post.inputs -= jnp.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) def test_net_1(): - neu = HH_without_Variable(10) - syn = GABAa_without_Variable(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) - net = bp.Network(neu=neu, syn=syn) + neu = HH_without_Variable(10) + syn = GABAa_without_Variable(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) + net = bp.Network(neu=neu, syn=syn) - # variables - print() - pprint(list(net.vars().keys())) - assert len(net.vars()) == 0 + # variables + print() + pprint(list(net.vars().keys())) + assert len(net.vars()) == 0 - print() - pprint(list(net.vars(method='relative').keys())) - assert len(net.vars(method='relative')) == 0 + print() + pprint(list(net.vars(method='relative').keys())) + assert len(net.vars(method='relative')) == 0 - # nodes - print() - pprint(list(net.nodes().unique().keys())) - # assert len(net.nodes()) == 8 + # nodes + print() + pprint(list(net.nodes().unique().keys())) + # assert len(net.nodes()) == 8 - print() - pprint(list(net.nodes(method='relative').unique().keys())) - # assert len(net.nodes(method='relative')) == 12 + print() + pprint(list(net.nodes(method='relative').unique().keys())) + # assert len(net.nodes(method='relative')) == 12 def test_net_vars_2(): - neu = HH_with_Variable(10) - syn = GABAa_with_Variable(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) - net = bp.Network(neu=neu, syn=syn) + neu = HH_with_Variable(10) + syn = GABAa_with_Variable(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) + net = bp.Network(neu=neu, syn=syn) - # variables - print() - pprint(list(net.vars().keys())) - # assert len(net.vars()) == 3 + # variables + print() + pprint(list(net.vars().keys())) + # assert len(net.vars()) == 3 - print() - pprint(list(net.vars(method='relative').keys())) - # assert len(net.vars(method='relative')) == 3 + print() + pprint(list(net.vars(method='relative').keys())) + # assert len(net.vars(method='relative')) == 3 - # nodes - print() - pprint(list(net.nodes().keys())) - # assert len(net.nodes()) == 8 + # nodes + print() + pprint(list(net.nodes().keys())) + # assert len(net.nodes()) == 8 - print() - pprint(list(net.nodes(method='relative').keys())) - # assert len(net.nodes(method='relative')) == 6 + print() + pprint(list(net.nodes(method='relative').keys())) + # assert len(net.nodes(method='relative')) == 6 def test_hidden_variables(): - class BPClass(bp.BrainPyObject): - _excluded_vars = ('_rng_', ) - - def __init__(self): - super(BPClass, self).__init__() + class BPClass(bp.BrainPyObject): + _excluded_vars = ('_rng_',) - self._rng_ = bp.math.random.RandomState() - self.rng = bp.math.random.RandomState() + def __init__(self): + super(BPClass, self).__init__() - model = BPClass() + self._rng_ = bp.math.random.RandomState() + self.rng = bp.math.random.RandomState() - print(model.vars(level=-1).keys()) - assert len(model.vars(level=-1)) == 1 + model = BPClass() + print(model.vars(level=-1).keys()) + assert len(model.vars(level=-1)) == 1 diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py index b504d5e9d..8272a4371 100644 --- a/brainpy/_src/math/object_transform/tests/test_controls.py +++ b/brainpy/_src/math/object_transform/tests/test_controls.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import tempfile import unittest from functools import partial @@ -12,329 +11,326 @@ class TestLoop(parameterized.TestCase): - @parameterized.named_parameters( - {"testcase_name": "_jit_scan={}_jit_f={}_unroll={}".format(jit_scan, jit_f, unroll), - "jit_scan": jit_scan, - "jit_f": jit_f, - "unroll": unroll} - for jit_scan in [False, True] - for jit_f in [False, True] - for unroll in [1, 2] - ) - def test_for_loop(self, jit_scan, jit_f, unroll): - rng = bm.random.RandomState(123) - - c = bm.Variable(rng.randn(4)) - d = rng.randn(2) - all_a = rng.randn(5, 3) - - def f(a): - assert a.shape == (3,) - assert c.shape == (4,) - b = bm.cos(bm.sum(bm.sin(a)) + bm.sum(bm.cos(c)) + bm.sum(bm.tan(d))) - c.value = bm.sin(c * b) - assert b.shape == () - return b - - if jit_f: - f = bm.jit(f) - scan = partial(bm.for_loop, f, unroll=unroll, ) - if jit_scan: - scan = bm.jit(scan) - ans = scan(operands=all_a) - print(ans) - print(c) - - def test_for_loop_progress_bar(self): - xs = bm.arange(100) - ys = bm.for_loop(lambda a: a, xs, progress_bar=True) - self.assertTrue(bm.allclose(xs, ys)) - - def test_for_loop2(self): - class MyClass(bp.DynamicalSystem): - def __init__(self): - super().__init__() - self.a = bm.Variable(bm.zeros(1)) - - def update(self): - self.a += 1 - - cls = MyClass() - indices = bm.arange(10) - bm.for_loop(cls.step_run, indices) - self.assertTrue(bm.allclose(cls.a, 10.)) + @parameterized.named_parameters( + {"testcase_name": "_jit_scan={}_jit_f={}_unroll={}".format(jit_scan, jit_f, unroll), + "jit_scan": jit_scan, + "jit_f": jit_f, + "unroll": unroll} + for jit_scan in [False, True] + for jit_f in [False, True] + for unroll in [1, 2] + ) + def test_for_loop(self, jit_scan, jit_f, unroll): + rng = bm.random.RandomState(123) + + c = bm.Variable(rng.randn(4)) + d = rng.randn(2) + all_a = rng.randn(5, 3) + + def f(a): + assert a.shape == (3,) + assert c.shape == (4,) + b = bm.cos(bm.sum(bm.sin(a)) + bm.sum(bm.cos(c)) + bm.sum(bm.tan(d))) + c.value = bm.sin(c * b) + assert b.shape == () + return b + + if jit_f: + f = bm.jit(f) + scan = partial(bm.for_loop, f, unroll=unroll, ) + if jit_scan: + scan = bm.jit(scan) + ans = scan(operands=all_a) + print(ans) + print(c) + + def test_for_loop_progress_bar(self): + xs = bm.arange(100) + ys = bm.for_loop(lambda a: a, xs, progress_bar=True) + self.assertTrue(bm.allclose(xs, ys)) + + def test_for_loop2(self): + class MyClass(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.a = bm.Variable(bm.zeros(1)) + + def update(self): + self.a += 1 + + cls = MyClass() + indices = bm.arange(10) + bm.for_loop(cls.step_run, indices) + self.assertTrue(bm.allclose(cls.a, 10.)) class TestScan(unittest.TestCase): - def test1(self): - a = bm.Variable(1) - - def f(carray, x): - carray += x - a.value += 1. - return carray, a.value - - carry, outs = bm.scan(f, bm.zeros(2), bm.arange(10)) - self.assertTrue(bm.allclose(carry, 45.)) - expected = bm.arange(1, 11).astype(outs.dtype) - expected = bm.expand_dims(expected, axis=-1) - self.assertTrue(bm.allclose(outs, expected)) - - def test2(self): - a = bm.Variable(1) - - def f(carray, x): - carray += x - a.value += 1. - return carray, a.value - - @bm.jit - def f_outer(carray, x): - carry, outs = bm.scan(f, carray, x, unroll=2) - return carry, outs - - carry, outs = f_outer(bm.zeros(2), bm.arange(10)) - self.assertTrue(bm.allclose(carry, 45.)) - expected = bm.arange(1, 11).astype(outs.dtype) - expected = bm.expand_dims(expected, axis=-1) - self.assertTrue(bm.allclose(outs, expected)) - - def test_disable_jit(self): - def cumsum(res, el): - res = res + el - print(res) - return res, res # ("carryover", "accumulated") - - a = bm.array([1, 2, 3, 5, 7, 11, 13, 17]).value - result_init = 0 - with jax.disable_jit(): - final, result = jax.lax.scan(cumsum, result_init, a) - - b = bm.array([1, 2, 3, 5, 7, 11, 13, 17]) - result_init = 0 - with jax.disable_jit(): - final, result = bm.scan(cumsum, result_init, b) - - - - def test_array_aware_of_bp_array(self): - def cumsum(res, el): - res = bm.asarray(res + el) - return res, res # ("carryover", "accumulated") - - b = bm.array([1, 2, 3, 5, 7, 11, 13, 17]) - result_init = 0 - with jax.disable_jit(): - final, result = bm.scan(cumsum, result_init, b) + def test1(self): + a = bm.Variable(1) + + def f(carray, x): + carray += x + a.value += 1. + return carray, a.value + + carry, outs = bm.scan(f, bm.zeros(2), bm.arange(10)) + self.assertTrue(bm.allclose(carry, 45.)) + expected = bm.arange(1, 11).astype(outs.dtype) + expected = bm.expand_dims(expected, axis=-1) + self.assertTrue(bm.allclose(outs, expected)) + + def test2(self): + a = bm.Variable(1) + + def f(carray, x): + carray += x + a.value += 1. + return carray, a.value + + @bm.jit + def f_outer(carray, x): + carry, outs = bm.scan(f, carray, x, unroll=2) + return carry, outs + + carry, outs = f_outer(bm.zeros(2), bm.arange(10)) + self.assertTrue(bm.allclose(carry, 45.)) + expected = bm.arange(1, 11).astype(outs.dtype) + expected = bm.expand_dims(expected, axis=-1) + self.assertTrue(bm.allclose(outs, expected)) + + def test_disable_jit(self): + def cumsum(res, el): + res = res + el + print(res) + return res, res # ("carryover", "accumulated") + + a = bm.array([1, 2, 3, 5, 7, 11, 13, 17]).value + result_init = 0 + with jax.disable_jit(): + final, result = jax.lax.scan(cumsum, result_init, a) + + b = bm.array([1, 2, 3, 5, 7, 11, 13, 17]) + result_init = 0 + with jax.disable_jit(): + final, result = bm.scan(cumsum, result_init, b) + + def test_array_aware_of_bp_array(self): + def cumsum(res, el): + res = bm.asarray(res + el) + return res, res # ("carryover", "accumulated") + + b = bm.array([1, 2, 3, 5, 7, 11, 13, 17]) + result_init = 0 + with jax.disable_jit(): + final, result = bm.scan(cumsum, result_init, b) class TestCond(unittest.TestCase): - def test1(self): - bm.random.seed(1) - bm.cond(True, lambda: bm.random.random(10), lambda: bm.random.random(10), ()) - bm.cond(False, lambda: bm.random.random(10), lambda: bm.random.random(10), ()) + def test1(self): + bm.random.seed(1) + bm.cond(True, lambda: bm.random.random(10), lambda: bm.random.random(10), ()) + bm.cond(False, lambda: bm.random.random(10), lambda: bm.random.random(10), ()) class TestIfElse(unittest.TestCase): - def test1(self): - def f(a): - return bm.ifelse(conditions=[a < 0, a < 2, a < 5, a < 10, a < 20], - branches=[lambda: 1, - lambda: 2, - lambda: 3, - lambda: 4, - lambda: 5]) - - self.assertTrue(f(3) == 3) - self.assertTrue(f(1) == 2) - self.assertTrue(f(-1) == 1) - - def test2(self): - def f(a): - return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0, a > -1], - branches=[1, 2, 3, 4, 5]) - - self.assertTrue(f(3) == 3) - self.assertTrue(f(1) == 4) - self.assertTrue(f(-1) == 5) - - def test_dyn_vars1(self): - var_a = bm.Variable(bm.zeros(1)) - - def f(a): - def f1(): - var_a.value += 1 - return 1 - - return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0, a > -1], - branches=[f1, - lambda: 2, lambda: 3, - lambda: 4, lambda: 5]) - - self.assertTrue(f(11) == 1) - print(var_a) - self.assertTrue(bm.all(var_a == 1)) - self.assertTrue(f(1) == 4) - self.assertTrue(f(-1) == 5) - - def test_vmap(self): - def f(operands): - f = lambda a: bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0], - branches=[lambda _: 1, - lambda _: 2, - lambda _: 3, - lambda _: 4, - lambda _: 5, ], - operands=a) - return vmap(f)(operands) - - r = f(bm.random.randint(-20, 20, 200)) - self.assertTrue(r.size == 200) - - def test_vmap2(self): - def f2(): - f = lambda a: bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0], - branches=[1, 2, 3, 4, lambda _: 5], - operands=a) - return vmap(f)(bm.random.randint(-20, 20, 200)) - - self.assertTrue(f2().size == 200) - - def test_grad1(self): - def F2(x): - return bm.ifelse(conditions=(x >= 10,), - branches=[lambda x: x, - lambda x: x ** 2, ], - operands=x) - - self.assertTrue(bm.grad(F2)(9.0) == 18.) - self.assertTrue(bm.grad(F2)(11.0) == 1.) - - def test_grad2(self): - def F3(x): - return bm.ifelse(conditions=(x >= 10, x >= 0), - branches=[lambda x: x, - lambda x: x ** 2, - lambda x: x ** 4, ], - operands=x) - - self.assertTrue(bm.grad(F3)(9.0) == 18.) - self.assertTrue(bm.grad(F3)(11.0) == 1.) + def test1(self): + def f(a): + return bm.ifelse(conditions=[a < 0, a < 2, a < 5, a < 10, a < 20], + branches=[lambda: 1, + lambda: 2, + lambda: 3, + lambda: 4, + lambda: 5]) + + self.assertTrue(f(3) == 3) + self.assertTrue(f(1) == 2) + self.assertTrue(f(-1) == 1) + + def test2(self): + def f(a): + return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0, a > -1], + branches=[1, 2, 3, 4, 5]) + + self.assertTrue(f(3) == 3) + self.assertTrue(f(1) == 4) + self.assertTrue(f(-1) == 5) + + def test_dyn_vars1(self): + var_a = bm.Variable(bm.zeros(1)) + + def f(a): + def f1(): + var_a.value += 1 + return 1 + + return bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0, a > -1], + branches=[f1, + lambda: 2, lambda: 3, + lambda: 4, lambda: 5]) + + self.assertTrue(f(11) == 1) + print(var_a) + self.assertTrue(bm.all(var_a == 1)) + self.assertTrue(f(1) == 4) + self.assertTrue(f(-1) == 5) + + def test_vmap(self): + def f(operands): + f = lambda a: bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0], + branches=[lambda _: 1, + lambda _: 2, + lambda _: 3, + lambda _: 4, + lambda _: 5, ], + operands=a) + return vmap(f)(operands) + + r = f(bm.random.randint(-20, 20, 200)) + self.assertTrue(r.size == 200) + + def test_vmap2(self): + def f2(): + f = lambda a: bm.ifelse(conditions=[a > 10, a > 5, a > 2, a > 0], + branches=[1, 2, 3, 4, lambda _: 5], + operands=a) + return vmap(f)(bm.random.randint(-20, 20, 200)) + + self.assertTrue(f2().size == 200) + + def test_grad1(self): + def F2(x): + return bm.ifelse(conditions=(x >= 10,), + branches=[lambda x: x, + lambda x: x ** 2, ], + operands=x) + + self.assertTrue(bm.grad(F2)(9.0) == 18.) + self.assertTrue(bm.grad(F2)(11.0) == 1.) + + def test_grad2(self): + def F3(x): + return bm.ifelse(conditions=(x >= 10, x >= 0), + branches=[lambda x: x, + lambda x: x ** 2, + lambda x: x ** 4, ], + operands=x) + + self.assertTrue(bm.grad(F3)(9.0) == 18.) + self.assertTrue(bm.grad(F3)(11.0) == 1.) class TestWhile(unittest.TestCase): - def test1(self): - bm.random.seed() - - a = bm.Variable(bm.zeros(1)) - b = bm.Variable(bm.ones(1)) - - def cond(x, y): - return x < 6. - - def body(x, y): - a.value += x - b.value *= y - return x + b[0], y + 1. - - res = bm.while_loop(body, cond, operands=(1., 1.)) - print() - print(res) - - def test2(self): - bm.random.seed() - - a = bm.Variable(bm.zeros(1)) - b = bm.Variable(bm.ones(1)) - - def cond(x, y): - return x < 6. - - def body(x, y): - a.value += x - b.value *= y - return x + b[0], y + 1. - - res = bm.while_loop(body, cond, operands=(1., 1.)) - print() - print(res) - - with jax.disable_jit(): - a = bm.Variable(bm.zeros(1)) - b = bm.Variable(bm.ones(1)) - - res2 = bm.while_loop(body, cond, operands=(1., 1.)) - print(res2) - self.assertTrue(bm.array_equal(res2[0], res[0])) - self.assertTrue(bm.array_equal(res2[1], res[1])) - - def test3(self): - bm.random.seed() - - a = bm.Variable(bm.zeros(1)) - b = bm.Variable(bm.ones(1)) - - def cond(x, y): - return bm.all(a.value < 6.) - - def body(x, y): - a.value += x - b.value *= y - - res = bm.while_loop(body, cond, operands=(1., 1.)) - self.assertTrue(bm.allclose(a, 6.)) - self.assertTrue(bm.allclose(b, 1.)) - print() - print(res) - print(a) - print(b) - - # def test4(self): - # bm.random.seed() - # - # a = bm.Variable(bm.zeros(1)) - # b = bm.Variable(bm.ones(1)) - # - # def cond(x, y): - # a.value += 1 - # return bm.all(a.value < 6.) - # - # def body(x, y): - # a.value += x - # b.value *= y - # - # res = bm.while_loop(body, cond, operands=(1., 1.)) - # self.assertTrue(bm.allclose(a, 7.)) # Corrected: condition function increments a each time before checking - # self.assertTrue(bm.allclose(b, 1.)) - # print(res) - # print(a) - # print(b) - # print() - - def test5(self): - bm.random.seed() - - a = bm.Variable(bm.zeros(1)) - b = bm.Variable(bm.ones(1)) - c = bm.Variable(bm.ones(1)) - - def cond(x, y): - a.value += 1 - return bm.all(a.value < 6.) - - def body(x, y): - a.value += x - b.value *= y - return x + 1, y + 1 - - @bm.jit - def run(a, b): - x, y = bm.while_loop(body, cond, operands=(a, b)) - return c + x - - # Test that JIT compilation fails when condition function has write states - with self.assertRaises(ValueError) as cm: - run(0., 1.) - - self.assertIn("cond_fun should not have any write states", str(cm.exception)) - + def test1(self): + bm.random.seed() + + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + def cond(x, y): + return x < 6. + + def body(x, y): + a.value += x + b.value *= y + return x + b[0], y + 1. + + res = bm.while_loop(body, cond, operands=(1., 1.)) + print() + print(res) + + def test2(self): + bm.random.seed() + + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + def cond(x, y): + return x < 6. + + def body(x, y): + a.value += x + b.value *= y + return x + b[0], y + 1. + + res = bm.while_loop(body, cond, operands=(1., 1.)) + print() + print(res) + + with jax.disable_jit(): + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + res2 = bm.while_loop(body, cond, operands=(1., 1.)) + print(res2) + self.assertTrue(bm.array_equal(res2[0], res[0])) + self.assertTrue(bm.array_equal(res2[1], res[1])) + + def test3(self): + bm.random.seed() + + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + def cond(x, y): + return bm.all(a.value < 6.) + + def body(x, y): + a.value += x + b.value *= y + + res = bm.while_loop(body, cond, operands=(1., 1.)) + self.assertTrue(bm.allclose(a, 6.)) + self.assertTrue(bm.allclose(b, 1.)) + print() + print(res) + print(a) + print(b) + + # def test4(self): + # bm.random.seed() + # + # a = bm.Variable(bm.zeros(1)) + # b = bm.Variable(bm.ones(1)) + # + # def cond(x, y): + # a.value += 1 + # return bm.all(a.value < 6.) + # + # def body(x, y): + # a.value += x + # b.value *= y + # + # res = bm.while_loop(body, cond, operands=(1., 1.)) + # self.assertTrue(bm.allclose(a, 7.)) # Corrected: condition function increments a each time before checking + # self.assertTrue(bm.allclose(b, 1.)) + # print(res) + # print(a) + # print(b) + # print() + + def test5(self): + bm.random.seed() + + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + c = bm.Variable(bm.ones(1)) + + def cond(x, y): + a.value += 1 + return bm.all(a.value < 6.) + + def body(x, y): + a.value += x + b.value *= y + return x + 1, y + 1 + + @bm.jit + def run(a, b): + x, y = bm.while_loop(body, cond, operands=(a, b)) + return c + x + + # Test that JIT compilation fails when condition function has write states + with self.assertRaises(ValueError) as cm: + run(0., 1.) + + self.assertIn("cond_fun should not have any write states", str(cm.exception)) diff --git a/brainpy/_src/math/object_transform/tests/test_namechecking.py b/brainpy/_src/math/object_transform/tests/test_namechecking.py index 70b60cbb3..0b603ef6e 100644 --- a/brainpy/_src/math/object_transform/tests/test_namechecking.py +++ b/brainpy/_src/math/object_transform/tests/test_namechecking.py @@ -1,20 +1,13 @@ # -*- coding: utf-8 -*- -import pytest import brainpy as bp class LIF(bp.dyn.NeuDyn): - pass - + pass # def test_check_name(): # LIF(1, name='X_LIF') # # with pytest.raises(bp.errors.UniqueNameError): # LIF(1, name='X_LIF') - - - - - diff --git a/brainpy/_src/math/object_transform/tests/test_naming.py b/brainpy/_src/math/object_transform/tests/test_naming.py index ddcae5b93..0fbcb98a3 100644 --- a/brainpy/_src/math/object_transform/tests/test_naming.py +++ b/brainpy/_src/math/object_transform/tests/test_naming.py @@ -14,18 +14,19 @@ # ============================================================================== +import unittest + import brainpy as bp import brainpy.math as bm -import unittest class TestNaming(unittest.TestCase): - def test_clear_name_cache(self): - lif = bp.dyn.LifRef(1, name='a') - with self.assertRaises(bp.errors.UniqueNameError): - lif = bp.dyn.LifRef(1, name='a') - bm.clear_name_cache(ignore_warn=True) - lif = bp.dyn.LifRef(1, name='a') - bm.clear_name_cache() - bm.clear_buffer_memory(array=False, compilation=True) + def test_clear_name_cache(self): + lif = bp.dyn.LifRef(1, name='a') + with self.assertRaises(bp.errors.UniqueNameError): + lif = bp.dyn.LifRef(1, name='a') + bm.clear_name_cache(ignore_warn=True) + lif = bp.dyn.LifRef(1, name='a') + bm.clear_name_cache() + bm.clear_buffer_memory(array=False, compilation=True) diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py index ba020f7c1..88a7e2fb7 100644 --- a/brainpy/_src/math/object_transform/tests/test_variable.py +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -1,38 +1,35 @@ -import brainpy.math as bm import unittest - -class TestVar(unittest.TestCase): - def test1(self): - class A(bm.BrainPyObject): - def __init__(self): - super().__init__() - self.a = bm.Variable(1) - self.f1 = bm.jit(self.f) - self.f2 = bm.jit(self.ff) - self.f3 = bm.jit(self.fff) - - def f(self): - b = self.tracing_variable('b', bm.ones, (1,)) - self.a += (b * 2) - return self.a.value - - def ff(self): - self.b += 1. - - def fff(self): - self.f() - self.ff() - self.b *= self.a - return self.b.value - - print() - f_jit = bm.jit(A().f) - f_jit() - - print() - a = A() - - +import brainpy.math as bm +class TestVar(unittest.TestCase): + def test1(self): + class A(bm.BrainPyObject): + def __init__(self): + super().__init__() + self.a = bm.Variable(1) + self.f1 = bm.jit(self.f) + self.f2 = bm.jit(self.ff) + self.f3 = bm.jit(self.fff) + + def f(self): + b = self.tracing_variable('b', bm.ones, (1,)) + self.a += (b * 2) + return self.a.value + + def ff(self): + self.b += 1. + + def fff(self): + self.f() + self.ff() + self.b *= self.a + return self.b.value + + print() + f_jit = bm.jit(A().f) + f_jit() + + print() + a = A() diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 632c6d79e..bb7346688 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -12,7 +12,7 @@ class Empty(object): - pass + pass empty = Empty() @@ -25,59 +25,59 @@ def _partial_fun( static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = () ): - static_args, dyn_args = [], [] - for i, arg in enumerate(args): - if i in static_argnums: - static_args.append(arg) - else: - static_args.append(empty) - dyn_args.append(arg) - static_kwargs, dyn_kwargs = {}, {} - for k, arg in kwargs.items(): - if k in static_argnames: - static_kwargs[k] = arg - else: - dyn_kwargs[k] = arg - del args, kwargs, static_argnums, static_argnames + static_args, dyn_args = [], [] + for i, arg in enumerate(args): + if i in static_argnums: + static_args.append(arg) + else: + static_args.append(empty) + dyn_args.append(arg) + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames - @wraps(fun) - def new_fun(*dynargs, **dynkwargs): - args = [] - i = 0 - for arg in static_args: - if arg == empty: - args.append(dynargs[i]) - i += 1 - else: - args.append(arg) - return fun(*args, **static_kwargs, **dynkwargs) + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + args = [] + i = 0 + for arg in static_args: + if arg == empty: + args.append(dynargs[i]) + i += 1 + else: + args.append(arg) + return fun(*args, **static_kwargs, **dynkwargs) - return new_fun, dyn_args, dyn_kwargs + return new_fun, dyn_args, dyn_kwargs def dynvar_deprecation(dyn_vars=None): - if dyn_vars is not None: - warnings.warn('\n' - 'From brainpy>=2.4.0, users no longer need to provide ``dyn_vars`` into ' - 'transformation functions like "jit", "grad", "for_loop", etc. ' - 'Because these transformations are capable of automatically collecting them.', - UserWarning) + if dyn_vars is not None: + warnings.warn('\n' + 'From brainpy>=2.4.0, users no longer need to provide ``dyn_vars`` into ' + 'transformation functions like "jit", "grad", "for_loop", etc. ' + 'Because these transformations are capable of automatically collecting them.', + UserWarning) def node_deprecation(child_objs=None): - if child_objs is not None: - warnings.warn('\n' - 'From brainpy>=2.4.0, users no longer need to provide ``child_objs`` into ' - 'transformation functions like "jit", "grad", "for_loop", etc. ' - 'Because these transformations are capable of automatically collecting them.', - UserWarning) + if child_objs is not None: + warnings.warn('\n' + 'From brainpy>=2.4.0, users no longer need to provide ``child_objs`` into ' + 'transformation functions like "jit", "grad", "for_loop", etc. ' + 'Because these transformations are capable of automatically collecting them.', + UserWarning) def abstract(x): - if callable(x): - return x - else: - return jax.api_util.shaped_abstractify(x) + if callable(x): + return x + else: + return jax.api_util.shaped_abstractify(x) def evaluate_dyn_vars( @@ -88,20 +88,20 @@ def evaluate_dyn_vars( use_eval_shape: bool = True, **kwargs ) -> Tuple[VariableStack, Any]: - # arguments - if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun(f, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) - else: - f2, args, kwargs = f, args, kwargs - # stack - with VariableStack() as stack: - if use_eval_shape: - rets = jax.eval_shape(f2, *args, **kwargs) + # arguments + if len(static_argnums) or len(static_argnames): + f2, args, kwargs = _partial_fun(f, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) else: - rets = f2(*args, **kwargs) - return stack, rets + f2, args, kwargs = f, args, kwargs + # stack + with VariableStack() as stack: + if use_eval_shape: + rets = jax.eval_shape(f2, *args, **kwargs) + else: + rets = f2(*args, **kwargs) + return stack, rets def evaluate_dyn_vars_with_cache( @@ -112,24 +112,25 @@ def evaluate_dyn_vars_with_cache( with_return: bool = False, **kwargs ): - # TODO: better way for cache mechanism - stack = get_stack_cache(f) - if stack is None or with_return: - if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun(f, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) - else: - f2, args, kwargs = f, args, kwargs - - with jax.ensure_compile_time_eval(): - with VariableStack() as stack: - rets = eval_shape(f2, *args, **kwargs) - cache_stack(f, stack) # cache - del args, kwargs, f2 - if with_return: - return stack, rets - else: - return stack - return stack + # TODO: better way for cache mechanism + stack = get_stack_cache(f) + if stack is None or with_return: + if len(static_argnums) or len(static_argnames): + f2, args, kwargs = _partial_fun(f, args, kwargs, static_argnums=static_argnums, + static_argnames=static_argnames) + else: + f2, args, kwargs = f, args, kwargs + + with jax.ensure_compile_time_eval(): + with VariableStack() as stack: + rets = eval_shape(f2, *args, **kwargs) + cache_stack(f, stack) # cache + del args, kwargs, f2 + if with_return: + return stack, rets + else: + return stack + return stack def _partial_fun2( @@ -139,41 +140,41 @@ def _partial_fun2( static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = () ): - num_args = len(args) - - # arguments - static_args = dict() - dyn_args = [] - dyn_arg_ids = dict() - static_argnums = list(static_argnums) - dyn_i = 0 - for i in range(num_args): - if i in static_argnums: - static_argnums.remove(i) - static_args[i] = args[i] - else: - dyn_args.append(args[i]) - dyn_arg_ids[i] = dyn_i - dyn_i += 1 - if len(static_argnums) > 0: - raise ValueError(f"Invalid static_argnums: {static_argnums}") - - # keyword arguments - static_kwargs, dyn_kwargs = {}, {} - for k, arg in kwargs.items(): - if k in static_argnames: - static_kwargs[k] = arg - else: - dyn_kwargs[k] = arg - del args, kwargs, static_argnums, static_argnames + num_args = len(args) + + # arguments + static_args = dict() + dyn_args = [] + dyn_arg_ids = dict() + static_argnums = list(static_argnums) + dyn_i = 0 + for i in range(num_args): + if i in static_argnums: + static_argnums.remove(i) + static_args[i] = args[i] + else: + dyn_args.append(args[i]) + dyn_arg_ids[i] = dyn_i + dyn_i += 1 + if len(static_argnums) > 0: + raise ValueError(f"Invalid static_argnums: {static_argnums}") + + # keyword arguments + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames - @wraps(fun) - def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], - **static_kwargs, - **dynkwargs) + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], + **static_kwargs, + **dynkwargs) - return new_fun, dyn_args, dyn_kwargs + return new_fun, dyn_args, dyn_kwargs def eval_shape( @@ -184,45 +185,45 @@ def eval_shape( with_stack: bool = False, **kwargs ): - """Compute the shape/dtype of ``fun`` without any FLOPs. - - Args: - fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - with_stack: Whether evaluate the function within a local variable stack. - static_argnums: The static argument indices. - static_argnames: The static argument names. - - Returns: - The variable stack and the functional returns. - """ - # reorganize the function - if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) - else: - f2 = fun - - # evaluate the function - fun_in_eval_shape.append(fun) - try: - if with_stack: - with VariableStack() as stack: - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + """Compute the shape/dtype of ``fun`` without any FLOPs. + + Args: + fun: The callable function. + *args: The positional arguments. + **kwargs: The keyword arguments. + with_stack: Whether evaluate the function within a local variable stack. + static_argnums: The static argument indices. + static_argnames: The static argument names. + + Returns: + The variable stack and the functional returns. + """ + # reorganize the function + if len(static_argnums) or len(static_argnames): + f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, + static_argnames=static_argnames) + else: + f2 = fun + + # evaluate the function + fun_in_eval_shape.append(fun) + try: + if with_stack: + with VariableStack() as stack: + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) + stack = None + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) + finally: + fun_in_eval_shape.pop() + del f2 + if with_stack: + return stack, returns else: - stack = None - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) - finally: - fun_in_eval_shape.pop() - del f2 - if with_stack: - return stack, returns - else: - return returns - + return returns diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index f7b147cc7..17717dd7a 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,223 +1,221 @@ from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple +import brainstate import jax import numpy as np +from brainstate._state import record_state_value_read, record_state_value_write from jax import numpy as jnp from jax.dtypes import canonicalize_dtype from jax.tree_util import register_pytree_node_class -from brainpy._src.math.ndarray import Array, BaseArray +from brainpy._src.math.ndarray import BaseArray from brainpy._src.math.sharding import BATCH_AXIS from brainpy.errors import MathError -import brainstate -from brainstate._state import record_state_value_read, record_state_value_write - - __all__ = [ - 'Variable', - 'TrainVar', - 'Parameter', - 'VariableView', + 'Variable', + 'TrainVar', + 'Parameter', + 'VariableView', - 'VarList', 'var_list', - 'VarDict', 'var_dict', + 'VarList', 'var_list', + 'VarDict', 'var_dict', ] class VariableStack(dict): - """Variable stack, for collecting all :py:class:`~.Variable` used in the program. - - :py:class:`~.VariableStack` supports all features of python dict. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._values = dict() - - def add(self, var: 'Variable'): - """Add a new :py:class:`~.Variable`.""" - assert isinstance(var, Variable), f'must be instance of {Variable}' - id_ = id(var) - if id_ not in self: - self[id_] = var - self._values[id_] = var._value - - def collect_values(self): - """Collect the value of each variable once again.""" - for id_, var in self.items(): - self._values[id_] = var._value - - def assign_org_values(self): - """Assign the original value for each variable.""" - for id_, var in self.items(): - if id_ in self._values: - var._value = self._values[id_] + """Variable stack, for collecting all :py:class:`~.Variable` used in the program. - def assign(self, data: Union[Dict, Sequence], check: bool = True): - """Assign the value for each :math:`~.Variable` according to the given ``data``. - - Args: - data: dict, list, tuple. The data of all variables - check: bool. Check whether the shape and type of the given data are consistent with original data. + :py:class:`~.VariableStack` supports all features of python dict. """ - if isinstance(data, dict): - assert len(data) == len(self), 'Data length mismatch. ' - if check: + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._values = dict() + + def add(self, var: 'Variable'): + """Add a new :py:class:`~.Variable`.""" + assert isinstance(var, Variable), f'must be instance of {Variable}' + id_ = id(var) + if id_ not in self: + self[id_] = var + self._values[id_] = var._value + + def collect_values(self): + """Collect the value of each variable once again.""" + for id_, var in self.items(): + self._values[id_] = var._value + + def assign_org_values(self): + """Assign the original value for each variable.""" + for id_, var in self.items(): + if id_ in self._values: + var._value = self._values[id_] + + def assign(self, data: Union[Dict, Sequence], check: bool = True): + """Assign the value for each :math:`~.Variable` according to the given ``data``. + + Args: + data: dict, list, tuple. The data of all variables + check: bool. Check whether the shape and type of the given data are consistent with original data. + """ + if isinstance(data, dict): + assert len(data) == len(self), 'Data length mismatch. ' + if check: + for id_, elem in self.items(): + elem.value = data[id_] + else: + for id_, elem in self.items(): + elem._value = data[id_] + elif isinstance(data, (tuple, list)): + assert len(data) == len(self), 'Data length mismatch. ' + if check: + for i, elem in enumerate(self.values()): + elem.value = data[i] + else: + for i, elem in enumerate(self.values()): + elem._value = data[i] + else: + raise TypeError + + def call_on_subset(self, cond: Callable, call: Callable) -> dict: + """Call a function on the subset of this :py:class:`~VariableStack`. + + >>> import brainpy.math as bm + >>> stack = VariableStack(a=bm.Variable(1), b=bm.random.RandomState(1)) + >>> stack.call_on_subset(lambda a: isinstance(a, bm.random.RandomState), + >>> lambda a: a.split_key()) + {'b': Array([3819641963, 2025898573], dtype=uint32)} + + Args: + cond: The function to determine whether the element belongs to the wanted subset. + call: The function to call if the element belongs to the wanted subset. + + Returns: + A dict containing the results of ``call`` function for each element in the ``cond`` constrained subset. + """ + res = dict() for id_, elem in self.items(): - elem.value = data[id_] - else: + if cond(elem): + res[id_] = call(elem) + return res + + def separate_by_instance(self, cls: type) -> Tuple['VariableStack', 'VariableStack']: + """Separate all variables into two groups: (variables that are instances of the given ``cls``, + variables that are not instances of the given ``cls``). + + >>> import brainpy.math as bm + >>> stack = VariableStack(a=bm.Variable(1), b=bm.random.RandomState(1)) + >>> stack.separate_by_instance(bm.random.RandomState) + ({'b': RandomState(key=([0, 1], dtype=uint32))}, + {'a': Variable(value=Array([0.]), dtype=float32)}) + >>> stack.separate_by_instance(bm.Variable) + ({'a': Variable(value=Array([0.]), dtype=float32), + 'b': RandomState(key=([0, 1], dtype=uint32))}, + {}) + + Args: + cls: The class type. + + Returns: + A tuple with two elements: + + - VariableStack of variables that are instances of the given ``cls`` + - VariableStack of variables that are not instances of the given ``cls`` + """ + is_instances = type(self)() + not_instances = type(self)() for id_, elem in self.items(): - elem._value = data[id_] - elif isinstance(data, (tuple, list)): - assert len(data) == len(self), 'Data length mismatch. ' - if check: - for i, elem in enumerate(self.values()): - elem.value = data[i] - else: - for i, elem in enumerate(self.values()): - elem._value = data[i] - else: - raise TypeError - - def call_on_subset(self, cond: Callable, call: Callable) -> dict: - """Call a function on the subset of this :py:class:`~VariableStack`. - - >>> import brainpy.math as bm - >>> stack = VariableStack(a=bm.Variable(1), b=bm.random.RandomState(1)) - >>> stack.call_on_subset(lambda a: isinstance(a, bm.random.RandomState), - >>> lambda a: a.split_key()) - {'b': Array([3819641963, 2025898573], dtype=uint32)} - - Args: - cond: The function to determine whether the element belongs to the wanted subset. - call: The function to call if the element belongs to the wanted subset. - - Returns: - A dict containing the results of ``call`` function for each element in the ``cond`` constrained subset. - """ - res = dict() - for id_, elem in self.items(): - if cond(elem): - res[id_] = call(elem) - return res - - def separate_by_instance(self, cls: type) -> Tuple['VariableStack', 'VariableStack']: - """Separate all variables into two groups: (variables that are instances of the given ``cls``, - variables that are not instances of the given ``cls``). - - >>> import brainpy.math as bm - >>> stack = VariableStack(a=bm.Variable(1), b=bm.random.RandomState(1)) - >>> stack.separate_by_instance(bm.random.RandomState) - ({'b': RandomState(key=([0, 1], dtype=uint32))}, - {'a': Variable(value=Array([0.]), dtype=float32)}) - >>> stack.separate_by_instance(bm.Variable) - ({'a': Variable(value=Array([0.]), dtype=float32), - 'b': RandomState(key=([0, 1], dtype=uint32))}, - {}) - - Args: - cls: The class type. - - Returns: - A tuple with two elements: - - - VariableStack of variables that are instances of the given ``cls`` - - VariableStack of variables that are not instances of the given ``cls`` - """ - is_instances = type(self)() - not_instances = type(self)() - for id_, elem in self.items(): - if isinstance(elem, cls): - is_instances[id_] = elem - else: - not_instances[id_] = elem - return is_instances, not_instances - - def subset_by_instance(self, cls: type) -> 'VariableStack': - """Collect all variables which are instances of the given class type.""" - new_dict = type(self)() - for id_, elem in self.items(): - if isinstance(elem, cls): - new_dict[id_] = elem - return new_dict - - def subset_by_not_instance(self, cls: type) -> 'VariableStack': - """Collect all variables which are not instance of the given class type.""" - new_dict = type(self)() - for id_, elem in self.items(): - if not isinstance(elem, cls): - new_dict[id_] = elem - return new_dict - - instance_of = subset_by_instance - not_instance_of = subset_by_not_instance - - def dict_data_of_subset(self, subset_cond: Callable) -> dict: - """Get data of the given subset constrained by function ``subset_cond``. - - Args: - subset_cond: A function to determine whether the element is in the subset wanted. - - Returns: - A dict of data for elements of the wanted subset. - """ - res = dict() - for id_, elem in self.items(): - if subset_cond(elem): - res[id_] = elem.value - return res - - def dict_data(self) -> dict: - """Get all data in the collected variables with a python dict structure.""" - new_dict = dict() - for id_, elem in tuple(self.items()): - new_dict[id_] = elem.value - return new_dict - - def list_data(self) -> list: - """Get all data in the collected variables with a python list structure.""" - new_list = list() - for elem in tuple(self.values()): - new_list.append(elem.value if isinstance(elem, BaseArray) else elem) - return new_list - - def remove_by_id(self, *ids, error_when_absent=False): - """Remove or pop variables in the stack by the given ids.""" - if error_when_absent: - for id_ in ids: - self.pop(id_) - else: - for id_ in ids: - self.pop(id_, None) - - remove_var_by_id = remove_by_id + if isinstance(elem, cls): + is_instances[id_] = elem + else: + not_instances[id_] = elem + return is_instances, not_instances + + def subset_by_instance(self, cls: type) -> 'VariableStack': + """Collect all variables which are instances of the given class type.""" + new_dict = type(self)() + for id_, elem in self.items(): + if isinstance(elem, cls): + new_dict[id_] = elem + return new_dict - @classmethod - def num_of_stack(self): - return len(var_stack_list) + def subset_by_not_instance(self, cls: type) -> 'VariableStack': + """Collect all variables which are not instance of the given class type.""" + new_dict = type(self)() + for id_, elem in self.items(): + if not isinstance(elem, cls): + new_dict[id_] = elem + return new_dict - @classmethod - def is_first_stack(self): - return len(var_stack_list) == 0 + instance_of = subset_by_instance + not_instance_of = subset_by_not_instance - def __enter__(self) -> 'VariableStack': - self.collect_values() # recollect the original value of each variable - var_stack_list.append(self) - return self + def dict_data_of_subset(self, subset_cond: Callable) -> dict: + """Get data of the given subset constrained by function ``subset_cond``. - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - var_stack_list.pop() - self.assign_org_values() # reassign the original value for each variable - self._values.clear() + Args: + subset_cond: A function to determine whether the element is in the subset wanted. - def __add__(self, other: dict): - new_dict = VariableStack(self) - new_dict.update(other) - new_dict._values.update(self._values) - if isinstance(other, VariableStack): - new_dict._values.update(other._values) - return new_dict + Returns: + A dict of data for elements of the wanted subset. + """ + res = dict() + for id_, elem in self.items(): + if subset_cond(elem): + res[id_] = elem.value + return res + + def dict_data(self) -> dict: + """Get all data in the collected variables with a python dict structure.""" + new_dict = dict() + for id_, elem in tuple(self.items()): + new_dict[id_] = elem.value + return new_dict + + def list_data(self) -> list: + """Get all data in the collected variables with a python list structure.""" + new_list = list() + for elem in tuple(self.values()): + new_list.append(elem.value if isinstance(elem, BaseArray) else elem) + return new_list + + def remove_by_id(self, *ids, error_when_absent=False): + """Remove or pop variables in the stack by the given ids.""" + if error_when_absent: + for id_ in ids: + self.pop(id_) + else: + for id_ in ids: + self.pop(id_, None) + + remove_var_by_id = remove_by_id + + @classmethod + def num_of_stack(self): + return len(var_stack_list) + + @classmethod + def is_first_stack(self): + return len(var_stack_list) == 0 + + def __enter__(self) -> 'VariableStack': + self.collect_values() # recollect the original value of each variable + var_stack_list.append(self) + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + var_stack_list.pop() + self.assign_org_values() # reassign the original value for each variable + self._values.clear() + + def __add__(self, other: dict): + new_dict = VariableStack(self) + new_dict.update(other) + new_dict._values.update(self._values) + if isinstance(other, VariableStack): + new_dict._values.update(other._values) + return new_dict var_stack_list: List[VariableStack] = [] @@ -225,379 +223,378 @@ def __add__(self, other: dict): @register_pytree_node_class class Variable(brainstate.State, BaseArray): - """The pointer to specify the dynamical variable. - - Initializing an instance of ``Variable`` by two ways: - - >>> import brainpy.math as bm - >>> # 1. init a Variable by the concreate data - >>> v1 = bm.Variable(bm.zeros(10)) - >>> # 2. init a Variable by the data shape - >>> v2 = bm.Variable(10) - - Note that when initializing a `Variable` by the data shape, - all values in this `Variable` will be initialized as zeros. - - Args: - value_or_size: Shape, Array, int. The value or the size of the value. - dtype: Any. The type of the data. - batch_axis: optional, int. The batch axis. - axis_names: sequence of str. The name for each axis. - """ - - def __init__( - self, - value_or_size: Any, - dtype: type = None, - batch_axis: int = None, - *, - axis_names: Optional[Sequence[str]] = None, - ): - if isinstance(value_or_size, int): - value = jnp.zeros(value_or_size, dtype=dtype) - elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]): - value = jnp.zeros(value_or_size, dtype=dtype) - else: - value = value_or_size - - if isinstance(value, BaseArray): - value = value.value - super().__init__(value) - - # check batch axis - if isinstance(value, Variable): - if value.batch_axis is not None and batch_axis is not None: - if batch_axis != value.batch_axis: - raise ValueError(f'"batch_axis" is not consistent. Got batch_axis in the given value ' - f'is {value.batch_axis}, but the specified batch_axis is {batch_axis}') - batch_axis = value.batch_axis - - # assign batch axis - self._batch_axis = batch_axis - if batch_axis is not None: - if batch_axis >= np.ndim(self._value): - raise MathError(f'This variables has {np.ndim(self._value)} dimension, ' - f'but the batch axis is set to be {batch_axis}.') - - # ready to trace the variable - if axis_names is not None: - if len(axis_names) + 1 == self.ndim: - axis_names = list(axis_names) - axis_names.insert(self.batch_axis, BATCH_AXIS) - assert len(axis_names) == self.ndim - axis_names = tuple(axis_names) - self.axis_names = axis_names - - @property - def size_without_batch(self): - if self.batch_axis is None: - return self.size - else: - sizes = self.size - return sizes[:self.batch_size] + sizes[self.batch_axis + 1:] + """The pointer to specify the dynamical variable. - @property - def batch_axis(self) -> Optional[int]: - return self._batch_axis + Initializing an instance of ``Variable`` by two ways: - @batch_axis.setter - def batch_axis(self, val): - raise ValueError(f'Cannot set "batch_axis" after creating a {self.__class__.__name__} instance.') - - @property - def batch_size(self) -> Optional[int]: - if self.batch_axis is None: - return None - else: - return self.shape[self.batch_axis] - - @batch_size.setter - def batch_size(self, val): - raise ValueError(f'Cannot set "batch_size" manually.') - - @property - def value(self): - record_state_value_read(self) - return self._read_value() - - @value.setter - def value(self, v): - _value = self.value - ext_shape = jnp.shape(v) - int_shape = jnp.shape(_value) - if self._batch_axis is not None: - ext_shape = ext_shape[:self._batch_axis] + ext_shape[self._batch_axis + 1:] - int_shape = int_shape[:self._batch_axis] + int_shape[self._batch_axis + 1:] - if ext_shape != int_shape: - error = f"The shape of the original data is {int_shape}, while we got {ext_shape}" - error += f' with batch_axis={self._batch_axis}.' - raise MathError(error) - ext_dtype = _get_dtype(v) - int_dtype = self.dtype - if ext_dtype != int_dtype: - raise MathError(f"The dtype of the original data is {int_dtype}, " - f"while we got {ext_dtype}.") - if isinstance(v, BaseArray): - v = v.value - elif isinstance(v, np.ndarray): - v = jnp.asarray(v) - else: - v = v - - if isinstance(v, brainstate.State): # value checking - v = v.value - self._check_value_tree(v) # check the tree structure - record_state_value_write(self) # record the value by the stack (>= level) - self._been_writen = True # set the flag - self._write_value(v) # write the value - - def tree_flatten(self): - """Flattens this variable. - - Returns: - A pair where the first element is a list of leaf values - and the second element is a treedef representing the - structure of the flattened tree. - """ - return (self._value,), None + >>> import brainpy.math as bm + >>> # 1. init a Variable by the concreate data + >>> v1 = bm.Variable(bm.zeros(10)) + >>> # 2. init a Variable by the data shape + >>> v2 = bm.Variable(10) - @classmethod - def tree_unflatten(cls, aux_data, flat_contents): - """Reconstructs a variable from the aux_data and the leaves. + Note that when initializing a `Variable` by the data shape, + all values in this `Variable` will be initialized as zeros. Args: - aux_data: - flat_contents: - - Returns: - The variable. + value_or_size: Shape, Array, int. The value or the size of the value. + dtype: Any. The type of the data. + batch_axis: optional, int. The batch axis. + axis_names: sequence of str. The name for each axis. """ - return cls(*flat_contents) - - def clone(self) -> 'Variable': - """Clone the variable. """ - r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis) - return r - - def __eq__(self, other): - """Override State's __eq__ to use BaseArray behavior for element-wise comparison.""" - from brainpy._src.math.ndarray import _check_input_array, _return - return _return(self.value == _check_input_array(other)) - - def __ne__(self, other): - """Override State's __ne__ to use BaseArray behavior for element-wise comparison.""" - from brainpy._src.math.ndarray import _check_input_array, _return - return _return(self.value != _check_input_array(other)) - - def __lt__(self, other): - """Override State's __lt__ to use BaseArray behavior for element-wise comparison.""" - from brainpy._src.math.ndarray import _check_input_array, _return - return _return(self.value < _check_input_array(other)) - - def __le__(self, other): - """Override State's __le__ to use BaseArray behavior for element-wise comparison.""" - from brainpy._src.math.ndarray import _check_input_array, _return - return _return(self.value <= _check_input_array(other)) - - def __gt__(self, other): - """Override State's __gt__ to use BaseArray behavior for element-wise comparison.""" - from brainpy._src.math.ndarray import _check_input_array, _return - return _return(self.value > _check_input_array(other)) - - def __ge__(self, other): - """Override State's __ge__ to use BaseArray behavior for element-wise comparison.""" - from brainpy._src.math.ndarray import _check_input_array, _return - return _return(self.value >= _check_input_array(other)) + def __init__( + self, + value_or_size: Any, + dtype: type = None, + batch_axis: int = None, + *, + axis_names: Optional[Sequence[str]] = None, + ): + if isinstance(value_or_size, int): + value = jnp.zeros(value_or_size, dtype=dtype) + elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]): + value = jnp.zeros(value_or_size, dtype=dtype) + else: + value = value_or_size + + if isinstance(value, BaseArray): + value = value.value + super().__init__(value) + + # check batch axis + if isinstance(value, Variable): + if value.batch_axis is not None and batch_axis is not None: + if batch_axis != value.batch_axis: + raise ValueError(f'"batch_axis" is not consistent. Got batch_axis in the given value ' + f'is {value.batch_axis}, but the specified batch_axis is {batch_axis}') + batch_axis = value.batch_axis + + # assign batch axis + self._batch_axis = batch_axis + if batch_axis is not None: + if batch_axis >= np.ndim(self._value): + raise MathError(f'This variables has {np.ndim(self._value)} dimension, ' + f'but the batch axis is set to be {batch_axis}.') + + # ready to trace the variable + if axis_names is not None: + if len(axis_names) + 1 == self.ndim: + axis_names = list(axis_names) + axis_names.insert(self.batch_axis, BATCH_AXIS) + assert len(axis_names) == self.ndim + axis_names = tuple(axis_names) + self.axis_names = axis_names + + @property + def size_without_batch(self): + if self.batch_axis is None: + return self.size + else: + sizes = self.size + return sizes[:self.batch_size] + sizes[self.batch_axis + 1:] + + @property + def batch_axis(self) -> Optional[int]: + return self._batch_axis + + @batch_axis.setter + def batch_axis(self, val): + raise ValueError(f'Cannot set "batch_axis" after creating a {self.__class__.__name__} instance.') + + @property + def batch_size(self) -> Optional[int]: + if self.batch_axis is None: + return None + else: + return self.shape[self.batch_axis] + + @batch_size.setter + def batch_size(self, val): + raise ValueError(f'Cannot set "batch_size" manually.') + + @property + def value(self): + record_state_value_read(self) + return self._read_value() + + @value.setter + def value(self, v): + _value = self.value + ext_shape = jnp.shape(v) + int_shape = jnp.shape(_value) + if self._batch_axis is not None: + ext_shape = ext_shape[:self._batch_axis] + ext_shape[self._batch_axis + 1:] + int_shape = int_shape[:self._batch_axis] + int_shape[self._batch_axis + 1:] + if ext_shape != int_shape: + error = f"The shape of the original data is {int_shape}, while we got {ext_shape}" + error += f' with batch_axis={self._batch_axis}.' + raise MathError(error) + ext_dtype = _get_dtype(v) + int_dtype = self.dtype + if ext_dtype != int_dtype: + raise MathError(f"The dtype of the original data is {int_dtype}, " + f"while we got {ext_dtype}.") + if isinstance(v, BaseArray): + v = v.value + elif isinstance(v, np.ndarray): + v = jnp.asarray(v) + else: + v = v + + if isinstance(v, brainstate.State): # value checking + v = v.value + self._check_value_tree(v) # check the tree structure + record_state_value_write(self) # record the value by the stack (>= level) + self._been_writen = True # set the flag + self._write_value(v) # write the value + + def tree_flatten(self): + """Flattens this variable. + + Returns: + A pair where the first element is a list of leaf values + and the second element is a treedef representing the + structure of the flattened tree. + """ + return (self._value,), None + + @classmethod + def tree_unflatten(cls, aux_data, flat_contents): + """Reconstructs a variable from the aux_data and the leaves. + + Args: + aux_data: + flat_contents: + + Returns: + The variable. + """ + return cls(*flat_contents) + + def clone(self) -> 'Variable': + """Clone the variable. """ + r = type(self)(jnp.array(self.value, copy=True), batch_axis=self.batch_axis) + return r + + def __eq__(self, other): + """Override State's __eq__ to use BaseArray behavior for element-wise comparison.""" + from brainpy._src.math.ndarray import _check_input_array, _return + return _return(self.value == _check_input_array(other)) + + def __ne__(self, other): + """Override State's __ne__ to use BaseArray behavior for element-wise comparison.""" + from brainpy._src.math.ndarray import _check_input_array, _return + return _return(self.value != _check_input_array(other)) + + def __lt__(self, other): + """Override State's __lt__ to use BaseArray behavior for element-wise comparison.""" + from brainpy._src.math.ndarray import _check_input_array, _return + return _return(self.value < _check_input_array(other)) + + def __le__(self, other): + """Override State's __le__ to use BaseArray behavior for element-wise comparison.""" + from brainpy._src.math.ndarray import _check_input_array, _return + return _return(self.value <= _check_input_array(other)) + + def __gt__(self, other): + """Override State's __gt__ to use BaseArray behavior for element-wise comparison.""" + from brainpy._src.math.ndarray import _check_input_array, _return + return _return(self.value > _check_input_array(other)) + + def __ge__(self, other): + """Override State's __ge__ to use BaseArray behavior for element-wise comparison.""" + from brainpy._src.math.ndarray import _check_input_array, _return + return _return(self.value >= _check_input_array(other)) def _get_dtype(v): - if hasattr(v, 'dtype'): - dtype = v.dtype - else: - dtype = canonicalize_dtype(type(v)) - return dtype + if hasattr(v, 'dtype'): + dtype = v.dtype + else: + dtype = canonicalize_dtype(type(v)) + return dtype def _as_jax_array_(obj): - return obj.value if isinstance(obj, BaseArray) else obj + return obj.value if isinstance(obj, BaseArray) else obj @register_pytree_node_class class TrainVar(Variable): - """The pointer to specify the trainable variable. - """ - - def __init__( - self, - value_or_size: Any, - dtype: type = None, - batch_axis: int = None, - *, - axis_names: Optional[Sequence[str]] = None, - ): - super().__init__( - value_or_size, - dtype=dtype, - batch_axis=batch_axis, - axis_names=axis_names, - ) + """The pointer to specify the trainable variable. + """ + + def __init__( + self, + value_or_size: Any, + dtype: type = None, + batch_axis: int = None, + *, + axis_names: Optional[Sequence[str]] = None, + ): + super().__init__( + value_or_size, + dtype=dtype, + batch_axis=batch_axis, + axis_names=axis_names, + ) @register_pytree_node_class class Parameter(Variable): - """The pointer to specify the parameter. - """ - - def __init__( - self, - value_or_size: Any, - dtype: type = None, - batch_axis: int = None, - *, - axis_names: Optional[Sequence[str]] = None, - ): - super().__init__( - value_or_size, - dtype=dtype, - batch_axis=batch_axis, - axis_names=axis_names, - ) - - -class VariableView(Variable): - """A view of a Variable instance. - - This class is used to create a subset view of ``brainpy.math.Variable``. - - >>> import brainpy.math as bm - >>> bm.random.seed(123) - >>> origin = bm.Variable(bm.random.random(5)) - >>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2] - VariableView([0.02920651, 0.19066381], dtype=float32) - - ``VariableView`` can be used to update the subset of the original - Variable instance, and make operations on this subset of the Variable. - - >>> view[:] = 1. - >>> view - VariableView([1., 1.], dtype=float32) - >>> origin - Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32) - >>> view + 10 - Array([11., 11.], dtype=float32) - >>> view *= 10 - VariableView([10., 10.], dtype=float32) - - The above example demonstrates that the updating of an ``VariableView`` instance - is actually made in the original ``Variable`` instance. - - Moreover, it's worthy to note that ``VariableView`` is not a PyTree. - """ - _need_record = False - - def __init__( - self, - value: Variable, - index: Any, - ): - self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, BaseArray)) - if not isinstance(value, Variable): - raise ValueError('Must be instance of Variable.') - super().__init__(value.value, batch_axis=value.batch_axis) - self._value = value - - def __repr__(self) -> str: - print_code = repr(self._value) - prefix = f'{self.__class__.__name__}' - blank = " " * (len(prefix) + 1) - lines = print_code.split("\n") - lines[0] = prefix + "(" + lines[0] - for i in range(1, len(lines)): - lines[i] = blank + lines[i] - lines[-1] += "," - lines.append(blank + f'index={self.index})') - print_code = "\n".join(lines) - return print_code - - @property - def value(self): - return self._value[self.index] - - @value.setter - def value(self, v): - int_shape = self.shape - if self.batch_axis is None: - ext_shape = v.shape - else: - ext_shape = v.shape[:self.batch_axis] + v.shape[self.batch_axis + 1:] - int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:] - if ext_shape != int_shape: - error = f"The shape of the original data is {self.shape}, while we got {v.shape}" - if self.batch_axis is None: - error += '. Do you forget to set "batch_axis" when initialize this variable?' - else: - error += f' with batch_axis={self.batch_axis}.' - raise MathError(error) - if v.dtype != self._value.dtype: - raise MathError(f"The dtype of the original data is {self._value.dtype}, " - f"while we got {v.dtype}.") - self._value[self.index] = v.value if isinstance(v, BaseArray) else v - + """The pointer to specify the parameter. + """ -@register_pytree_node_class -class VarList(list): - """A sequence of :py:class:`~.Variable`, which is compatible with - :py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`. + def __init__( + self, + value_or_size: Any, + dtype: type = None, + batch_axis: int = None, + *, + axis_names: Optional[Sequence[str]] = None, + ): + super().__init__( + value_or_size, + dtype=dtype, + batch_axis=batch_axis, + axis_names=axis_names, + ) - Actually, :py:class:`~.VarList` is a python list. - :py:class:`~.VarList` is specifically designed to store Variable instances. +class VariableView(Variable): + """A view of a Variable instance. - """ + This class is used to create a subset view of ``brainpy.math.Variable``. - def __init__(self, seq=()): - super().__init__() - self.extend(seq) + >>> import brainpy.math as bm + >>> bm.random.seed(123) + >>> origin = bm.Variable(bm.random.random(5)) + >>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2] + VariableView([0.02920651, 0.19066381], dtype=float32) + + ``VariableView`` can be used to update the subset of the original + Variable instance, and make operations on this subset of the Variable. + + >>> view[:] = 1. + >>> view + VariableView([1., 1.], dtype=float32) + >>> origin + Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32) + >>> view + 10 + Array([11., 11.], dtype=float32) + >>> view *= 10 + VariableView([10., 10.], dtype=float32) + + The above example demonstrates that the updating of an ``VariableView`` instance + is actually made in the original ``Variable`` instance. + + Moreover, it's worthy to note that ``VariableView`` is not a PyTree. + """ + _need_record = False + + def __init__( + self, + value: Variable, + index: Any, + ): + self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, BaseArray)) + if not isinstance(value, Variable): + raise ValueError('Must be instance of Variable.') + super().__init__(value.value, batch_axis=value.batch_axis) + self._value = value + + def __repr__(self) -> str: + print_code = repr(self._value) + prefix = f'{self.__class__.__name__}' + blank = " " * (len(prefix) + 1) + lines = print_code.split("\n") + lines[0] = prefix + "(" + lines[0] + for i in range(1, len(lines)): + lines[i] = blank + lines[i] + lines[-1] += "," + lines.append(blank + f'index={self.index})') + print_code = "\n".join(lines) + return print_code + + @property + def value(self): + return self._value[self.index] + + @value.setter + def value(self, v): + int_shape = self.shape + if self.batch_axis is None: + ext_shape = v.shape + else: + ext_shape = v.shape[:self.batch_axis] + v.shape[self.batch_axis + 1:] + int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:] + if ext_shape != int_shape: + error = f"The shape of the original data is {self.shape}, while we got {v.shape}" + if self.batch_axis is None: + error += '. Do you forget to set "batch_axis" when initialize this variable?' + else: + error += f' with batch_axis={self.batch_axis}.' + raise MathError(error) + if v.dtype != self._value.dtype: + raise MathError(f"The dtype of the original data is {self._value.dtype}, " + f"while we got {v.dtype}.") + self._value[self.index] = v.value if isinstance(v, BaseArray) else v - def append(self, element) -> 'VarList': - if not isinstance(element, Variable): - raise TypeError(f'element must be an instance of {Variable.__name__}.') - super().append(element) - return self - def extend(self, iterable) -> 'VarList': - for element in iterable: - self.append(element) - return self +@register_pytree_node_class +class VarList(list): + """A sequence of :py:class:`~.Variable`, which is compatible with + :py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`. - def __setitem__(self, key, value) -> 'VarList': - """Override the item setting. + Actually, :py:class:`~.VarList` is a python list. - This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden, - and only the value can be changed for each element. + :py:class:`~.VarList` is specifically designed to store Variable instances. - >>> import brainpy.math as bm - >>> l = bm.var_list([bm.Variable(1), bm.Variable(2)]) - >>> print(id(l[0]), id(l[1])) - 2077748389472 2077748389552 - >>> l[1] = bm.random.random(2) - >>> l[0] = bm.random.random(1) - >>> print(id(l[0]), id(l[1])) # still the original Variable instances - 2077748389472 2077748389552 """ - if isinstance(key, int): - self[key].value = value - else: - super().__setitem__(key, value) - return self - def tree_flatten(self): - return tuple(self), None - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(children) + def __init__(self, seq=()): + super().__init__() + self.extend(seq) + + def append(self, element) -> 'VarList': + if not isinstance(element, Variable): + raise TypeError(f'element must be an instance of {Variable.__name__}.') + super().append(element) + return self + + def extend(self, iterable) -> 'VarList': + for element in iterable: + self.append(element) + return self + + def __setitem__(self, key, value) -> 'VarList': + """Override the item setting. + + This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden, + and only the value can be changed for each element. + + >>> import brainpy.math as bm + >>> l = bm.var_list([bm.Variable(1), bm.Variable(2)]) + >>> print(id(l[0]), id(l[1])) + 2077748389472 2077748389552 + >>> l[1] = bm.random.random(2) + >>> l[0] = bm.random.random(1) + >>> print(id(l[0]), id(l[1])) # still the original Variable instances + 2077748389472 2077748389552 + """ + if isinstance(key, int): + self[key].value = value + else: + super().__setitem__(key, value) + return self + + def tree_flatten(self): + return tuple(self), None + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(children) var_list = VarList @@ -605,62 +602,62 @@ def tree_unflatten(cls, aux_data, children): @register_pytree_node_class class VarDict(dict): - """A dictionary of :py:class:`~.Variable`, which is compatible with - :py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`. - - Actually, :py:class:`~.VarDict` is a python dict. - - :py:class:`~.VarDict` is specifically designed to store Variable instances. - - """ - - def _check_elem(self, elem): - if not isinstance(elem, Variable): - raise TypeError(f'Element should be {Variable.__name__}, but got {type(elem)}.') - return elem + """A dictionary of :py:class:`~.Variable`, which is compatible with + :py:func:`.vars()` operation in a :py:class:`~.BrainPyObject`. - def __init__(self, *args, **kwargs): - super().__init__() - self.update(*args, **kwargs) + Actually, :py:class:`~.VarDict` is a python dict. - def update(self, *args, **kwargs) -> 'VarDict': - for arg in args: - if isinstance(arg, dict): - for k, v in arg.items(): - self[k] = v - elif isinstance(arg, tuple): - assert len(arg) == 2 - self[arg[0]] = args[1] - for k, v in kwargs.items(): - self[k] = v - return self + :py:class:`~.VarDict` is specifically designed to store Variable instances. - def __setitem__(self, key, value) -> 'VarDict': - """Override the item setting. - - This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden. - - >>> import brainpy.math as bm - >>> d = bm.var_dict({'a': bm.Variable(1), 'b': bm.Variable(2)}) - >>> print(id(d['a']), id(d['b'])) - 2077667833504 2077748488176 - >>> d['b'] = bm.random.random(2) - >>> d['a'] = bm.random.random(1) - >>> print(id(d['a']), id(d['b'])) # still the original Variable instances - 2077667833504 2077748488176 """ - if key in self: - self[key].value = value - else: - super().__setitem__(key, self._check_elem(value)) - return self - - def tree_flatten(self): - return tuple(self.values()), tuple(self.keys()) - @classmethod - def tree_unflatten(cls, keys, values): - return cls(jax.util.safe_zip(keys, values)) + def _check_elem(self, elem): + if not isinstance(elem, Variable): + raise TypeError(f'Element should be {Variable.__name__}, but got {type(elem)}.') + return elem + + def __init__(self, *args, **kwargs): + super().__init__() + self.update(*args, **kwargs) + + def update(self, *args, **kwargs) -> 'VarDict': + for arg in args: + if isinstance(arg, dict): + for k, v in arg.items(): + self[k] = v + elif isinstance(arg, tuple): + assert len(arg) == 2 + self[arg[0]] = args[1] + for k, v in kwargs.items(): + self[k] = v + return self + + def __setitem__(self, key, value) -> 'VarDict': + """Override the item setting. + + This function ensures that the Variable appended in the :py:class:`~.VarList` will not be overridden. + + >>> import brainpy.math as bm + >>> d = bm.var_dict({'a': bm.Variable(1), 'b': bm.Variable(2)}) + >>> print(id(d['a']), id(d['b'])) + 2077667833504 2077748488176 + >>> d['b'] = bm.random.random(2) + >>> d['a'] = bm.random.random(1) + >>> print(id(d['a']), id(d['b'])) # still the original Variable instances + 2077667833504 2077748488176 + """ + if key in self: + self[key].value = value + else: + super().__setitem__(key, self._check_elem(value)) + return self + + def tree_flatten(self): + return tuple(self.values()), tuple(self.keys()) + + @classmethod + def tree_unflatten(cls, keys, values): + return cls(jax.util.safe_zip(keys, values)) var_dict = VarDict diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 1d99ae423..335293b7a 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp from jax.tree_util import tree_map -import numpy as np from brainpy import check, tools from .compat_numpy import fill_diagonal @@ -15,14 +14,14 @@ from .ndarray import Array, _return __all__ = [ - 'shared_args_over_time', - 'remove_diag', - 'clip_by_norm', - 'exprel', - 'is_float_type', - # 'reduce', - 'add_axis', - 'add_axes', + 'shared_args_over_time', + 'remove_diag', + 'clip_by_norm', + 'exprel', + 'is_float_type', + # 'reduce', + 'add_axis', + 'add_axes', ] @@ -31,116 +30,116 @@ def shared_args_over_time(num_step: Optional[int] = None, dt: Optional[float] = None, t0: float = 0., include_dt: bool = True): - """Form a shared argument over time for the inference of a :py:class:`~.DynamicalSystem`. - - Parameters:: - - num_step: int - The number of time step. Provide either ``duration`` or ``num_step``. - duration: float - The total duration. Provide either ``duration`` or ``num_step``. - dt: float - The duration for each time step. - t0: float - The start time. - include_dt: bool - Produce the time steps at every time step. - - Returns:: - - shared: DotDict - The shared arguments over the given time. - """ - dt = get_dt() if dt is None else dt - check.is_float(dt, 'dt', allow_none=False) - if duration is None: - check.is_integer(num_step, 'num_step', allow_none=False) - else: - check.is_float(duration, 'duration', allow_none=False) - num_step = int(duration / dt) - r = tools.DotDict(i=jnp.arange(num_step, dtype=get_int())) - r['t'] = r['i'] * dt + t0 - if include_dt: - r['dt'] = jnp.ones_like(r['t']) * dt - return r + """Form a shared argument over time for the inference of a :py:class:`~.DynamicalSystem`. + + Parameters:: + + num_step: int + The number of time step. Provide either ``duration`` or ``num_step``. + duration: float + The total duration. Provide either ``duration`` or ``num_step``. + dt: float + The duration for each time step. + t0: float + The start time. + include_dt: bool + Produce the time steps at every time step. + + Returns:: + + shared: DotDict + The shared arguments over the given time. + """ + dt = get_dt() if dt is None else dt + check.is_float(dt, 'dt', allow_none=False) + if duration is None: + check.is_integer(num_step, 'num_step', allow_none=False) + else: + check.is_float(duration, 'duration', allow_none=False) + num_step = int(duration / dt) + r = tools.DotDict(i=jnp.arange(num_step, dtype=get_int())) + r['t'] = r['i'] * dt + t0 + if include_dt: + r['dt'] = jnp.ones_like(r['t']) * dt + return r def remove_diag(arr): - """Remove the diagonal of the matrix. + """Remove the diagonal of the matrix. - Parameters:: + Parameters:: - arr: ArrayType - The matrix with the shape of `(M, N)`. + arr: ArrayType + The matrix with the shape of `(M, N)`. - Returns:: + Returns:: - arr: Array - The matrix without diagonal which has the shape of `(M, N-1)`. - """ - if arr.ndim != 2: - raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') - eyes = _return(jnp.ones(arr.shape, dtype=bool)) - fill_diagonal(eyes, False) - return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) + arr: Array + The matrix without diagonal which has the shape of `(M, N-1)`. + """ + if arr.ndim != 2: + raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') + eyes = _return(jnp.ones(arr.shape, dtype=bool)) + fill_diagonal(eyes, False) + return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) def clip_by_norm(t, clip_norm, axis=None): - def f(l): - return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm) + def f(l): + return l * clip_norm / jnp.maximum(jnp.sqrt(jnp.sum(l * l, axis=axis, keepdims=True)), clip_norm) - return tree_map(f, t) + return tree_map(f, t) def _exprel(x, threshold): - def true_f(x): - x2 = x * x - return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120. + def true_f(x): + x2 = x * x + return 1. + x / 2. + x2 / 6. + x2 * x / 24.0 # + x2 * x2 / 120. - def false_f(x): - return (jnp.exp(x) - 1) / x + def false_f(x): + return (jnp.exp(x) - 1) / x - # return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x) - # return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) - return jax.lax.select(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) + # return jax.lax.cond(jnp.abs(x) < threshold, true_f, false_f, x) + # return jnp.where(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) + return jax.lax.select(jnp.abs(x) <= threshold, 1. + x / 2. + x * x / 6., (jnp.exp(x) - 1) / x) def exprel(x, threshold: float = None): - """Relative error exponential, ``(exp(x) - 1)/x``. - - When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can - suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of - precision that occurs when ``x`` is near zero. - - Args: - x: ndarray. Input array. ``x`` must contain real numbers. - threshold: float. - - Returns: - ``(exp(x) - 1)/x``, computed element-wise. - """ - x = as_jax(x) - if threshold is None: - if hasattr(x, 'dtype') and x.dtype == jnp.float64: - threshold = 1e-8 - else: - threshold = 1e-5 - return _exprel(x, threshold) + """Relative error exponential, ``(exp(x) - 1)/x``. + + When ``x`` is near zero, ``exp(x)`` is near 1, so the numerical calculation of ``exp(x) - 1`` can + suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of + precision that occurs when ``x`` is near zero. + + Args: + x: ndarray. Input array. ``x`` must contain real numbers. + threshold: float. + + Returns: + ``(exp(x) - 1)/x``, computed element-wise. + """ + x = as_jax(x) + if threshold is None: + if hasattr(x, 'dtype') and x.dtype == jnp.float64: + threshold = 1e-8 + else: + threshold = 1e-5 + return _exprel(x, threshold) def is_float_type(x: Union[Array, jax.Array]): - return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") + return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16") def add_axis(x: Union[Array, jax.Array], new_position: int): - x = as_jax(x) - return jnp.expand_dims(x, new_position) + x = as_jax(x) + return jnp.expand_dims(x, new_position) def add_axes(x: Union[Array, jax.Array], n_axes, pos2len): - x = as_jax(x) - repeats = [1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = add_axis(x, axis_position) - repeats[axis_position] = axis_length - return jnp.tile(x, repeats) + x = as_jax(x) + repeats = [1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = add_axis(x, axis_position) + repeats[axis_position] = axis_length + return jnp.tile(x, repeats) diff --git a/brainpy/_src/math/pre_syn_post.py b/brainpy/_src/math/pre_syn_post.py index 4239ad345..ee4e1138c 100644 --- a/brainpy/_src/math/pre_syn_post.py +++ b/brainpy/_src/math/pre_syn_post.py @@ -4,323 +4,323 @@ import jax.numpy as jnp from jax import vmap, jit, ops as jops -from brainpy._src.math.interoperability import as_jax from brainpy._src.math import event +from brainpy._src.math.interoperability import as_jax from brainpy.errors import MathError __all__ = [ - # pre-to-post - 'pre2post_sum', - 'pre2post_prod', - 'pre2post_max', - 'pre2post_min', - 'pre2post_mean', - - # pre-to-post event operator - 'pre2post_event_sum', - 'pre2post_csr_event_sum', - - # pre-to-syn - 'pre2syn', - - # syn-to-post - 'syn2post_sum', 'syn2post', - 'syn2post_prod', - 'syn2post_max', - 'syn2post_min', - 'syn2post_mean', - 'syn2post_softmax', + # pre-to-post + 'pre2post_sum', + 'pre2post_prod', + 'pre2post_max', + 'pre2post_min', + 'pre2post_mean', + + # pre-to-post event operator + 'pre2post_event_sum', + 'pre2post_csr_event_sum', + + # pre-to-syn + 'pre2syn', + + # syn-to-post + 'syn2post_sum', 'syn2post', + 'syn2post_prod', + 'syn2post_max', + 'syn2post_min', + 'syn2post_mean', + 'syn2post_softmax', ] def _raise_pre_ids_is_none(pre_ids): - if pre_ids is None: - raise MathError(f'pre2post synaptic computation needs "pre_ids" ' - f'when providing heterogeneous "pre_values" ' - f'(brainpy.math.ndim(pre_values) != 0).') + if pre_ids is None: + raise MathError(f'pre2post synaptic computation needs "pre_ids" ' + f'when providing heterogeneous "pre_values" ' + f'(brainpy.math.ndim(pre_values) != 0).') def pre2post_event_sum(events, pre2post, post_num: int, values=1.): - """The pre-to-post event-driven synaptic summation with `CSR` synapse structure. + """The pre-to-post event-driven synaptic summation with `CSR` synapse structure. - When ``values`` is a scalar, this function is equivalent to + When ``values`` is a scalar, this function is equivalent to - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - post_val = np.zeros(post_num) - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[j]] += values + post_val = np.zeros(post_num) + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[j]] += values - When ``values`` is a vector (with the length of ``len(post_ids)``), - this function is equivalent to + When ``values`` is a vector (with the length of ``len(post_ids)``), + this function is equivalent to - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - post_val = np.zeros(post_num) + post_val = np.zeros(post_num) - post_ids, idnptr = pre2post - for i in range(pre_num): - if events[i]: - for j in range(idnptr[i], idnptr[i+1]): - post_val[post_ids[j]] += values[j] + post_ids, idnptr = pre2post + for i in range(pre_num): + if events[i]: + for j in range(idnptr[i], idnptr[i+1]): + post_val[post_ids[j]] += values[j] - Parameters:: + Parameters:: - events: ArrayType - The events, must be bool. - pre2post: tuple of ArrayType, tuple of ArrayType - A tuple contains the connection information of pre-to-post. - post_num: int - The number of post-synaptic group. - values: float, ArrayType - The value to make summation. + events: ArrayType + The events, must be bool. + pre2post: tuple of ArrayType, tuple of ArrayType + A tuple contains the connection information of pre-to-post. + post_num: int + The number of post-synaptic group. + values: float, ArrayType + The value to make summation. - Returns:: + Returns:: - out: ArrayType - A tensor with the shape of ``post_num``. - """ - indices, idnptr = pre2post - events = as_jax(events) - indices = as_jax(indices) - idnptr = as_jax(idnptr) - values = as_jax(values) - return event.csrmv(values, indices, idnptr, events, - shape=(events.shape[0], post_num), - transpose=True) + out: ArrayType + A tensor with the shape of ``post_num``. + """ + indices, idnptr = pre2post + events = as_jax(events) + indices = as_jax(indices) + idnptr = as_jax(idnptr) + values = as_jax(values) + return event.csrmv(values, indices, idnptr, events, + shape=(events.shape[0], post_num), + transpose=True) pre2post_csr_event_sum = pre2post_event_sum def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic summation. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] += pre_values[pre_ids[i]] - - Parameters:: - - pre_values: float, ArrayType - The pre-synaptic values. - post_ids: ArrayType - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - pre_ids: optional, ArrayType - The connected pre-synaptic neuron ids. - - Returns:: - - post_val: ArrayType - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_jax(pre_values) - post_ids = as_jax(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_jax(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].add(pre_values) + """The pre-to-post synaptic summation. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] += pre_values[pre_ids[i]] + + Parameters:: + + pre_values: float, ArrayType + The pre-synaptic values. + post_ids: ArrayType + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + pre_ids: optional, ArrayType + The connected pre-synaptic neuron ids. + + Returns:: + + post_val: ArrayType + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_jax(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].add(pre_values) def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic production. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] *= pre_values[pre_ids[i]] - - Parameters:: - - pre_values: float, ArrayType - The pre-synaptic values. - pre_ids: ArrayType - The connected pre-synaptic neuron ids. - post_ids: ArrayType - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns:: - - post_val: ArrayType - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_jax(pre_values) - post_ids = as_jax(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_jax(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].multiply(pre_values) + """The pre-to-post synaptic production. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] *= pre_values[pre_ids[i]] + + Parameters:: + + pre_values: float, ArrayType + The pre-synaptic values. + pre_ids: ArrayType + The connected pre-synaptic neuron ids. + post_ids: ArrayType + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns:: + + post_val: ArrayType + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_jax(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].multiply(pre_values) def pre2post_min(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic minimization. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]]) - - Parameters:: - - pre_values: float, ArrayType - The pre-synaptic values. - pre_ids: ArrayType - The connected pre-synaptic neuron ids. - post_ids: ArrayType - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns:: - - post_val: ArrayType - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_jax(pre_values) - post_ids = as_jax(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_jax(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].min(pre_values) + """The pre-to-post synaptic minimization. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]]) + + Parameters:: + + pre_values: float, ArrayType + The pre-synaptic values. + pre_ids: ArrayType + The connected pre-synaptic neuron ids. + post_ids: ArrayType + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns:: + + post_val: ArrayType + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_jax(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].min(pre_values) def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic maximization. - - This function is equivalent to: - - .. highlight:: python - .. code-block:: python - - post_val = np.zeros(post_num) - for i, j in zip(pre_ids, post_ids): - post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) - - Parameters:: - - pre_values: float, ArrayType - The pre-synaptic values. - pre_ids: ArrayType - The connected pre-synaptic neuron ids. - post_ids: ArrayType - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns:: - - post_val: ArrayType - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_jax(pre_values) - post_ids = as_jax(post_ids) - if jnp.ndim(pre_values) != 0: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_jax(pre_ids) - pre_values = pre_values[pre_ids] - return out.at[post_ids].max(pre_values) + """The pre-to-post synaptic maximization. + + This function is equivalent to: + + .. highlight:: python + .. code-block:: python + + post_val = np.zeros(post_num) + for i, j in zip(pre_ids, post_ids): + post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) + + Parameters:: + + pre_values: float, ArrayType + The pre-synaptic values. + pre_ids: ArrayType + The connected pre-synaptic neuron ids. + post_ids: ArrayType + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns:: + + post_val: ArrayType + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) + if jnp.ndim(pre_values) != 0: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_jax(pre_ids) + pre_values = pre_values[pre_ids] + return out.at[post_ids].max(pre_values) def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): - """The pre-to-post synaptic mean computation. - - Parameters:: - - pre_values: float, ArrayType - The pre-synaptic values. - pre_ids: ArrayType - The connected pre-synaptic neuron ids. - post_ids: ArrayType - The connected post-synaptic neuron ids. - post_num: int - Output dimension. The number of post-synaptic neurons. - - Returns:: - - post_val: ArrayType - The value with the size of post-synaptic neurons. - """ - out = jnp.zeros(post_num) - pre_values = as_jax(pre_values) - post_ids = as_jax(post_ids) - if jnp.ndim(pre_values) == 0: - return out.at[post_ids].set(pre_values) - # return out.at[jnp.unique(post_ids)].set(pre_values) - else: - _raise_pre_ids_is_none(pre_ids) - pre_ids = as_jax(pre_ids) - pre_values = pre2syn(pre_values, pre_ids) - return syn2post_mean(pre_values, post_ids, post_num) + """The pre-to-post synaptic mean computation. + + Parameters:: + + pre_values: float, ArrayType + The pre-synaptic values. + pre_ids: ArrayType + The connected pre-synaptic neuron ids. + post_ids: ArrayType + The connected post-synaptic neuron ids. + post_num: int + Output dimension. The number of post-synaptic neurons. + + Returns:: + + post_val: ArrayType + The value with the size of post-synaptic neurons. + """ + out = jnp.zeros(post_num) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) + if jnp.ndim(pre_values) == 0: + return out.at[post_ids].set(pre_values) + # return out.at[jnp.unique(post_ids)].set(pre_values) + else: + _raise_pre_ids_is_none(pre_ids) + pre_ids = as_jax(pre_ids) + pre_values = pre2syn(pre_values, pre_ids) + return syn2post_mean(pre_values, post_ids, post_num) _pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None)) def pre2syn(pre_values, pre_ids): - """The pre-to-syn computation. + """The pre-to-syn computation. - Change the pre-synaptic data to the data with the dimension of synapses. + Change the pre-synaptic data to the data with the dimension of synapses. - This function is equivalent to: + This function is equivalent to: - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - syn_val = np.zeros(len(pre_ids)) - for syn_i, pre_i in enumerate(pre_ids): - syn_val[i] = pre_values[pre_i] + syn_val = np.zeros(len(pre_ids)) + for syn_i, pre_i in enumerate(pre_ids): + syn_val[i] = pre_values[pre_i] - Parameters:: + Parameters:: - pre_values: float, ArrayType - The pre-synaptic value. - pre_ids: ArrayType - The pre-synaptic neuron index. + pre_values: float, ArrayType + The pre-synaptic value. + pre_ids: ArrayType + The pre-synaptic neuron index. - Returns:: + Returns:: - syn_val: ArrayType - The synaptic value. - """ - pre_values = as_jax(pre_values) - pre_ids = as_jax(pre_ids) - if jnp.ndim(pre_values) == 0: - return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values - else: - return _pre2syn(pre_ids, pre_values) + syn_val: ArrayType + The synaptic value. + """ + pre_values = as_jax(pre_values) + pre_ids = as_jax(pre_ids) + if jnp.ndim(pre_values) == 0: + return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values + else: + return _pre2syn(pre_ids, pre_values) _jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3)) @@ -330,210 +330,210 @@ def pre2syn(pre_values, pre_ids): def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=False): - """The syn-to-post summation computation. + """The syn-to-post summation computation. - This function is equivalent to: + This function is equivalent to: - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] += syn_values[syn_i] + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] += syn_values[syn_i] - Parameters:: + Parameters:: - syn_values: ArrayType - The synaptic values. - post_ids: ArrayType - The post-synaptic neuron ids. - post_num: int - The number of the post-synaptic neurons. + syn_values: ArrayType + The synaptic values. + post_ids: ArrayType + The post-synaptic neuron ids. + post_num: int + The number of the post-synaptic neurons. - Returns:: + Returns:: - post_val: ArrayType - The post-synaptic value. - """ - post_ids = as_jax(post_ids) - syn_values = as_jax(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + post_val: ArrayType + The post-synaptic value. + """ + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) syn2post = syn2post_sum def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=False): - """The syn-to-post product computation. + """The syn-to-post product computation. - This function is equivalent to: + This function is equivalent to: - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] *= syn_values[syn_i] + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] *= syn_values[syn_i] - Parameters:: + Parameters:: - syn_values: ArrayType - The synaptic values. - post_ids: ArrayType - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + syn_values: ArrayType + The synaptic values. + post_ids: ArrayType + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. - Returns:: + Returns:: - post_val: ArrayType - The post-synaptic value. - """ - post_ids = as_jax(post_ids) - syn_values = as_jax(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) + post_val: ArrayType + The post-synaptic value. + """ + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=False): - """The syn-to-post maximum computation. + """The syn-to-post maximum computation. - This function is equivalent to: + This function is equivalent to: - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) - Parameters:: + Parameters:: - syn_values: ArrayType - The synaptic values. - post_ids: ArrayType - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + syn_values: ArrayType + The synaptic values. + post_ids: ArrayType + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. - Returns:: + Returns:: - post_val: ArrayType - The post-synaptic value. - """ - post_ids = as_jax(post_ids) - syn_values = as_jax(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) + post_val: ArrayType + The post-synaptic value. + """ + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False): - """The syn-to-post minimization computation. + """The syn-to-post minimization computation. - This function is equivalent to: + This function is equivalent to: - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - post_val = np.zeros(post_num) - for syn_i, post_i in enumerate(post_ids): - post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) + post_val = np.zeros(post_num) + for syn_i, post_i in enumerate(post_ids): + post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) - Parameters:: + Parameters:: - syn_values: ArrayType - The synaptic values. - post_ids: ArrayType - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + syn_values: ArrayType + The synaptic values. + post_ids: ArrayType + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. - Returns:: + Returns:: - post_val: ArrayType - The post-synaptic value. - """ - post_ids = as_jax(post_ids) - syn_values = as_jax(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) + post_val: ArrayType + The post-synaptic value. + """ + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False): - """The syn-to-post mean computation. - - Parameters:: - - syn_values: ArrayType - The synaptic values. - post_ids: ArrayType - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns:: - - post_val: ArrayType - The post-synaptic value. - """ - post_ids = as_jax(post_ids) - syn_values = as_jax(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) - return jnp.nan_to_num(nominator / denominator) + """The syn-to-post mean computation. + + Parameters:: + + syn_values: ArrayType + The synaptic values. + post_ids: ArrayType + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns:: + + post_val: ArrayType + The post-synaptic value. + """ + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted) + return jnp.nan_to_num(nominator / denominator) def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False): - """The syn-to-post softmax computation. - - Parameters:: - - syn_values: ArrayType - The synaptic values. - post_ids: ArrayType - The post-synaptic neuron ids. If ``post_ids`` is generated by - ``brainpy.conn.TwoEndConnector``, then it has sorted indices. - Otherwise, this function cannot guarantee indices are sorted. - You's better set ``indices_are_sorted=False``. - post_num: int - The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. - - Returns:: - - post_val: ArrayType - The post-synaptic value. - """ - post_ids = as_jax(post_ids) - syn_values = as_jax(syn_values) - if syn_values.dtype == jnp.bool_: - syn_values = jnp.asarray(syn_values, dtype=jnp.int32) - syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) - syn_values = syn_values - syn_maxs[post_ids] - syn_values = jnp.exp(syn_values) - normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) - softmax = syn_values / normalizers[post_ids] - return jnp.nan_to_num(softmax) + """The syn-to-post softmax computation. + + Parameters:: + + syn_values: ArrayType + The synaptic values. + post_ids: ArrayType + The post-synaptic neuron ids. If ``post_ids`` is generated by + ``brainpy.conn.TwoEndConnector``, then it has sorted indices. + Otherwise, this function cannot guarantee indices are sorted. + You's better set ``indices_are_sorted=False``. + post_num: int + The number of the post-synaptic neurons. + indices_are_sorted: whether ``post_ids`` is known to be sorted. + + Returns:: + + post_val: ArrayType + The post-synaptic value. + """ + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) + if syn_values.dtype == jnp.bool_: + syn_values = jnp.asarray(syn_values, dtype=jnp.int32) + syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) + syn_values = syn_values - syn_maxs[post_ids] + syn_values = jnp.exp(syn_values) + normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) + softmax = syn_values / normalizers[post_ids] + return jnp.nan_to_num(softmax) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 871d20f73..e482abe07 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -21,4806 +21,4805 @@ from .object_transform.variables import Variable __all__ = [ - 'RandomState', 'Generator', 'DEFAULT', - - 'seed', 'default_rng', 'split_key', 'split_keys', - - # numpy compatibility - 'rand', 'randint', 'random_integers', 'randn', 'random', - 'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta', - 'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto', - 'poisson', 'standard_cauchy', 'standard_exponential', 'standard_gamma', - 'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli', - 'lognormal', 'binomial', 'chisquare', 'dirichlet', 'geometric', 'f', - 'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal', - 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power', - 'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min', - 'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical', 'canonicalize_shape', - - # pytorch compatibility - 'rand_like', 'randint_like', 'randn_like', + 'RandomState', 'Generator', 'DEFAULT', + + 'seed', 'default_rng', 'split_key', 'split_keys', + + # numpy compatibility + 'rand', 'randint', 'random_integers', 'randn', 'random', + 'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta', + 'exponential', 'gamma', 'gumbel', 'laplace', 'logistic', 'normal', 'pareto', + 'poisson', 'standard_cauchy', 'standard_exponential', 'standard_gamma', + 'standard_normal', 'standard_t', 'uniform', 'truncated_normal', 'bernoulli', + 'lognormal', 'binomial', 'chisquare', 'dirichlet', 'geometric', 'f', + 'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal', + 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power', + 'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min', + 'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical', 'canonicalize_shape', + + # pytorch compatibility + 'rand_like', 'randint_like', 'randn_like', ] JAX_RAND_KEY = jax.Array def _formalize_key(key): - if isinstance(key, int): - return jr.PRNGKey(key) - elif isinstance(key, (BaseArray, jnp.ndarray, np.ndarray)): - if key.dtype != jnp.uint32: - raise TypeError('key must be a int or an array with two uint32.') - if key.size != 2: - raise TypeError('key must be a int or an array with two uint32.') - return jnp.asarray(key) - else: - raise TypeError('key must be a int or an array with two uint32.') + if isinstance(key, int): + return jr.PRNGKey(key) + elif isinstance(key, (BaseArray, jnp.ndarray, np.ndarray)): + if key.dtype != jnp.uint32: + raise TypeError('key must be a int or an array with two uint32.') + if key.size != 2: + raise TypeError('key must be a int or an array with two uint32.') + return jnp.asarray(key) + else: + raise TypeError('key must be a int or an array with two uint32.') def _size2shape(size): - if size is None: - return () - elif isinstance(size, (tuple, list)): - return tuple(size) - else: - return (size,) + if size is None: + return () + elif isinstance(size, (tuple, list)): + return tuple(size) + else: + return (size,) def _check_shape(name, shape, *param_shapes): - if param_shapes: - shape_ = lax.broadcast_shapes(shape, *param_shapes) - if shape != shape_: - msg = ("{} parameter shapes must be broadcast-compatible with shape " - "argument, and the result of broadcasting the shapes must equal " - "the shape argument, but got result {} for shape argument {}.") - raise ValueError(msg.format(name, shape_, shape)) + if param_shapes: + shape_ = lax.broadcast_shapes(shape, *param_shapes) + if shape != shape_: + msg = ("{} parameter shapes must be broadcast-compatible with shape " + "argument, and the result of broadcasting the shapes must equal " + "the shape argument, but got result {} for shape argument {}.") + raise ValueError(msg.format(name, shape_, shape)) def _as_jax_array(a): - return a.value if isinstance(a, BaseArray) else a + return a.value if isinstance(a, BaseArray) else a def _is_python_scalar(x): - if hasattr(x, 'aval'): - return x.aval.weak_type - elif np.ndim(x) == 0: - return True - elif isinstance(x, (bool, int, float, complex)): - return True - else: - return False + if hasattr(x, 'aval'): + return x.aval.weak_type + elif np.ndim(x) == 0: + return True + elif isinstance(x, (bool, int, float, complex)): + return True + else: + return False python_scalar_dtypes = { - bool: np.dtype('bool'), - int: np.dtype('int64'), - float: np.dtype('float64'), - complex: np.dtype('complex128'), + bool: np.dtype('bool'), + int: np.dtype('int64'), + float: np.dtype('float64'), + complex: np.dtype('complex128'), } def _dtype(x, *, canonicalize: bool = False): - """Return the dtype object for a value or type, optionally canonicalized based on X64 mode.""" - if x is None: - raise ValueError(f"Invalid argument to dtype: {x}.") - elif isinstance(x, type) and x in python_scalar_dtypes: - dt = python_scalar_dtypes[x] - elif type(x) in python_scalar_dtypes: - dt = python_scalar_dtypes[type(x)] - elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)): - dt = x.dtype - else: - dt = np.result_type(x) - return dtypes.canonicalize_dtype(dt) if canonicalize else dt + """Return the dtype object for a value or type, optionally canonicalized based on X64 mode.""" + if x is None: + raise ValueError(f"Invalid argument to dtype: {x}.") + elif isinstance(x, type) and x in python_scalar_dtypes: + dt = python_scalar_dtypes[x] + elif type(x) in python_scalar_dtypes: + dt = python_scalar_dtypes[type(x)] + elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)): + dt = x.dtype + else: + dt = np.result_type(x) + return dtypes.canonicalize_dtype(dt) if canonicalize else dt def _const(example, val): - if _is_python_scalar(example): - dtype = dtypes.canonicalize_dtype(type(example)) - val = dtypes.scalar_type_of(example)(val) - return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype) - else: - dtype = dtypes.canonicalize_dtype(example.dtype) - return np.array(val, dtype) + if _is_python_scalar(example): + dtype = dtypes.canonicalize_dtype(type(example)) + val = dtypes.scalar_type_of(example)(val) + return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype) + else: + dtype = dtypes.canonicalize_dtype(example.dtype) + return np.array(val, dtype) _tr_params = namedtuple( - "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"] + "tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"] ) def _get_tr_params(n, p): - # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the - # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5). - mu = n * p - spq = jnp.sqrt(mu * (1 - p)) - c = mu + 0.5 - b = 1.15 + 2.53 * spq - a = -0.0873 + 0.0248 * b + 0.01 * p - alpha = (2.83 + 5.1 / b) * spq - u_r = 0.43 - v_r = 0.92 - 4.2 / b - m = jnp.floor((n + 1) * p).astype(n.dtype) - log_p = jnp.log(p) - log1_p = jnp.log1p(-p) - log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) + - _stirling_approx_tail(m) + _stirling_approx_tail(n - m)) - return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h) + # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the + # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5). + mu = n * p + spq = jnp.sqrt(mu * (1 - p)) + c = mu + 0.5 + b = 1.15 + 2.53 * spq + a = -0.0873 + 0.0248 * b + 0.01 * p + alpha = (2.83 + 5.1 / b) * spq + u_r = 0.43 + v_r = 0.92 - 4.2 / b + m = jnp.floor((n + 1) * p).astype(n.dtype) + log_p = jnp.log(p) + log1_p = jnp.log1p(-p) + log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) + + _stirling_approx_tail(m) + _stirling_approx_tail(n - m)) + return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h) def _stirling_approx_tail(k): - precomputed = jnp.array([0.08106146679532726, - 0.04134069595540929, - 0.02767792568499834, - 0.02079067210376509, - 0.01664469118982119, - 0.01387612882307075, - 0.01189670994589177, - 0.01041126526197209, - 0.009255462182712733, - 0.008330563433362871, ]) - kp1 = k + 1 - kp1sq = (k + 1) ** 2 - return jnp.where(k < 10, - precomputed[k], - (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1) + precomputed = jnp.array([0.08106146679532726, + 0.04134069595540929, + 0.02767792568499834, + 0.02079067210376509, + 0.01664469118982119, + 0.01387612882307075, + 0.01189670994589177, + 0.01041126526197209, + 0.009255462182712733, + 0.008330563433362871, ]) + kp1 = k + 1 + kp1sq = (k + 1) ** 2 + return jnp.where(k < 10, + precomputed[k], + (1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1) def _binomial_btrs(key, p, n): - """ - Based on the transformed rejection sampling algorithm (BTRS) from the - following reference: - - Hormann, "The Generation of Binonmial Random Variates" - (https://core.ac.uk/download/pdf/11007254.pdf) - """ - - def _btrs_body_fn(val): - _, key, _, _ = val - key, key_u, key_v = jr.split(key, 3) - u = jr.uniform(key_u) - v = jr.uniform(key_v) - u = u - 0.5 - k = jnp.floor( - (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c - ).astype(n.dtype) - return k, key, u, v - - def _btrs_cond_fn(val): - def accept_fn(k, u, v): - # See acceptance condition in Step 3. (Page 3) of TRS algorithm - # v <= f(k) * g_grad(u) / alpha - - m = tr_params.m - log_p = tr_params.log_p - log1_p = tr_params.log1_p - # See: formula for log(f(k)) at bottom of Page 5. - log_f = ( - (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0)) - + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p) - + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k)) - + tr_params.log_h - ) - g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b - return jnp.log((v * tr_params.alpha) / g) <= log_f - - k, key, u, v = val - early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r) - early_reject = (k < 0) | (k > n) - return lax.cond( - early_accept | early_reject, - (), - lambda _: ~early_accept, - (k, u, v), - lambda x: ~accept_fn(*x), - ) + """ + Based on the transformed rejection sampling algorithm (BTRS) from the + following reference: + + Hormann, "The Generation of Binonmial Random Variates" + (https://core.ac.uk/download/pdf/11007254.pdf) + """ - tr_params = _get_tr_params(n, p) - ret = lax.while_loop( - _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0) - ) # use k=-1 initially so that cond_fn returns True - return ret[0] + def _btrs_body_fn(val): + _, key, _, _ = val + key, key_u, key_v = jr.split(key, 3) + u = jr.uniform(key_u) + v = jr.uniform(key_v) + u = u - 0.5 + k = jnp.floor( + (2 * tr_params.a / (0.5 - jnp.abs(u)) + tr_params.b) * u + tr_params.c + ).astype(n.dtype) + return k, key, u, v + + def _btrs_cond_fn(val): + def accept_fn(k, u, v): + # See acceptance condition in Step 3. (Page 3) of TRS algorithm + # v <= f(k) * g_grad(u) / alpha + + m = tr_params.m + log_p = tr_params.log_p + log1_p = tr_params.log1_p + # See: formula for log(f(k)) at bottom of Page 5. + log_f = ( + (n + 1.0) * jnp.log((n - m + 1.0) / (n - k + 1.0)) + + (k + 0.5) * (jnp.log((n - k + 1.0) / (k + 1.0)) + log_p - log1_p) + + (_stirling_approx_tail(k) - _stirling_approx_tail(n - k)) + + tr_params.log_h + ) + g = (tr_params.a / (0.5 - jnp.abs(u)) ** 2) + tr_params.b + return jnp.log((v * tr_params.alpha) / g) <= log_f + + k, key, u, v = val + early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r) + early_reject = (k < 0) | (k > n) + return lax.cond( + early_accept | early_reject, + (), + lambda _: ~early_accept, + (k, u, v), + lambda x: ~accept_fn(*x), + ) + + tr_params = _get_tr_params(n, p) + ret = lax.while_loop( + _btrs_cond_fn, _btrs_body_fn, (-1, key, 1.0, 1.0) + ) # use k=-1 initially so that cond_fn returns True + return ret[0] def _binomial_inversion(key, p, n): - def _binom_inv_body_fn(val): - i, key, geom_acc = val - key, key_u = jr.split(key) - u = jr.uniform(key_u) - geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 - geom_acc = geom_acc + geom - return i + 1, key, geom_acc + def _binom_inv_body_fn(val): + i, key, geom_acc = val + key, key_u = jr.split(key) + u = jr.uniform(key_u) + geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1 + geom_acc = geom_acc + geom + return i + 1, key, geom_acc - def _binom_inv_cond_fn(val): - i, _, geom_acc = val - return geom_acc <= n + def _binom_inv_cond_fn(val): + i, _, geom_acc = val + return geom_acc <= n - log1_p = jnp.log1p(-p) - ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0)) - return ret[0] + log1_p = jnp.log1p(-p) + ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0)) + return ret[0] def _binomial_dispatch(key, p, n): - def dispatch(key, p, n): - is_le_mid = p <= 0.5 - pq = jnp.where(is_le_mid, p, 1 - p) - mu = n * pq - k = lax.cond( - mu < 10, - (key, pq, n), - lambda x: _binomial_inversion(*x), - (key, pq, n), - lambda x: _binomial_btrs(*x), + def dispatch(key, p, n): + is_le_mid = p <= 0.5 + pq = jnp.where(is_le_mid, p, 1 - p) + mu = n * pq + k = lax.cond( + mu < 10, + (key, pq, n), + lambda x: _binomial_inversion(*x), + (key, pq, n), + lambda x: _binomial_btrs(*x), + ) + return jnp.where(is_le_mid, k, n - k) + + # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types + cond0 = jnp.isfinite(p) & (n > 0) & (p > 0) + return lax.cond( + cond0 & (p < 1), + (key, p, n), + lambda x: dispatch(*x), + (), + lambda _: jnp.where(cond0, n, 0), ) - return jnp.where(is_le_mid, k, n - k) - - # Return 0 for nan `p` or negative `n`, since nan values are not allowed for integer types - cond0 = jnp.isfinite(p) & (n > 0) & (p > 0) - return lax.cond( - cond0 & (p < 1), - (key, p, n), - lambda x: dispatch(*x), - (), - lambda _: jnp.where(cond0, n, 0), - ) @partial(jit, static_argnums=(3,)) def _binomial(key, p, n, shape): - shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) - # reshape to map over axis 0 - p = jnp.reshape(jnp.broadcast_to(p, shape), -1) - n = jnp.reshape(jnp.broadcast_to(n, shape), -1) - key = jr.split(key, jnp.size(p)) - if jax.default_backend() == "cpu": - ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) - else: - ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) - return jnp.reshape(ret, shape) + shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) + # reshape to map over axis 0 + p = jnp.reshape(jnp.broadcast_to(p, shape), -1) + n = jnp.reshape(jnp.broadcast_to(n, shape), -1) + key = jr.split(key, jnp.size(p)) + if jax.default_backend() == "cpu": + ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) + else: + ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) + return jnp.reshape(ret, shape) @partial(jit, static_argnums=(2,)) def _categorical(key, p, shape): - # this implementation is fast when event shape is small, and slow otherwise - # Ref: https://stackoverflow.com/a/34190035 - shape = shape or p.shape[:-1] - s = jnp.cumsum(p, axis=-1) - r = jr.uniform(key, shape=shape + (1,)) - return jnp.sum(s < r, axis=-1) + # this implementation is fast when event shape is small, and slow otherwise + # Ref: https://stackoverflow.com/a/34190035 + shape = shape or p.shape[:-1] + s = jnp.cumsum(p, axis=-1) + r = jr.uniform(key, shape=shape + (1,)) + return jnp.sum(s < r, axis=-1) def _scatter_add_one(operand, indices, updates): - return lax.scatter_add( - operand, - indices, - updates, - lax.ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,), - ), - ) + return lax.scatter_add( + operand, + indices, + updates, + lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ), + ) def _reshape(x, shape): - if isinstance(x, (int, float, np.ndarray, np.generic)): - return np.reshape(x, shape) - else: - return jnp.reshape(x, shape) + if isinstance(x, (int, float, np.ndarray, np.generic)): + return np.reshape(x, shape) + else: + return jnp.reshape(x, shape) def _promote_shapes(*args, shape=()): - # adapted from lax.lax_numpy - if len(args) < 2 and not shape: - return args - else: - shapes = [jnp.shape(arg) for arg in args] - num_dims = len(lax.broadcast_shapes(shape, *shapes)) - return [ - _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg - for arg, s in zip(args, shapes) - ] + # adapted from lax.lax_numpy + if len(args) < 2 and not shape: + return args + else: + shapes = [jnp.shape(arg) for arg in args] + num_dims = len(lax.broadcast_shapes(shape, *shapes)) + return [ + _reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg + for arg, s in zip(args, shapes) + ] @partial(jit, static_argnums=(3, 4)) def _multinomial(key, p, n, n_max, shape=()): - if jnp.shape(n) != jnp.shape(p)[:-1]: - broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) - n = jnp.broadcast_to(n, broadcast_shape) - p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) - shape = shape or p.shape[:-1] - if n_max == 0: - return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int)) - # get indices from categorical distribution then gather the result - indices = _categorical(key, p, (n_max,) + shape) - # mask out values when counts is heterogeneous - if jnp.ndim(n) > 0: - mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0] - mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype) - excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), - jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], - -1) - else: - mask = 1 - excess = 0 - # NB: we transpose to move batch shape to the front - indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T - samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype), - jnp.expand_dims(indices_2D, axis=-1), - jnp.ones(indices_2D.shape, dtype=indices.dtype)) - return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess + if jnp.shape(n) != jnp.shape(p)[:-1]: + broadcast_shape = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + n = jnp.broadcast_to(n, broadcast_shape) + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) + shape = shape or p.shape[:-1] + if n_max == 0: + return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int)) + # get indices from categorical distribution then gather the result + indices = _categorical(key, p, (n_max,) + shape) + # mask out values when counts is heterogeneous + if jnp.ndim(n) > 0: + mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0] + mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype) + excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1), + jnp.zeros(jnp.shape(n) + (p.shape[-1] - 1,))], + -1) + else: + mask = 1 + excess = 0 + # NB: we transpose to move batch shape to the front + indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T + samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype), + jnp.expand_dims(indices_2D, axis=-1), + jnp.ones(indices_2D.shape, dtype=indices.dtype)) + return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess @partial(jit, static_argnums=(2, 3)) def _von_mises_centered(key, concentration, shape, dtype=jnp.float64): - """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal. - - Returns:: - - out: array_like - centered samples from von Mises - - References:: - - .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; - Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf - - """ - shape = shape or jnp.shape(concentration) - dtype = jnp.result_type(dtype) - concentration = lax.convert_element_type(concentration, dtype) - concentration = jnp.broadcast_to(concentration, shape) - - s_cutoff_map = { - jnp.dtype(jnp.float16): 1.8e-1, - jnp.dtype(jnp.float32): 2e-2, - jnp.dtype(jnp.float64): 1.2e-4, - } - s_cutoff = s_cutoff_map.get(dtype) - - r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2) - rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) - s_exact = (1.0 + rho ** 2) / (2.0 * rho) - - s_approximate = 1.0 / concentration - - s = jnp.where(concentration > s_cutoff, s_exact, s_approximate) - - def cond_fn(*args): - """check if all are done or reached max number of iterations""" - i, _, done, _, _ = args[0] - return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done))) - - def body_fn(*args): - i, key, done, _, w = args[0] - uni_ukey, uni_vkey, key = jr.split(key, 3) - u = jr.uniform( - key=uni_ukey, - shape=shape, - dtype=concentration.dtype, - minval=-1.0, - maxval=1.0, - ) - z = jnp.cos(jnp.pi * u) - w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done - y = concentration * (s - w) - v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype) - accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y) - return i + 1, key, accept | done, u, w + """Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal. + + Returns:: + + out: array_like + centered samples from von Mises + + References:: - init_done = jnp.zeros(shape, dtype=bool) - init_u = jnp.zeros(shape) - init_w = jnp.zeros(shape) + .. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; + Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf - _, _, done, u, w = lax.while_loop( - cond_fun=cond_fn, - body_fun=body_fn, - init_val=(jnp.array(0), key, init_done, init_u, init_w), - ) + """ + shape = shape or jnp.shape(concentration) + dtype = jnp.result_type(dtype) + concentration = lax.convert_element_type(concentration, dtype) + concentration = jnp.broadcast_to(concentration, shape) + + s_cutoff_map = { + jnp.dtype(jnp.float16): 1.8e-1, + jnp.dtype(jnp.float32): 2e-2, + jnp.dtype(jnp.float64): 1.2e-4, + } + s_cutoff = s_cutoff_map.get(dtype) + + r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2) + rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration) + s_exact = (1.0 + rho ** 2) / (2.0 * rho) + + s_approximate = 1.0 / concentration + + s = jnp.where(concentration > s_cutoff, s_exact, s_approximate) + + def cond_fn(*args): + """check if all are done or reached max number of iterations""" + i, _, done, _, _ = args[0] + return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done))) + + def body_fn(*args): + i, key, done, _, w = args[0] + uni_ukey, uni_vkey, key = jr.split(key, 3) + u = jr.uniform( + key=uni_ukey, + shape=shape, + dtype=concentration.dtype, + minval=-1.0, + maxval=1.0, + ) + z = jnp.cos(jnp.pi * u) + w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done + y = concentration * (s - w) + v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype) + accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y) + return i + 1, key, accept | done, u, w + + init_done = jnp.zeros(shape, dtype=bool) + init_u = jnp.zeros(shape) + init_w = jnp.zeros(shape) + + _, _, done, u, w = lax.while_loop( + cond_fun=cond_fn, + body_fun=body_fn, + init_val=(jnp.array(0), key, init_done, init_u, init_w), + ) - return jnp.sign(u) * jnp.arccos(w) + return jnp.sign(u) * jnp.arccos(w) def _loc_scale(loc, scale, value): - if loc is None: - if scale is None: - return value + if loc is None: + if scale is None: + return value + else: + return value * scale else: - return value * scale - else: - if scale is None: - return value + loc - else: - return value * scale + loc + if scale is None: + return value + loc + else: + return value * scale + loc def _check_py_seq(seq): - return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq + return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq def canonicalize_shape(shape: Shape, context: str = "") -> tuple[Any, ...]: - """Canonicalizes and checks for errors in a user-provided shape value. + """Canonicalizes and checks for errors in a user-provided shape value. - Args: - shape: a Python value that represents a shape. + Args: + shape: a Python value that represents a shape. - Returns: - A tuple of canonical dimension values. - """ - try: - return tuple(map(_canonicalize_dimension, shape)) - except TypeError: - pass - raise _invalid_shape_error(shape, context) + Returns: + A tuple of canonical dimension values. + """ + try: + return tuple(map(_canonicalize_dimension, shape)) + except TypeError: + pass + raise _invalid_shape_error(shape, context) @register_pytree_node_class class RandomState(Variable): - """RandomState that track the random generator state. """ - - def __init__( - self, - seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None, - seed: Optional[int] = None, - ): - """RandomState constructor. + """RandomState that track the random generator state. """ + + def __init__( + self, + seed_or_key: Optional[Union[int, Array, jax.Array, np.ndarray]] = None, + seed: Optional[int] = None, + ): + """RandomState constructor. + + Parameters:: + + seed_or_key: int, Array, optional + It can be an integer for initial seed of the random number generator, + or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. + + .. versionadded:: 2.2.3.4 + + seed : int, ArrayType, optional + Same as `seed_or_key`. + + .. deprecated:: 2.2.3.4 + Will be removed since version 2.4. + """ + if seed is not None: + if seed_or_key is not None: + raise ValueError('Please set "seed_or_key" or "seed", not both.') + seed_or_key = seed + warnings.warn('Please use `seed_or_key` instead. ' + 'seed will be removed since 2.4.0', UserWarning) + + with jax.ensure_compile_time_eval(): + if seed_or_key is None: + seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32) + if isinstance(seed_or_key, int): + key = jr.PRNGKey(seed_or_key) + else: + if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32: + raise ValueError('key must be an array with dtype uint32. ' + f'But we got {seed_or_key}') + key = seed_or_key + super(RandomState, self).__init__(key) + + def __repr__(self) -> str: + print_code = repr(self.value) + i = print_code.index('(') + name = self.__class__.__name__ + return f'{name}(key={print_code[i:]})' + + # ------------------- # + # seed and random key # + # ------------------- # + + def clone(self): + return type(self)(self.split_key()) + + def seed(self, seed_or_key=None, seed=None): + """Sets a new random seed. + + Parameters:: + + seed_or_key: int, ArrayType, optional + It can be an integer for initial seed of the random number generator, + or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. + + .. versionadded:: 2.2.3.4 + + seed : int, ArrayType, optional + Same as `seed_or_key`. + + .. deprecated:: 2.2.3.4 + Will be removed since version 2.4. + """ + if seed is not None: + if seed_or_key is not None: + raise ValueError('Please set "seed_or_key" or "seed", not both.') + seed_or_key = seed + warnings.warn('Please use seed_or_key instead. ' + 'seed will be removed since 2.4.0', UserWarning) + + if seed_or_key is None: + seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32) + if isinstance(seed_or_key, int): + key = jr.PRNGKey(seed_or_key) + else: + if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32: + raise ValueError('key must be an array with dtype uint32. ' + f'But we got {seed_or_key}') + key = seed_or_key + self._value = key + + def _ensure_value_exists(self): + """Ensure that the random state has a valid value, regenerate if needed.""" + if not isinstance(self._value, np.ndarray): + with jax.ensure_compile_time_eval(): + if not isinstance(self._value, jax.core.Tracer): + if self._value.is_deleted(): + seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32) + self._value = seed_or_key + + @property + def value(self): + self._ensure_value_exists() + record_state_value_read(self) + return self._read_value() + + # def check_if_deleted(self): + # if self.value.is_deleted(): + # self.seed() + + def split_key(self): + """Create a new seed from the current seed. + """ + if not isinstance(self.value, jnp.ndarray): + self._value = jnp.asarray(self.value) + keys = jr.split(self.value, num=2) + self._value = keys[0] + return keys[1] + + def split_keys(self, n): + """Create multiple seeds from the current seed. This is used + internally by `pmap` and `vmap` to ensure that random numbers + are different in parallel threads. + + Parameters:: + + n : int + The number of seeds to generate. + """ + keys = jr.split(self.value, n + 1) + self._value = keys[0] + return keys[1:] + + # ---------------- # + # random functions # + # ---------------- # + + def rand(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + r = jr.uniform(key, shape=dn, minval=0., maxval=1.) + return _return(r) + + def randint(self, + low, + high=None, + size: Optional[Union[int, Sequence[int]]] = None, + dtype=int, key: Optional[Union[int, JAX_RAND_KEY]] = None): + dtype = get_int() if dtype is None else dtype + low = _as_jax_array(low) + high = _as_jax_array(high) + if high is None: + high = low + low = 0 + high = _check_py_seq(high) + low = _check_py_seq(low) + if size is None: + size = lax.broadcast_shapes(jnp.shape(low), + jnp.shape(high)) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.randint(key, + shape=_size2shape(size), + minval=low, maxval=high, dtype=dtype) + return _return(r) + + def random_integers(self, + low, + high=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + low = _as_jax_array(low) + high = _as_jax_array(high) + low = _check_py_seq(low) + high = _check_py_seq(high) + if high is None: + high = low + low = 1 + high += 1 + if size is None: + size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.randint(key, + shape=_size2shape(size), + minval=low, + maxval=high) + return _return(r) + + def randn(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + r = jr.normal(key, shape=dn) + return _return(r) + + def random(self, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.) + return _return(r) - Parameters:: - - seed_or_key: int, Array, optional - It can be an integer for initial seed of the random number generator, - or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. + def random_sample(self, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r = self.random(size=size, key=key) + return _return(r) - .. versionadded:: 2.2.3.4 + def ranf(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r = self.random(size=size, key=key) + return _return(r) - seed : int, ArrayType, optional - Same as `seed_or_key`. + def sample(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r = self.random(size=size, key=key) + return _return(r) - .. deprecated:: 2.2.3.4 - Will be removed since version 2.4. - """ - if seed is not None: - if seed_or_key is not None: - raise ValueError('Please set "seed_or_key" or "seed", not both.') - seed_or_key = seed - warnings.warn('Please use `seed_or_key` instead. ' - 'seed will be removed since 2.4.0', UserWarning) + def choice(self, a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + a = _as_jax_array(a) + p = _as_jax_array(p) + a = _check_py_seq(a) + p = _check_py_seq(p) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p) + return _return(r) + + def permutation(self, x, axis: int = 0, independent: bool = False, key: Optional[Union[int, JAX_RAND_KEY]] = None): + x = x.value if isinstance(x, BaseArray) else x + x = _check_py_seq(x) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.permutation(key, x, axis=axis, independent=independent) + return _return(r) + + def shuffle(self, x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None): + if not isinstance(x, BaseArray): + raise TypeError('This numpy operator needs in-place updating, therefore ' + 'inputs should be brainpy Array.') + key = self.split_key() if key is None else _formalize_key(key) + x.value = jr.permutation(key, x.value, axis=axis) + + def beta(self, a, b, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + a = a.value if isinstance(a, BaseArray) else a + b = b.value if isinstance(b, BaseArray) else b + a = _check_py_seq(a) + b = _check_py_seq(b) + if size is None: + size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b)) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.beta(key, a=a, b=b, shape=_size2shape(size)) + return _return(r) + + def exponential(self, scale=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + scale = _as_jax_array(scale) + scale = _check_py_seq(scale) + if size is None: + size = jnp.shape(scale) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.exponential(key, shape=_size2shape(size)) + if scale is not None: + r = r / scale + return _return(r) + + def gamma(self, shape, scale=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + shape = _as_jax_array(shape) + scale = _as_jax_array(scale) + shape = _check_py_seq(shape) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale)) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.gamma(key, a=shape, shape=_size2shape(size)) + if scale is not None: + r = r * scale + return _return(r) + + def gumbel(self, loc=None, scale=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + loc = _as_jax_array(loc) + scale = _as_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + key = self.split_key() if key is None else _formalize_key(key) + r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size))) + return _return(r) + + def laplace(self, loc=None, scale=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + loc = _as_jax_array(loc) + scale = _as_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + key = self.split_key() if key is None else _formalize_key(key) + r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size))) + return _return(r) + + def logistic(self, loc=None, scale=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + loc = _as_jax_array(loc) + scale = _as_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) + key = self.split_key() if key is None else _formalize_key(key) + r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size))) + return _return(r) + + def normal(self, loc=None, scale=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + loc = _as_jax_array(loc) + scale = _as_jax_array(scale) + loc = _check_py_seq(loc) + scale = _check_py_seq(scale) + if size is None: + size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc)) + key = self.split_key() if key is None else _formalize_key(key) + r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size))) + return _return(r) + + def pareto(self, a, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + a = _as_jax_array(a) + a = _check_py_seq(a) + if size is None: + size = jnp.shape(a) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.pareto(key, b=a, shape=_size2shape(size)) + return _return(r) + + def poisson(self, lam=1.0, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + lam = _check_py_seq(_as_jax_array(lam)) + if size is None: + size = jnp.shape(lam) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.poisson(key, lam=lam, shape=_size2shape(size)) + return _return(r) + + def standard_cauchy(self, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + r = jr.cauchy(key, shape=_size2shape(size)) + return _return(r) + + def standard_exponential(self, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + r = jr.exponential(key, shape=_size2shape(size)) + return _return(r) + + def standard_gamma(self, + shape, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + shape = _as_jax_array(shape) + shape = _check_py_seq(shape) + if size is None: + size = jnp.shape(shape) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.gamma(key, a=shape, shape=_size2shape(size)) + return _return(r) + + def standard_normal(self, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + r = jr.normal(key, shape=_size2shape(size)) + return _return(r) - with jax.ensure_compile_time_eval(): - if seed_or_key is None: - seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32) - if isinstance(seed_or_key, int): - key = jr.PRNGKey(seed_or_key) - else: - if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32: - raise ValueError('key must be an array with dtype uint32. ' - f'But we got {seed_or_key}') - key = seed_or_key - super(RandomState, self).__init__(key) - - def __repr__(self) -> str: - print_code = repr(self.value) - i = print_code.index('(') - name = self.__class__.__name__ - return f'{name}(key={print_code[i:]})' - - # ------------------- # - # seed and random key # - # ------------------- # - - def clone(self): - return type(self)(self.split_key()) - - def seed(self, seed_or_key=None, seed=None): - """Sets a new random seed. + def standard_t(self, df, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + df = _as_jax_array(df) + df = _check_py_seq(df) + if size is None: + size = jnp.shape(size) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.t(key, df=df, shape=_size2shape(size)) + return _return(r) + + def uniform(self, low=0.0, high=1.0, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + low = _as_jax_array(low) + high = _as_jax_array(high) + low = _check_py_seq(low) + high = _check_py_seq(high) + if size is None: + size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high) + return _return(r) + + def __norm_cdf(self, x, sqrt2, dtype): + # Computes standard normal cumulative distribution function + return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype) + + def truncated_normal(self, + lower, + upper, + size: Optional[Union[int, Sequence[int]]] = None, + loc=0., + scale=1., + dtype=float, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + lower = _check_py_seq(_as_jax_array(lower)) + upper = _check_py_seq(_as_jax_array(upper)) + loc = _check_py_seq(_as_jax_array(loc)) + scale = _check_py_seq(_as_jax_array(scale)) + + lower = lax.convert_element_type(lower, dtype) + upper = lax.convert_element_type(upper, dtype) + loc = lax.convert_element_type(loc, dtype) + scale = lax.convert_element_type(scale, dtype) + + jit_error_checking_no_args( + jnp.any(jnp.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)), + ValueError("mean is more than 2 std from [lower, upper] in truncated_normal. " + "The distribution of values may be incorrect.") + ) + + if size is None: + size = lax.broadcast_shapes(jnp.shape(lower), + jnp.shape(upper), + jnp.shape(loc), + jnp.shape(scale)) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + sqrt2 = np.array(np.sqrt(2), dtype) + l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype) + u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + key = self.split_key() if key is None else _formalize_key(key) + out = jr.uniform(key, size, dtype, + minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)), + maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + out = lax.erf_inv(out) + + # Transform to proper mean, std + out = out * scale * sqrt2 + loc + + # Clamp to ensure it's in the proper range + out = jnp.clip(out, + lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)), + lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) + return _return(out) + + def _check_p(self, p): + raise ValueError(f'Parameter p should be within [0, 1], but we got {p}') + + def bernoulli(self, p, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + p = _check_py_seq(_as_jax_array(p)) + jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) + if size is None: + size = jnp.shape(p) + key = self.split_key() if key is None else _formalize_key(key) + r = jr.bernoulli(key, p=p, shape=_size2shape(size)) + return _return(r) + + def lognormal(self, mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + mean = _check_py_seq(_as_jax_array(mean)) + sigma = _check_py_seq(_as_jax_array(sigma)) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(mean), + jnp.shape(sigma)) + key = self.split_key() if key is None else _formalize_key(key) + samples = jr.normal(key, shape=_size2shape(size)) + samples = _loc_scale(mean, sigma, samples) + samples = jnp.exp(samples) + return _return(samples) + + def binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + n = _check_py_seq(n.value if isinstance(n, BaseArray) else n) + p = _check_py_seq(p.value if isinstance(p, BaseArray) else p) + jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)) + key = self.split_key() if key is None else _formalize_key(key) + r = _binomial(key, p, n, shape=_size2shape(size)) + return _return(r) + + def chisquare(self, df, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + df = _check_py_seq(_as_jax_array(df)) + key = self.split_key() if key is None else _formalize_key(key) + if size is None: + if jnp.ndim(df) == 0: + dist = jr.normal(key, (df,)) ** 2 + dist = dist.sum() + else: + raise NotImplementedError('Do not support non-scale "df" when "size" is None') + else: + dist = jr.normal(key, (df,) + _size2shape(size)) ** 2 + dist = dist.sum(axis=0) + return _return(dist) + + def dirichlet(self, alpha, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + alpha = _check_py_seq(_as_jax_array(alpha)) + r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size)) + return _return(r) - Parameters:: - - seed_or_key: int, ArrayType, optional - It can be an integer for initial seed of the random number generator, - or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype. + def geometric(self, p, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + p = _as_jax_array(p) + p = _check_py_seq(p) + if size is None: + size = jnp.shape(p) + key = self.split_key() if key is None else _formalize_key(key) + u = jr.uniform(key, size) + r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p)) + return _return(r) + + def _check_p2(self, p): + raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}') + + def multinomial(self, n, pvals, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + n = _check_py_seq(_as_jax_array(n)) + pvals = _check_py_seq(_as_jax_array(pvals)) + jit_error_checking(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals) + if isinstance(n, jax.core.Tracer): + raise ValueError("The total count parameter `n` should not be a jax abstract array.") + size = _size2shape(size) + n_max = int(np.max(jax.device_get(n))) + batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n)) + r = _multinomial(key, pvals, n, n_max, batch_shape + size) + return _return(r) + + def multivariate_normal(self, mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky', + key: Optional[Union[int, JAX_RAND_KEY]] = None): + if method not in {'svd', 'eigh', 'cholesky'}: + raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}") + mean = _check_py_seq(_as_jax_array(mean)) + cov = _check_py_seq(_as_jax_array(cov)) + key = self.split_key() if key is None else _formalize_key(key) + + if not jnp.ndim(mean) >= 1: + raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}") + if not jnp.ndim(cov) >= 2: + raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}") + n = mean.shape[-1] + if jnp.shape(cov)[-2:] != (n, n): + raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, " + f"but got cov.shape == {jnp.shape(cov)}.") + if size is None: + size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + else: + size = _size2shape(size) + _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2]) + + if method == 'svd': + (u, s, _) = jnp.linalg.svd(cov) + factor = u * jnp.sqrt(s[..., None, :]) + elif method == 'eigh': + (w, v) = jnp.linalg.eigh(cov) + factor = v * jnp.sqrt(w[..., None, :]) + else: # 'cholesky' + factor = jnp.linalg.cholesky(cov) + normal_samples = jr.normal(key, size + mean.shape[-1:]) + r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) + return _return(r) + + def rayleigh(self, scale=1.0, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + scale = _check_py_seq(_as_jax_array(scale)) + if size is None: + size = jnp.shape(scale) + key = self.split_key() if key is None else _formalize_key(key) + x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1))) + r = x * scale + return _return(r) + + def triangular(self, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size)) + r = 2 * bernoulli_samples - 1 + return _return(r) - .. versionadded:: 2.2.3.4 + def vonmises(self, mu, kappa, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + mu = _check_py_seq(_as_jax_array(mu)) + kappa = _check_py_seq(_as_jax_array(kappa)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa)) + size = _size2shape(size) + samples = _von_mises_centered(key, kappa, size) + samples = samples + mu + samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi + return _return(samples) + + def weibull(self, a, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + a = _check_py_seq(_as_jax_array(a)) + if size is None: + size = jnp.shape(a) + else: + if jnp.size(a) > 1: + raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') + size = _size2shape(size) + random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) + r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) + return _return(r) + + def weibull_min(self, a, scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + """Sample from a Weibull minimum distribution. + + Parameters:: + + a: float, array_like + The concentration parameter of the distribution. + scale: float, array_like + The scale parameter of the distribution. + size: optional, int, tuple of int + The shape added to the parameters loc and scale broadcastable shape. + + Returns:: + + out: array_like + The sampling results. + """ + key = self.split_key() if key is None else _formalize_key(key) + a = _check_py_seq(_as_jax_array(a)) + scale = _check_py_seq(_as_jax_array(scale)) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale)) + else: + if jnp.size(a) > 1: + raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') + size = _size2shape(size) + random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) + r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) + if scale is not None: + r /= scale + return _return(r) + + def maxwell(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + shape = canonicalize_shape(_size2shape(size)) + (3,) + norm_rvs = jr.normal(key=key, shape=shape) + r = jnp.linalg.norm(norm_rvs, axis=-1) + return _return(r) + + def negative_binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + n = _check_py_seq(_as_jax_array(n)) + p = _check_py_seq(_as_jax_array(p)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)) + size = _size2shape(size) + logits = jnp.log(p) - jnp.log1p(-p) + if key is None: + keys = self.split_keys(2) + else: + keys = jr.split(_formalize_key(key), 2) + rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0]) + r = self.poisson(lam=rate, key=keys[1]) + return _return(r) + + def wald(self, mean, scale, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + mean = _check_py_seq(_as_jax_array(mean)) + scale = _check_py_seq(_as_jax_array(scale)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale)) + size = _size2shape(size) + sampled_chi2 = jnp.square(_as_jax_array(self.randn(*size))) + sampled_uniform = _as_jax_array(self.uniform(size=size, key=key)) + # Wikipedia defines an intermediate x with the formula + # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2) + # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration. + # Let us write + # w = loc * y / (2 * conc) + # Then we can extract the common factor in the last two terms to obtain + # x = loc + loc * w * (1 - sqrt(2 / w + 1)) + # Now we see that the Wikipedia formula suffers from catastrphic + # cancellation for large w (e.g., if conc << loc). + # + # Fortunately, we can fix this by multiplying both sides + # by 1 + sqrt(2 / w + 1). We get + # x * (1 + sqrt(2 / w + 1)) = + # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1)) + # = loc * (sqrt(2 / w + 1) - 1) + # The term sqrt(2 / w + 1) + 1 no longer presents numerical + # difficulties for large w, and sqrt(2 / w + 1) - 1 is just + # sqrt1pm1(2 / w), which we know how to compute accurately. + # This just leaves the matter of small w, where 2 / w may + # overflow. In the limit a w -> 0, x -> loc, so we just mask + # that case. + sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above + safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0) + denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0) + ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator + sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above + res = jnp.where(sampled_uniform <= mean / (mean + sampled), + sampled, + jnp.square(mean) / sampled) + return _return(res) + + def t(self, df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + df = _check_py_seq(_as_jax_array(df)) + if size is None: + size = np.shape(df) + else: + size = _size2shape(size) + _check_shape("t", size, np.shape(df)) + if key is None: + keys = self.split_keys(2) + else: + keys = jr.split(_formalize_key(key), 2) + n = jr.normal(keys[0], size) + two = _const(n, 2) + half_df = lax.div(df, two) + g = jr.gamma(keys[1], half_df, size) + r = n * jnp.sqrt(half_df / g) + return _return(r) + + def orthogonal(self, n: int, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + size = _size2shape(size) + _check_shape("orthogonal", size) + n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") + z = jr.normal(key, size + (n, n)) + q, r = jnp.linalg.qr(z) + d = jnp.diagonal(r, 0, -2, -1) + r = q * jnp.expand_dims(d / abs(d), -2) + return _return(r) + + def noncentral_chisquare(self, df, nonc, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + df = _check_py_seq(_as_jax_array(df)) + nonc = _check_py_seq(_as_jax_array(nonc)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc)) + size = _size2shape(size) + if key is None: + keys = self.split_keys(3) + else: + keys = jr.split(_formalize_key(key), 3) + i = jr.poisson(keys[0], 0.5 * nonc, shape=size) + n = jr.normal(keys[1], shape=size) + jnp.sqrt(nonc) + cond = jnp.greater(df, 1.0) + df2 = jnp.where(cond, df - 1.0, df + 2.0 * i) + chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size) + r = jnp.where(cond, chi2 + n * n, chi2) + return _return(r) + + def loggamma(self, a, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + a = _check_py_seq(_as_jax_array(a)) + if size is None: + size = jnp.shape(a) + r = jr.loggamma(key, a, shape=_size2shape(size)) + return _return(r) + + def categorical(self, logits, axis: int = -1, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + key = self.split_key() if key is None else _formalize_key(key) + logits = _check_py_seq(_as_jax_array(logits)) + if size is None: + size = list(jnp.shape(logits)) + size.pop(axis) + r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size)) + return _return(r) + + def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + a = _check_py_seq(_as_jax_array(a)) + if size is None: + size = jnp.shape(a) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + a) + return _return(r) + + def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + a = _check_py_seq(_as_jax_array(a)) + if size is None: + size = jnp.shape(a) + size = _size2shape(size) + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) + r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + a) + return _return(r) + + def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + dfnum = _as_jax_array(dfnum) + dfden = _as_jax_array(dfden) + dfnum = _check_py_seq(dfnum) + dfden = _check_py_seq(dfden) + if size is None: + size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) + size = _size2shape(size) + d = {'dfnum': dfnum, 'dfden': dfden} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) + r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'], + dfden=x['dfden'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) + return _return(r) + + def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + ngood = _check_py_seq(_as_jax_array(ngood)) + nbad = _check_py_seq(_as_jax_array(nbad)) + nsample = _check_py_seq(_as_jax_array(nsample)) + + if size is None: + size = lax.broadcast_shapes(jnp.shape(ngood), + jnp.shape(nbad), + jnp.shape(nsample)) + size = _size2shape(size) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} + r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'], + nbad=d['nbad'], + nsample=d['nsample'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) + return _return(r) + + def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + p = _check_py_seq(_as_jax_array(p)) + if size is None: + size = jnp.shape(p) + size = _size2shape(size) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + p) + return _return(r) + + def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + dfnum = _check_py_seq(_as_jax_array(dfnum)) + dfden = _check_py_seq(_as_jax_array(dfden)) + nonc = _check_py_seq(_as_jax_array(nonc)) + if size is None: + size = lax.broadcast_shapes(jnp.shape(dfnum), + jnp.shape(dfden), + jnp.shape(nonc)) + size = _size2shape(size) + d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) + r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], + dfden=x['dfden'], + nonc=x['nonc'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) + return _return(r) + + # PyTorch compatibility # + # --------------------- # + + def rand_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + """Returns a tensor with the same size as input that is filled with random + numbers from a uniform distribution on the interval ``[0, 1)``. + + Args: + input: the ``size`` of input will determine size of the output tensor. + dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. + key: the seed or key for the random. + + Returns: + The random data. + """ + return self.random(shape(input), key=key).astype(dtype) + + def randn_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + """Returns a tensor with the same size as ``input`` that is filled with + random numbers from a normal distribution with mean 0 and variance 1. + + Args: + input: the ``size`` of input will determine size of the output tensor. + dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. + key: the seed or key for the random. + + Returns: + The random data. + """ + return self.randn(*shape(input), key=key).astype(dtype) + + def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + if high is None: + high = max(input) + return self.randint(low, high=high, size=shape(input), dtype=dtype, key=key) - seed : int, ArrayType, optional - Same as `seed_or_key`. - .. deprecated:: 2.2.3.4 - Will be removed since version 2.4. - """ - if seed is not None: - if seed_or_key is not None: - raise ValueError('Please set "seed_or_key" or "seed", not both.') - seed_or_key = seed - warnings.warn('Please use seed_or_key instead. ' - 'seed will be removed since 2.4.0', UserWarning) +# alias +Generator = RandomState - if seed_or_key is None: - seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32) - if isinstance(seed_or_key, int): - key = jr.PRNGKey(seed_or_key) - else: - if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32: - raise ValueError('key must be an array with dtype uint32. ' - f'But we got {seed_or_key}') - key = seed_or_key - self._value = key - - - def _ensure_value_exists(self): - """Ensure that the random state has a valid value, regenerate if needed.""" - if not isinstance(self._value, np.ndarray): - with jax.ensure_compile_time_eval(): - if not isinstance(self._value, jax.core.Tracer): - if self._value.is_deleted(): - seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32) - self._value = seed_or_key +# default random generator +__a = Array(None) +__a._value = np.random.randint(0, 10000, size=2, dtype=np.uint32) +DEFAULT = RandomState(__a) +del __a - @property - def value(self): - self._ensure_value_exists() - record_state_value_read(self) - return self._read_value() +def split_key(): + """Create a new seed from the current seed. - # def check_if_deleted(self): - # if self.value.is_deleted(): - # self.seed() + This function is useful for the consistency with JAX's random paradigm.""" + return DEFAULT.split_key() - def split_key(self): - """Create a new seed from the current seed. - """ - if not isinstance(self.value, jnp.ndarray): - self._value = jnp.asarray(self.value) - keys = jr.split(self.value, num=2) - self._value = keys[0] - return keys[1] - def split_keys(self, n): +def split_keys(n): """Create multiple seeds from the current seed. This is used internally by `pmap` and `vmap` to ensure that random numbers are different in parallel threads. + .. versionadded:: 2.4.5 + Parameters:: - + n : int The number of seeds to generate. """ - keys = jr.split(self.value, n + 1) - self._value = keys[0] - return keys[1:] - - # ---------------- # - # random functions # - # ---------------- # - - def rand(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - r = jr.uniform(key, shape=dn, minval=0., maxval=1.) - return _return(r) - - def randint(self, - low, - high=None, - size: Optional[Union[int, Sequence[int]]] = None, - dtype=int, key: Optional[Union[int, JAX_RAND_KEY]] = None): - dtype = get_int() if dtype is None else dtype - low = _as_jax_array(low) - high = _as_jax_array(high) - if high is None: - high = low - low = 0 - high = _check_py_seq(high) - low = _check_py_seq(low) - if size is None: - size = lax.broadcast_shapes(jnp.shape(low), - jnp.shape(high)) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.randint(key, - shape=_size2shape(size), - minval=low, maxval=high, dtype=dtype) - return _return(r) - - def random_integers(self, - low, - high=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - low = _as_jax_array(low) - high = _as_jax_array(high) - low = _check_py_seq(low) - high = _check_py_seq(high) - if high is None: - high = low - low = 1 - high += 1 - if size is None: - size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.randint(key, - shape=_size2shape(size), - minval=low, - maxval=high) - return _return(r) - - def randn(self, *dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - r = jr.normal(key, shape=dn) - return _return(r) - - def random(self, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1.) - return _return(r) + return DEFAULT.split_keys(n) - def random_sample(self, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r = self.random(size=size, key=key) - return _return(r) - def ranf(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r = self.random(size=size, key=key) - return _return(r) +def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState: + """Clone the random state according to the given setting. - def sample(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r = self.random(size=size, key=key) - return _return(r) + Args: + seed_or_key: The seed (an integer) or the random key. + clone: Bool. Whether clone the default random state. - def choice(self, a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - a = _as_jax_array(a) - p = _as_jax_array(p) - a = _check_py_seq(a) - p = _check_py_seq(p) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p) - return _return(r) - - def permutation(self, x, axis: int = 0, independent: bool = False, key: Optional[Union[int, JAX_RAND_KEY]] = None): - x = x.value if isinstance(x, BaseArray) else x - x = _check_py_seq(x) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.permutation(key, x, axis=axis, independent=independent) - return _return(r) - - def shuffle(self, x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None): - if not isinstance(x, BaseArray): - raise TypeError('This numpy operator needs in-place updating, therefore ' - 'inputs should be brainpy Array.') - key = self.split_key() if key is None else _formalize_key(key) - x.value = jr.permutation(key, x.value, axis=axis) - - def beta(self, a, b, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - a = a.value if isinstance(a, BaseArray) else a - b = b.value if isinstance(b, BaseArray) else b - a = _check_py_seq(a) - b = _check_py_seq(b) - if size is None: - size = lax.broadcast_shapes(jnp.shape(a), jnp.shape(b)) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.beta(key, a=a, b=b, shape=_size2shape(size)) - return _return(r) + Returns: + The random state. + """ + if seed_or_key is None: + return DEFAULT.clone() if clone else DEFAULT + else: + return RandomState(seed_or_key) - def exponential(self, scale=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - scale = _as_jax_array(scale) - scale = _check_py_seq(scale) - if size is None: - size = jnp.shape(scale) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.exponential(key, shape=_size2shape(size)) - if scale is not None: - r = r / scale - return _return(r) - - def gamma(self, shape, scale=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - shape = _as_jax_array(shape) - scale = _as_jax_array(scale) - shape = _check_py_seq(shape) - scale = _check_py_seq(scale) - if size is None: - size = lax.broadcast_shapes(jnp.shape(shape), jnp.shape(scale)) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.gamma(key, a=shape, shape=_size2shape(size)) - if scale is not None: - r = r * scale - return _return(r) - - def gumbel(self, loc=None, scale=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - loc = _as_jax_array(loc) - scale = _as_jax_array(scale) - loc = _check_py_seq(loc) - scale = _check_py_seq(scale) - if size is None: - size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - key = self.split_key() if key is None else _formalize_key(key) - r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size))) - return _return(r) - def laplace(self, loc=None, scale=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - loc = _as_jax_array(loc) - scale = _as_jax_array(scale) - loc = _check_py_seq(loc) - scale = _check_py_seq(scale) - if size is None: - size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - key = self.split_key() if key is None else _formalize_key(key) - r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size))) - return _return(r) +def default_rng(seed_or_key=None, clone: bool = True) -> RandomState: + if seed_or_key is None: + return DEFAULT.clone() if clone else DEFAULT + else: + return RandomState(seed_or_key) - def logistic(self, loc=None, scale=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - loc = _as_jax_array(loc) - scale = _as_jax_array(scale) - loc = _check_py_seq(loc) - scale = _check_py_seq(scale) - if size is None: - size = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale)) - key = self.split_key() if key is None else _formalize_key(key) - r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size))) - return _return(r) - def normal(self, loc=None, scale=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - loc = _as_jax_array(loc) - scale = _as_jax_array(scale) - loc = _check_py_seq(loc) - scale = _check_py_seq(scale) - if size is None: - size = lax.broadcast_shapes(jnp.shape(scale), jnp.shape(loc)) - key = self.split_key() if key is None else _formalize_key(key) - r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size))) - return _return(r) +def seed(seed: int = None): + """Sets a new random seed. - def pareto(self, a, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - a = _as_jax_array(a) - a = _check_py_seq(a) - if size is None: - size = jnp.shape(a) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.pareto(key, b=a, shape=_size2shape(size)) - return _return(r) + Parameters:: - def poisson(self, lam=1.0, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - lam = _check_py_seq(_as_jax_array(lam)) - if size is None: - size = jnp.shape(lam) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.poisson(key, lam=lam, shape=_size2shape(size)) - return _return(r) + seed: int, optional + The random seed. + """ + with jax.ensure_compile_time_eval(): + if seed is None: + seed = np.random.randint(0, 100000) + np.random.seed(seed) + DEFAULT.seed(seed) - def standard_cauchy(self, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - r = jr.cauchy(key, shape=_size2shape(size)) - return _return(r) - - def standard_exponential(self, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - r = jr.exponential(key, shape=_size2shape(size)) - return _return(r) - - def standard_gamma(self, - shape, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - shape = _as_jax_array(shape) - shape = _check_py_seq(shape) - if size is None: - size = jnp.shape(shape) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.gamma(key, a=shape, shape=_size2shape(size)) - return _return(r) - def standard_normal(self, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - r = jr.normal(key, shape=_size2shape(size)) - return _return(r) +def rand(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r"""Random values in a given shape. - def standard_t(self, df, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - df = _as_jax_array(df) - df = _check_py_seq(df) - if size is None: - size = jnp.shape(size) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.t(key, df=df, shape=_size2shape(size)) - return _return(r) + .. note:: + This is a convenience function for users porting code from Matlab, + and wraps `random_sample`. That function takes a + tuple to specify the size of the output, which is consistent with + other NumPy functions like `numpy.zeros` and `numpy.ones`. - def uniform(self, low=0.0, high=1.0, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - low = _as_jax_array(low) - high = _as_jax_array(high) - low = _check_py_seq(low) - high = _check_py_seq(high) - if size is None: - size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high)) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high) - return _return(r) - - def __norm_cdf(self, x, sqrt2, dtype): - # Computes standard normal cumulative distribution function - return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype) - - def truncated_normal(self, - lower, - upper, - size: Optional[Union[int, Sequence[int]]] = None, - loc=0., - scale=1., - dtype=float, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - lower = _check_py_seq(_as_jax_array(lower)) - upper = _check_py_seq(_as_jax_array(upper)) - loc = _check_py_seq(_as_jax_array(loc)) - scale = _check_py_seq(_as_jax_array(scale)) - - lower = lax.convert_element_type(lower, dtype) - upper = lax.convert_element_type(upper, dtype) - loc = lax.convert_element_type(loc, dtype) - scale = lax.convert_element_type(scale, dtype) - - jit_error_checking_no_args( - jnp.any(jnp.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)), - ValueError("mean is more than 2 std from [lower, upper] in truncated_normal. " - "The distribution of values may be incorrect.") - ) + Create an array of the given shape and populate it with + random samples from a uniform distribution + over ``[0, 1)``. - if size is None: - size = lax.broadcast_shapes(jnp.shape(lower), - jnp.shape(upper), - jnp.shape(loc), - jnp.shape(scale)) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - sqrt2 = np.array(np.sqrt(2), dtype) - l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype) - u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - key = self.split_key() if key is None else _formalize_key(key) - out = jr.uniform(key, size, dtype, - minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)), - maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - out = lax.erf_inv(out) - - # Transform to proper mean, std - out = out * scale * sqrt2 + loc - - # Clamp to ensure it's in the proper range - out = jnp.clip(out, - lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)), - lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))) - return _return(out) - - def _check_p(self, p): - raise ValueError(f'Parameter p should be within [0, 1], but we got {p}') - - def bernoulli(self, p, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - p = _check_py_seq(_as_jax_array(p)) - jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) - if size is None: - size = jnp.shape(p) - key = self.split_key() if key is None else _formalize_key(key) - r = jr.bernoulli(key, p=p, shape=_size2shape(size)) - return _return(r) + Parameters:: - def lognormal(self, mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - mean = _check_py_seq(_as_jax_array(mean)) - sigma = _check_py_seq(_as_jax_array(sigma)) - if size is None: - size = jnp.broadcast_shapes(jnp.shape(mean), - jnp.shape(sigma)) - key = self.split_key() if key is None else _formalize_key(key) - samples = jr.normal(key, shape=_size2shape(size)) - samples = _loc_scale(mean, sigma, samples) - samples = jnp.exp(samples) - return _return(samples) - - def binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - n = _check_py_seq(n.value if isinstance(n, BaseArray) else n) - p = _check_py_seq(p.value if isinstance(p, BaseArray) else p) - jit_error_checking(jnp.any(jnp.logical_and(p < 0, p > 1)), self._check_p, p) - if size is None: - size = jnp.broadcast_shapes(jnp.shape(n), jnp.shape(p)) - key = self.split_key() if key is None else _formalize_key(key) - r = _binomial(key, p, n, shape=_size2shape(size)) - return _return(r) + d0, d1, ..., dn : int, optional + The dimensions of the returned array, must be non-negative. + If no argument is given a single Python float is returned. - def chisquare(self, df, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - df = _check_py_seq(_as_jax_array(df)) - key = self.split_key() if key is None else _formalize_key(key) - if size is None: - if jnp.ndim(df) == 0: - dist = jr.normal(key, (df,)) ** 2 - dist = dist.sum() - else: - raise NotImplementedError('Do not support non-scale "df" when "size" is None') - else: - dist = jr.normal(key, (df,) + _size2shape(size)) ** 2 - dist = dist.sum(axis=0) - return _return(dist) + Returns:: - def dirichlet(self, alpha, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - alpha = _check_py_seq(_as_jax_array(alpha)) - r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size)) - return _return(r) + out : ndarray, shape ``(d0, d1, ..., dn)`` + Random values. - def geometric(self, p, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - p = _as_jax_array(p) - p = _check_py_seq(p) - if size is None: - size = jnp.shape(p) - key = self.split_key() if key is None else _formalize_key(key) - u = jr.uniform(key, size) - r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p)) - return _return(r) + See Also:: - def _check_p2(self, p): - raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}') + random - def multinomial(self, n, pvals, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - n = _check_py_seq(_as_jax_array(n)) - pvals = _check_py_seq(_as_jax_array(pvals)) - jit_error_checking(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals) - if isinstance(n, jax.core.Tracer): - raise ValueError("The total count parameter `n` should not be a jax abstract array.") - size = _size2shape(size) - n_max = int(np.max(jax.device_get(n))) - batch_shape = lax.broadcast_shapes(jnp.shape(pvals)[:-1], jnp.shape(n)) - r = _multinomial(key, pvals, n, n_max, batch_shape + size) - return _return(r) - - def multivariate_normal(self, mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky', - key: Optional[Union[int, JAX_RAND_KEY]] = None): - if method not in {'svd', 'eigh', 'cholesky'}: - raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}") - mean = _check_py_seq(_as_jax_array(mean)) - cov = _check_py_seq(_as_jax_array(cov)) - key = self.split_key() if key is None else _formalize_key(key) - - if not jnp.ndim(mean) >= 1: - raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}") - if not jnp.ndim(cov) >= 2: - raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}") - n = mean.shape[-1] - if jnp.shape(cov)[-2:] != (n, n): - raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, " - f"but got cov.shape == {jnp.shape(cov)}.") - if size is None: - size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) - else: - size = _size2shape(size) - _check_shape("normal", size, mean.shape[:-1], cov.shape[:-2]) - - if method == 'svd': - (u, s, _) = jnp.linalg.svd(cov) - factor = u * jnp.sqrt(s[..., None, :]) - elif method == 'eigh': - (w, v) = jnp.linalg.eigh(cov) - factor = v * jnp.sqrt(w[..., None, :]) - else: # 'cholesky' - factor = jnp.linalg.cholesky(cov) - normal_samples = jr.normal(key, size + mean.shape[-1:]) - r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples) - return _return(r) - - def rayleigh(self, scale=1.0, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - scale = _check_py_seq(_as_jax_array(scale)) - if size is None: - size = jnp.shape(scale) - key = self.split_key() if key is None else _formalize_key(key) - x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1))) - r = x * scale - return _return(r) + Examples:: - def triangular(self, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size)) - r = 2 * bernoulli_samples - 1 - return _return(r) + >>> brainpy.math.random.rand(3,2) + array([[ 0.14022471, 0.96360618], #random + [ 0.37601032, 0.25528411], #random + [ 0.49313049, 0.94909878]]) #random + """ + return DEFAULT.rand(*dn, key=key) - def vonmises(self, mu, kappa, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - mu = _check_py_seq(_as_jax_array(mu)) - kappa = _check_py_seq(_as_jax_array(kappa)) - if size is None: - size = lax.broadcast_shapes(jnp.shape(mu), jnp.shape(kappa)) - size = _size2shape(size) - samples = _von_mises_centered(key, kappa, size) - samples = samples + mu - samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi - return _return(samples) - - def weibull(self, a, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - a = _check_py_seq(_as_jax_array(a)) - if size is None: - size = jnp.shape(a) - else: - if jnp.size(a) > 1: - raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') - size = _size2shape(size) - random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) - r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) - return _return(r) - - def weibull_min(self, a, scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample from a Weibull minimum distribution. + +def randint(low, high=None, size: Optional[Union[int, Sequence[int]]] = None, dtype=int, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r"""Return random integers from `low` (inclusive) to `high` (exclusive). + + Return random integers from the "discrete uniform" distribution of + the specified dtype in the "half-open" interval [`low`, `high`). If + `high` is None (the default), then results are from [0, `low`). Parameters:: - - a: float, array_like - The concentration parameter of the distribution. - scale: float, array_like - The scale parameter of the distribution. - size: optional, int, tuple of int - The shape added to the parameters loc and scale broadcastable shape. + + low : int or array-like of ints + Lowest (signed) integers to be drawn from the distribution (unless + ``high=None``, in which case this parameter is one above the + *highest* such integer). + high : int or array-like of ints, optional + If provided, one above the largest (signed) integer to be drawn + from the distribution (see above for behavior if ``high=None``). + If array-like, must contain integer values + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + dtype : dtype, optional + Desired dtype of the result. Byteorder must be native. + The default value is int. Returns:: - - out: array_like - The sampling results. - """ - key = self.split_key() if key is None else _formalize_key(key) - a = _check_py_seq(_as_jax_array(a)) - scale = _check_py_seq(_as_jax_array(scale)) - if size is None: - size = jnp.broadcast_shapes(jnp.shape(a), jnp.shape(scale)) - else: - if jnp.size(a) > 1: - raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}') - size = _size2shape(size) - random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1) - r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a) - if scale is not None: - r /= scale - return _return(r) - - def maxwell(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - shape = canonicalize_shape(_size2shape(size)) + (3,) - norm_rvs = jr.normal(key=key, shape=shape) - r = jnp.linalg.norm(norm_rvs, axis=-1) - return _return(r) - - def negative_binomial(self, n, p, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - n = _check_py_seq(_as_jax_array(n)) - p = _check_py_seq(_as_jax_array(p)) - if size is None: - size = lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)) - size = _size2shape(size) - logits = jnp.log(p) - jnp.log1p(-p) - if key is None: - keys = self.split_keys(2) - else: - keys = jr.split(_formalize_key(key), 2) - rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0]) - r = self.poisson(lam=rate, key=keys[1]) - return _return(r) - def wald(self, mean, scale, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - mean = _check_py_seq(_as_jax_array(mean)) - scale = _check_py_seq(_as_jax_array(scale)) - if size is None: - size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale)) - size = _size2shape(size) - sampled_chi2 = jnp.square(_as_jax_array(self.randn(*size))) - sampled_uniform = _as_jax_array(self.uniform(size=size, key=key)) - # Wikipedia defines an intermediate x with the formula - # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2) - # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration. - # Let us write - # w = loc * y / (2 * conc) - # Then we can extract the common factor in the last two terms to obtain - # x = loc + loc * w * (1 - sqrt(2 / w + 1)) - # Now we see that the Wikipedia formula suffers from catastrphic - # cancellation for large w (e.g., if conc << loc). - # - # Fortunately, we can fix this by multiplying both sides - # by 1 + sqrt(2 / w + 1). We get - # x * (1 + sqrt(2 / w + 1)) = - # = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1)) - # = loc * (sqrt(2 / w + 1) - 1) - # The term sqrt(2 / w + 1) + 1 no longer presents numerical - # difficulties for large w, and sqrt(2 / w + 1) - 1 is just - # sqrt1pm1(2 / w), which we know how to compute accurately. - # This just leaves the matter of small w, where 2 / w may - # overflow. In the limit a w -> 0, x -> loc, so we just mask - # that case. - sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above - safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0) - denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0) - ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator - sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above - res = jnp.where(sampled_uniform <= mean / (mean + sampled), - sampled, - jnp.square(mean) / sampled) - return _return(res) - - def t(self, df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - df = _check_py_seq(_as_jax_array(df)) - if size is None: - size = np.shape(df) - else: - size = _size2shape(size) - _check_shape("t", size, np.shape(df)) - if key is None: - keys = self.split_keys(2) - else: - keys = jr.split(_formalize_key(key), 2) - n = jr.normal(keys[0], size) - two = _const(n, 2) - half_df = lax.div(df, two) - g = jr.gamma(keys[1], half_df, size) - r = n * jnp.sqrt(half_df / g) - return _return(r) - - def orthogonal(self, n: int, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - size = _size2shape(size) - _check_shape("orthogonal", size) - n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") - z = jr.normal(key, size + (n, n)) - q, r = jnp.linalg.qr(z) - d = jnp.diagonal(r, 0, -2, -1) - r = q * jnp.expand_dims(d / abs(d), -2) - return _return(r) - - def noncentral_chisquare(self, df, nonc, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - df = _check_py_seq(_as_jax_array(df)) - nonc = _check_py_seq(_as_jax_array(nonc)) - if size is None: - size = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nonc)) - size = _size2shape(size) - if key is None: - keys = self.split_keys(3) - else: - keys = jr.split(_formalize_key(key), 3) - i = jr.poisson(keys[0], 0.5 * nonc, shape=size) - n = jr.normal(keys[1], shape=size) + jnp.sqrt(nonc) - cond = jnp.greater(df, 1.0) - df2 = jnp.where(cond, df - 1.0, df + 2.0 * i) - chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size) - r = jnp.where(cond, chi2 + n * n, chi2) - return _return(r) - - def loggamma(self, a, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - a = _check_py_seq(_as_jax_array(a)) - if size is None: - size = jnp.shape(a) - r = jr.loggamma(key, a, shape=_size2shape(size)) - return _return(r) + out : int or ndarray of ints + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. - def categorical(self, logits, axis: int = -1, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - key = self.split_key() if key is None else _formalize_key(key) - logits = _check_py_seq(_as_jax_array(logits)) - if size is None: - size = list(jnp.shape(logits)) - size.pop(axis) - r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size)) - return _return(r) + See Also:: - def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - a = _check_py_seq(_as_jax_array(a)) - if size is None: - size = jnp.shape(a) - dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), - jax.ShapeDtypeStruct(size, dtype), - a) - return _return(r) - - def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - a = _check_py_seq(_as_jax_array(a)) - if size is None: - size = jnp.shape(a) - size = _size2shape(size) - dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype), - jax.ShapeDtypeStruct(size, dtype), - a) - return _return(r) - - def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - dfnum = _as_jax_array(dfnum) - dfden = _as_jax_array(dfden) - dfnum = _check_py_seq(dfnum) - dfden = _check_py_seq(dfden) - if size is None: - size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) - size = _size2shape(size) - d = {'dfnum': dfnum, 'dfden': dfden} - dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'], - dfden=x['dfden'], - size=size).astype(dtype), - jax.ShapeDtypeStruct(size, dtype), - d) - return _return(r) - - def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - ngood = _check_py_seq(_as_jax_array(ngood)) - nbad = _check_py_seq(_as_jax_array(nbad)) - nsample = _check_py_seq(_as_jax_array(nsample)) + random_integers : similar to `randint`, only for the closed + interval [`low`, `high`], and 1 is the lowest value if `high` is + omitted. + Generator.integers: which should be used for new code. - if size is None: - size = lax.broadcast_shapes(jnp.shape(ngood), - jnp.shape(nbad), - jnp.shape(nsample)) - size = _size2shape(size) - dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'], - nbad=d['nbad'], - nsample=d['nsample'], - size=size).astype(dtype), - jax.ShapeDtypeStruct(size, dtype), - d) - return _return(r) - - def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - p = _check_py_seq(_as_jax_array(p)) - if size is None: - size = jnp.shape(p) - size = _size2shape(size) - dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype), - jax.ShapeDtypeStruct(size, dtype), - p) - return _return(r) - - def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - dfnum = _check_py_seq(_as_jax_array(dfnum)) - dfden = _check_py_seq(_as_jax_array(dfden)) - nonc = _check_py_seq(_as_jax_array(nonc)) - if size is None: - size = lax.broadcast_shapes(jnp.shape(dfnum), - jnp.shape(dfden), - jnp.shape(nonc)) - size = _size2shape(size) - d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} - dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], - dfden=x['dfden'], - nonc=x['nonc'], - size=size).astype(dtype), - jax.ShapeDtypeStruct(size, dtype), - d) - return _return(r) - - # PyTorch compatibility # - # --------------------- # - - def rand_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Returns a tensor with the same size as input that is filled with random - numbers from a uniform distribution on the interval ``[0, 1)``. + Examples:: - Args: - input: the ``size`` of input will determine size of the output tensor. - dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. - key: the seed or key for the random. + >>> import brainpy.math as bm + >>> bm.random.randint(2, size=10) + array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random + >>> bm.random.randint(1, size=10) + array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - Returns: - The random data. - """ - return self.random(shape(input), key=key).astype(dtype) + Generate a 2 x 4 array of ints between 0 and 4, inclusive: - def randn_like(self, input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Returns a tensor with the same size as ``input`` that is filled with - random numbers from a normal distribution with mean 0 and variance 1. + >>> bm.random.randint(5, size=(2, 4)) + array([[4, 0, 2, 1], # random + [3, 2, 2, 0]]) - Args: - input: the ``size`` of input will determine size of the output tensor. - dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. - key: the seed or key for the random. + Generate a 1 x 3 array with 3 different upper bounds + + >>> bm.random.randint(1, [3, 5, 10]) + array([2, 2, 9]) # random + + Generate a 1 by 3 array with 3 different lower bounds + + >>> bm.random.randint([1, 5, 7], 10) + array([9, 8, 7]) # random + + Generate a 2 by 4 array using broadcasting with dtype of uint8 + + >>> bm.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8) + array([[ 8, 6, 9, 7], # random + [ 1, 16, 9, 12]], dtype=uint8) + """ + + return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key) + + +def random_integers(low, + high=None, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Random integers of type `np.int_` between `low` and `high`, inclusive. + + Return random integers of type `np.int_` from the "discrete uniform" + distribution in the closed interval [`low`, `high`]. If `high` is + None (the default), then results are from [1, `low`]. The `np.int_` + type translates to the C long integer type and its precision + is platform dependent. + + Parameters:: + + low : int + Lowest (signed) integer to be drawn from the distribution (unless + ``high=None``, in which case this parameter is the *highest* such + integer). + high : int, optional + If provided, the largest (signed) integer to be drawn from the + distribution (see above for behavior if ``high=None``). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + + Returns:: + + out : int or ndarray of ints + `size`-shaped array of random integers from the appropriate + distribution, or a single such random int if `size` not provided. + + See Also:: + + randint : Similar to `random_integers`, only for the half-open + interval [`low`, `high`), and 0 is the lowest value if `high` is + omitted. + + Notes:: + + To sample from N evenly spaced floating-point numbers between a and b, + use:: + + a + (b - a) * (bm.random.random_integers(N) - 1) / (N - 1.) + + Examples:: + + >>> import brainpy.math as bm + >>> bm.random.random_integers(5) + 4 # random + >>> type(bm.random.random_integers(5)) + + >>> bm.random.random_integers(5, size=(3,2)) + array([[5, 4], # random + [3, 3], + [4, 5]]) + + Choose five random numbers from the set of five evenly-spaced + numbers between 0 and 2.5, inclusive (*i.e.*, from the set + :math:`{0, 5/8, 10/8, 15/8, 20/8}`): + + >>> 2.5 * (bm.random.random_integers(5, size=(5,)) - 1) / 4. + array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random + + Roll two six sided dice 1000 times and sum the results: + + >>> d1 = bm.random.random_integers(1, 6, 1000) + >>> d2 = bm.random.random_integers(1, 6, 1000) + >>> dsums = d1 + d2 + + Display results as a histogram: + + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(dsums, 11, density=True) + >>> plt.show() + """ + + return DEFAULT.random_integers(low, high=high, size=size, key=key) + + +def randn(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Return a sample (or samples) from the "standard normal" distribution. + + .. note:: + This is a convenience function for users porting code from Matlab, + and wraps `standard_normal`. That function takes a + tuple to specify the size of the output, which is consistent with + other NumPy functions like `numpy.zeros` and `numpy.ones`. + + .. note:: + New code should use the ``standard_normal`` method of a ``default_rng()`` + instance instead; please see the :ref:`random-quick-start`. + + If positive int_like arguments are provided, `randn` generates an array + of shape ``(d0, d1, ..., dn)``, filled + with random floats sampled from a univariate "normal" (Gaussian) + distribution of mean 0 and variance 1. A single float randomly sampled + from the distribution is returned if no argument is provided. + + Parameters:: + + d0, d1, ..., dn : int, optional + The dimensions of the returned array, must be non-negative. + If no argument is given a single Python float is returned. + + Returns:: + + Z : ndarray or float + A ``(d0, d1, ..., dn)``-shaped array of floating-point samples from + the standard normal distribution, or a single such float if + no parameters were supplied. + + See Also:: + + standard_normal : Similar, but takes a tuple as its argument. + normal : Also accepts mu and sigma arguments. + + Notes:: + + For random samples from :math:`N(\mu, \sigma^2)`, use: + + ``sigma * bm.random.randn(...) + mu`` + + Examples:: + + >>> import brainpy.math as bm + >>> bm.random.randn() + 2.1923875335537315 # random + + Two-by-four array of samples from N(3, 6.25): + + >>> 3 + 2.5 * bm.random.randn(2, 4) + array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random + [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random + """ + + return DEFAULT.randn(*dn, key=key) + + +def random(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Return random floats in the half-open interval [0.0, 1.0). Alias for + `random_sample` to ease forward-porting to the new random API. + """ + return DEFAULT.random(size, key=key) + + +def random_sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Return random floats in the half-open interval [0.0, 1.0). + + Results are from the "continuous uniform" distribution over the + stated interval. To sample :math:`Unif[a, b), b > a` multiply + the output of `random_sample` by `(b-a)` and add `a`:: + + (b - a) * random_sample() + a + + .. note:: + New code should use the ``random`` method of a ``default_rng()`` + instance instead; please see the :ref:`random-quick-start`. + + Parameters:: + + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + + Returns:: + + out : float or ndarray of floats + Array of random floats of shape `size` (unless ``size=None``, in which + case a single float is returned). + + See Also:: + + Generator.random: which should be used for new code. + + Examples:: + + >>> import brainpy.math as bm + >>> bm.random.random_sample() + 0.47108547995356098 # random + >>> type(bm.random.random_sample()) + + >>> bm.random.random_sample((5,)) + array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random + + Three-by-two array of random numbers from [-5, 0): + + >>> 5 * bm.random.random_sample((3, 2)) - 5 + array([[-3.99149989, -0.52338984], # random + [-2.99091858, -0.79479508], + [-1.23204345, -1.75224494]]) + """ + return DEFAULT.random_sample(size, key=key) + + +def ranf(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + This is an alias of `random_sample`. See `random_sample` for the complete + documentation. + """ + return DEFAULT.ranf(size, key=key) + + +def sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + """ + This is an alias of `random_sample`. See `random_sample` for the complete + documentation. + """ + return DEFAULT.sample(size, key=key) + + +def choice(a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Generates a random sample from a given 1-D array + + Parameters:: + + a : 1-D array-like or int + If an ndarray, a random sample is generated from its elements. + If an int, the random sample is generated as if it were ``np.arange(a)`` + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + replace : boolean, optional + Whether the sample is with or without replacement. Default is True, + meaning that a value of ``a`` can be selected multiple times. + p : 1-D array-like, optional + The probabilities associated with each entry in a. + If not given, the sample assumes a uniform distribution over all + entries in ``a``. + + Returns:: + + samples : single item or ndarray + The generated random samples + + Raises:: + + ValueError + If a is an int and less than zero, if a or p are not 1-dimensional, + if a is an array-like of size 0, if p is not a vector of + probabilities, if a and p have different lengths, or if + replace=False and the sample size is greater than the population + size + + See Also:: + + randint, shuffle, permutation + Generator.choice: which should be used in new code + + Notes:: + + Setting user-specified probabilities through ``p`` uses a more general but less + efficient sampler than the default. The general sampler produces a different sample + than the optimized sampler even if each element of ``p`` is 1 / len(a). + + Sampling random rows from a 2-D array is not possible with this function, + but is possible with `Generator.choice` through its ``axis`` keyword. + + Examples:: + + Generate a uniform random sample from np.arange(5) of size 3: + + >>> import brainpy.math as bm + >>> bm.random.choice(5, 3) + array([0, 3, 4]) # random + >>> #This is equivalent to brainpy.math.random.randint(0,5,3) + + Generate a non-uniform random sample from np.arange(5) of size 3: + + >>> bm.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0]) + array([3, 3, 0]) # random + + Generate a uniform random sample from np.arange(5) of size 3 without + replacement: + + >>> bm.random.choice(5, 3, replace=False) + array([3,1,0]) # random + >>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3] + + Generate a non-uniform random sample from np.arange(5) of size + 3 without replacement: + + >>> bm.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) + array([2, 3, 0]) # random + + Any of the above can be repeated with an arbitrary array-like + instead of just integers. For instance: + + >>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher'] + >>> bm.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3]) + array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random + dtype='>> import brainpy.math as bm + >>> bm.random.permutation(10) + array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random + + >>> bm.random.permutation([1, 4, 9, 12, 15]) + array([15, 1, 9, 4, 12]) # random + + >>> arr = np.arange(9).reshape((3, 3)) + >>> bm.random.permutation(arr) + array([[6, 7, 8], # random + [0, 1, 2], + [3, 4, 5]]) + """ + return DEFAULT.permutation(x, axis=axis, independent=independent, key=key) + + +def shuffle(x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Modify a sequence in-place by shuffling its contents. + + This function only shuffles the array along the first axis of a + multi-dimensional array. The order of sub-arrays is changed but + their contents remains the same. + + Parameters:: + + x : ndarray or MutableSequence + The array, list or mutable sequence to be shuffled. + + Returns:: + + None + + Examples:: + + >>> import brainpy.math as bm + >>> arr = np.arange(10) + >>> bm.random.shuffle(arr) + >>> arr + [1 7 5 2 9 4 3 6 0 8] # random + + Multi-dimensional arrays are only shuffled along the first axis: + + >>> arr = np.arange(9).reshape((3, 3)) + >>> bm.random.shuffle(arr) + >>> arr + array([[3, 4, 5], # random + [6, 7, 8], + [0, 1, 2]]) + """ + DEFAULT.shuffle(x, axis, key=key) + + +def beta(a, b, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Beta distribution. + + The Beta distribution is a special case of the Dirichlet distribution, + and is related to the Gamma distribution. It has the probability + distribution function + + .. math:: f(x; a,b) = \frac{1}{B(\alpha, \beta)} x^{\alpha - 1} + (1 - x)^{\beta - 1}, + + where the normalization, B, is the beta function, + + .. math:: B(\alpha, \beta) = \int_0^1 t^{\alpha - 1} + (1 - t)^{\beta - 1} dt. + + It is often seen in Bayesian inference and order statistics. + + Parameters:: + + a : float or array_like of floats + Alpha, positive (>0). + b : float or array_like of floats + Beta, positive (>0). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` and ``b`` are both scalars. + Otherwise, ``np.broadcast(a, b).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized beta distribution. + """ + return DEFAULT.beta(a, b, size=size, key=key) + + +def exponential(scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from an exponential distribution. + + Its probability density function is + + .. math:: f(x; \frac{1}{\beta}) = \frac{1}{\beta} \exp(-\frac{x}{\beta}), + + for ``x > 0`` and 0 elsewhere. :math:`\beta` is the scale parameter, + which is the inverse of the rate parameter :math:`\lambda = 1/\beta`. + The rate parameter is an alternative, widely used parameterization + of the exponential distribution [3]_. + + The exponential distribution is a continuous analogue of the + geometric distribution. It describes many common situations, such as + the size of raindrops measured over many rainstorms [1]_, or the time + between page requests to Wikipedia [2]_. + + Parameters:: + + scale : float or array_like of floats + The scale parameter, :math:`\beta = 1/\lambda`. Must be + non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``scale`` is a scalar. Otherwise, + ``np.array(scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized exponential distribution. + + References:: + + .. [1] Peyton Z. Peebles Jr., "Probability, Random Variables and + Random Signal Principles", 4th ed, 2001, p. 57. + .. [2] Wikipedia, "Poisson process", + https://en.wikipedia.org/wiki/Poisson_process + .. [3] Wikipedia, "Exponential distribution", + https://en.wikipedia.org/wiki/Exponential_distribution + """ + return DEFAULT.exponential(scale, size, key=key) + + +def gamma(shape, scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + `shape` (sometimes designated "k") and `scale` (sometimes designated + "theta"), where both parameters are > 0. + + Parameters:: + + shape : float or array_like of floats + The shape of the gamma distribution. Must be non-negative. + scale : float or array_like of floats, optional + The scale of the gamma distribution. Must be non-negative. + Default is equal to 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized gamma distribution. + + + Notes:: + + The probability density for the Gamma distribution is + + .. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)}, + + where :math:`k` is the shape and :math:`\theta` the scale, + and :math:`\Gamma` is the Gamma function. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + + References:: + + .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A + Wolfram Web Resource. + http://mathworld.wolfram.com/GammaDistribution.html + .. [2] Wikipedia, "Gamma distribution", + https://en.wikipedia.org/wiki/Gamma_distribution + + """ + return DEFAULT.gamma(shape, scale, size=size, key=key) + + +def gumbel(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Gumbel distribution. + + Draw samples from a Gumbel distribution with specified location and + scale. For more information on the Gumbel distribution, see + Notes and References below. + + Parameters:: + + loc : float or array_like of floats, optional + The location of the mode of the distribution. Default is 0. + scale : float or array_like of floats, optional + The scale parameter of the distribution. Default is 1. Must be non- + negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``loc`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Gumbel distribution. + + Notes:: + + The Gumbel (or Smallest Extreme Value (SEV) or the Smallest Extreme + Value Type I) distribution is one of a class of Generalized Extreme + Value (GEV) distributions used in modeling extreme value problems. + The Gumbel is a special case of the Extreme Value Type I distribution + for maximums from distributions with "exponential-like" tails. + + The probability density for the Gumbel distribution is + + .. math:: p(x) = \frac{e^{-(x - \mu)/ \beta}}{\beta} e^{ -e^{-(x - \mu)/ + \beta}}, + + where :math:`\mu` is the mode, a location parameter, and + :math:`\beta` is the scale parameter. + + The Gumbel (named for German mathematician Emil Julius Gumbel) was used + very early in the hydrology literature, for modeling the occurrence of + flood events. It is also used for modeling maximum wind speed and + rainfall rates. It is a "fat-tailed" distribution - the probability of + an event in the tail of the distribution is larger than if one used a + Gaussian, hence the surprisingly frequent occurrence of 100-year + floods. Floods were initially modeled as a Gaussian process, which + underestimated the frequency of extreme events. + + It is one of a class of extreme value distributions, the Generalized + Extreme Value (GEV) distributions, which also includes the Weibull and + Frechet. + + The function has a mean of :math:`\mu + 0.57721\beta` and a variance + of :math:`\frac{\pi^2}{6}\beta^2`. + + References:: + + .. [1] Gumbel, E. J., "Statistics of Extremes," + New York: Columbia University Press, 1958. + .. [2] Reiss, R.-D. and Thomas, M., "Statistical Analysis of Extreme + Values from Insurance, Finance, Hydrology and Other Fields," + Basel: Birkhauser Verlag, 2001. + """ + return DEFAULT.gumbel(loc, scale, size=size, key=key) + + +def laplace(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from the Laplace or double exponential distribution with + specified location (or mean) and scale (decay). + + The Laplace distribution is similar to the Gaussian/normal distribution, + but is sharper at the peak and has fatter tails. It represents the + difference between two independent, identically distributed exponential + random variables. + + Parameters:: + + loc : float or array_like of floats, optional + The position, :math:`\mu`, of the distribution peak. Default is 0. + scale : float or array_like of floats, optional + :math:`\lambda`, the exponential decay. Default is 1. Must be non- + negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``loc`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Laplace distribution. + + Notes:: + + It has the probability density function + + .. math:: f(x; \mu, \lambda) = \frac{1}{2\lambda} + \exp\left(-\frac{|x - \mu|}{\lambda}\right). + + The first law of Laplace, from 1774, states that the frequency + of an error can be expressed as an exponential function of the + absolute magnitude of the error, which leads to the Laplace + distribution. For many problems in economics and health + sciences, this distribution seems to model the data better + than the standard Gaussian distribution. + + References:: + + .. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of + Mathematical Functions with Formulas, Graphs, and Mathematical + Tables, 9th printing," New York: Dover, 1972. + .. [2] Kotz, Samuel, et. al. "The Laplace Distribution and + Generalizations, " Birkhauser, 2001. + .. [3] Weisstein, Eric W. "Laplace Distribution." + From MathWorld--A Wolfram Web Resource. + http://mathworld.wolfram.com/LaplaceDistribution.html + .. [4] Wikipedia, "Laplace distribution", + https://en.wikipedia.org/wiki/Laplace_distribution + + Examples:: + + Draw samples from the distribution + + >>> loc, scale = 0., 1. + >>> s = bm.random.laplace(loc, scale, 1000) + + Display the histogram of the samples, along with + the probability density function: + + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, 30, density=True) + >>> x = np.arange(-8., 8., .01) + >>> pdf = np.exp(-abs(x-loc)/scale)/(2.*scale) + >>> plt.plot(x, pdf) + + Plot Gaussian for comparison: + + >>> g = (1/(scale * np.sqrt(2 * np.pi)) * + ... np.exp(-(x - loc)**2 / (2 * scale**2))) + >>> plt.plot(x,g) + """ + return DEFAULT.laplace(loc, scale, size, key=key) + + +def logistic(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a logistic distribution. + + Samples are drawn from a logistic distribution with specified + parameters, loc (location or mean, also median), and scale (>0). + + Parameters:: + + loc : float or array_like of floats, optional + Parameter of the distribution. Default is 0. + scale : float or array_like of floats, optional + Parameter of the distribution. Must be non-negative. + Default is 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``loc`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized logistic distribution. + + Notes:: + + The probability density for the Logistic distribution is + + .. math:: P(x) = P(x) = \frac{e^{-(x-\mu)/s}}{s(1+e^{-(x-\mu)/s})^2}, + + where :math:`\mu` = location and :math:`s` = scale. + + The Logistic distribution is used in Extreme Value problems where it + can act as a mixture of Gumbel distributions, in Epidemiology, and by + the World Chess Federation (FIDE) where it is used in the Elo ranking + system, assuming the performance of each player is a logistically + distributed random variable. + + References:: + + .. [1] Reiss, R.-D. and Thomas M. (2001), "Statistical Analysis of + Extreme Values, from Insurance, Finance, Hydrology and Other + Fields," Birkhauser Verlag, Basel, pp 132-133. + .. [2] Weisstein, Eric W. "Logistic Distribution." From + MathWorld--A Wolfram Web Resource. + http://mathworld.wolfram.com/LogisticDistribution.html + .. [3] Wikipedia, "Logistic-distribution", + https://en.wikipedia.org/wiki/Logistic_distribution + + Examples:: + + Draw samples from the distribution: + + >>> loc, scale = 10, 1 + >>> s = bm.random.logistic(loc, scale, 10000) + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, bins=50) + + # plot against distribution + + >>> def logist(x, loc, scale): + ... return np.exp((loc-x)/scale)/(scale*(1+np.exp((loc-x)/scale))**2) + >>> lgst_val = logist(bins, loc, scale) + >>> plt.plot(bins, lgst_val * count.max() / lgst_val.max()) + >>> plt.show() + """ + return DEFAULT.logistic(loc, scale, size, key=key) + + +def normal(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw random samples from a normal (Gaussian) distribution. + + The probability density function of the normal distribution, first + derived by De Moivre and 200 years later by both Gauss and Laplace + independently [2]_, is often called the bell curve because of + its characteristic shape (see the example below). + + The normal distributions occurs often in nature. For example, it + describes the commonly occurring distribution of samples influenced + by a large number of tiny, random disturbances, each with its own + unique distribution [2]_. + + Parameters:: + + loc : float or array_like of floats + Mean ("centre") of the distribution. + scale : float or array_like of floats + Standard deviation (spread or "width") of the distribution. Must be + non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``loc`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized normal distribution. + + Notes:: + + The probability density for the Gaussian distribution is + + .. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} + e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} }, + + where :math:`\mu` is the mean and :math:`\sigma` the standard + deviation. The square of the standard deviation, :math:`\sigma^2`, + is called the variance. + + The function has its peak at the mean, and its "spread" increases with + the standard deviation (the function reaches 0.607 times its maximum at + :math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that + normal is more likely to return samples lying close to the mean, rather + than those far away. + + References:: + + .. [1] Wikipedia, "Normal distribution", + https://en.wikipedia.org/wiki/Normal_distribution + .. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability, + Random Variables and Random Signal Principles", 4th ed., 2001, + pp. 51, 51, 125. + + Examples:: + + Draw samples from the distribution: + + >>> mu, sigma = 0, 0.1 # mean and standard deviation + >>> s = bm.random.normal(mu, sigma, 1000) + + Verify the mean and the variance: + + >>> abs(mu - np.mean(s)) + 0.0 # may vary + + >>> abs(sigma - np.std(s, ddof=1)) + 0.1 # may vary + + Display the histogram of the samples, along with + the probability density function: + + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, 30, density=True) + >>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) * + ... np.exp( - (bins - mu)**2 / (2 * sigma**2) ), + ... linewidth=2, color='r') + >>> plt.show() + + Two-by-four array of samples from the normal distribution with + mean 3 and standard deviation 2.5: + + >>> bm.random.normal(3, 2.5, size=(2, 4)) + array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random + [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random + """ + return DEFAULT.normal(loc, scale, size, key=key) + + +def pareto(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Pareto II or Lomax distribution with + specified shape. + + The Lomax or Pareto II distribution is a shifted Pareto + distribution. The classical Pareto distribution can be + obtained from the Lomax distribution by adding 1 and + multiplying by the scale parameter ``m`` (see Notes). The + smallest value of the Lomax distribution is zero while for the + classical Pareto distribution it is ``mu``, where the standard + Pareto distribution has location ``mu = 1``. Lomax can also + be considered as a simplified version of the Generalized + Pareto distribution (available in SciPy), with the scale set + to one and the location set to zero. + + The Pareto distribution must be greater than zero, and is + unbounded above. It is also known as the "80-20 rule". In + this distribution, 80 percent of the weights are in the lowest + 20 percent of the range, while the other 20 percent fill the + remaining 80 percent of the range. + + Parameters:: + + a : float or array_like of floats + Shape of the distribution. Must be positive. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Pareto distribution. + + See Also:: + + scipy.stats.lomax : probability density function, distribution or + cumulative density function, etc. + scipy.stats.genpareto : probability density function, distribution or + cumulative density function, etc. + + Notes:: + + The probability density for the Pareto distribution is + + .. math:: p(x) = \frac{am^a}{x^{a+1}} + + where :math:`a` is the shape and :math:`m` the scale. + + The Pareto distribution, named after the Italian economist + Vilfredo Pareto, is a power law probability distribution + useful in many real world problems. Outside the field of + economics it is generally referred to as the Bradford + distribution. Pareto developed the distribution to describe + the distribution of wealth in an economy. It has also found + use in insurance, web page access statistics, oil field sizes, + and many other problems, including the download frequency for + projects in Sourceforge [1]_. It is one of the so-called + "fat-tailed" distributions. + + References:: + + .. [1] Francis Hunt and Paul Johnson, On the Pareto Distribution of + Sourceforge projects. + .. [2] Pareto, V. (1896). Course of Political Economy. Lausanne. + .. [3] Reiss, R.D., Thomas, M.(2001), Statistical Analysis of Extreme + Values, Birkhauser Verlag, Basel, pp 23-30. + .. [4] Wikipedia, "Pareto distribution", + https://en.wikipedia.org/wiki/Pareto_distribution + + Examples:: + + Draw samples from the distribution: + + >>> a, m = 3., 2. # shape and mode + >>> s = (bm.random.pareto(a, 1000) + 1) * m + + Display the histogram of the samples, along with the probability + density function: + + >>> import matplotlib.pyplot as plt + >>> count, bins, _ = plt.hist(s, 100, density=True) + >>> fit = a*m**a / bins**(a+1) + >>> plt.plot(bins, max(count)*fit/max(fit), linewidth=2, color='r') + >>> plt.show() + """ + return DEFAULT.pareto(a, size, key=key) + + +def poisson(lam=1.0, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Poisson distribution. + + The Poisson distribution is the limit of the binomial distribution + for large N. + + Parameters:: + + lam : float or array_like of floats + Expected number of events occurring in a fixed-time interval, + must be >= 0. A sequence must be broadcastable over the requested + size. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``lam`` is a scalar. Otherwise, + ``np.array(lam).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Poisson distribution. + + Notes:: + + The Poisson distribution + + .. math:: f(k; \lambda)=\frac{\lambda^k e^{-\lambda}}{k!} + + For events with an expected separation :math:`\lambda` the Poisson + distribution :math:`f(k; \lambda)` describes the probability of + :math:`k` events occurring within the observed + interval :math:`\lambda`. + + Because the output is limited to the range of the C int64 type, a + ValueError is raised when `lam` is within 10 sigma of the maximum + representable value. + + References:: + + .. [1] Weisstein, Eric W. "Poisson Distribution." + From MathWorld--A Wolfram Web Resource. + http://mathworld.wolfram.com/PoissonDistribution.html + .. [2] Wikipedia, "Poisson distribution", + https://en.wikipedia.org/wiki/Poisson_distribution + + Examples:: + + Draw samples from the distribution: + + >>> import numpy as np + >>> s = bm.random.poisson(5, 10000) + + Display histogram of the sample: + + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, 14, density=True) + >>> plt.show() + + Draw each 100 values for lambda 100 and 500: + + >>> s = bm.random.poisson(lam=(100., 500.), size=(100, 2)) + """ + return DEFAULT.poisson(lam, size, key=key) + + +def standard_cauchy(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a standard Cauchy distribution with mode = 0. + + Also known as the Lorentz distribution. + + Parameters:: + + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + + Returns:: + + samples : ndarray or scalar + The drawn samples. + + Notes:: + + The probability density function for the full Cauchy distribution is + + .. math:: P(x; x_0, \gamma) = \frac{1}{\pi \gamma \bigl[ 1+ + (\frac{x-x_0}{\gamma})^2 \bigr] } + + and the Standard Cauchy distribution just sets :math:`x_0=0` and + :math:`\gamma=1` + + The Cauchy distribution arises in the solution to the driven harmonic + oscillator problem, and also describes spectral line broadening. It + also describes the distribution of values at which a line tilted at + a random angle will cut the x axis. + + When studying hypothesis tests that assume normality, seeing how the + tests perform on data from a Cauchy distribution is a good indicator of + their sensitivity to a heavy-tailed distribution, since the Cauchy looks + very much like a Gaussian distribution, but with heavier tails. + + References:: + + .. [1] NIST/SEMATECH e-Handbook of Statistical Methods, "Cauchy + Distribution", + https://www.itl.nist.gov/div898/handbook/eda/section3/eda3663.htm + .. [2] Weisstein, Eric W. "Cauchy Distribution." From MathWorld--A + Wolfram Web Resource. + http://mathworld.wolfram.com/CauchyDistribution.html + .. [3] Wikipedia, "Cauchy distribution" + https://en.wikipedia.org/wiki/Cauchy_distribution + + Examples:: + + Draw samples and plot the distribution: + + >>> import matplotlib.pyplot as plt + >>> s = bm.random.standard_cauchy(1000000) + >>> s = s[(s>-25) & (s<25)] # truncate distribution so it plots well + >>> plt.hist(s, bins=100) + >>> plt.show() + """ + return DEFAULT.standard_cauchy(size, key=key) + + +def standard_exponential(size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from the standard exponential distribution. + + `standard_exponential` is identical to the exponential distribution + with a scale parameter of 1. + + Parameters:: + + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + + Returns:: + + out : float or ndarray + Drawn samples. + + Examples:: + + Output a 3x8000 array: + + >>> n = bm.random.standard_exponential((3, 8000)) + """ + return DEFAULT.standard_exponential(size, key=key) + + +def standard_gamma(shape, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a standard Gamma distribution. + + Samples are drawn from a Gamma distribution with specified parameters, + shape (sometimes designated "k") and scale=1. + + Parameters:: + + shape : float or array_like of floats + Parameter, must be non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``shape`` is a scalar. Otherwise, + ``np.array(shape).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized standard gamma distribution. + + See Also:: + + scipy.stats.gamma : probability density function, distribution or + cumulative density function, etc. + + Notes:: + + The probability density for the Gamma distribution is + + .. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)}, + + where :math:`k` is the shape and :math:`\theta` the scale, + and :math:`\Gamma` is the Gamma function. + + The Gamma distribution is often used to model the times to failure of + electronic components, and arises naturally in processes for which the + waiting times between Poisson distributed events are relevant. + + References:: + + .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A + Wolfram Web Resource. + http://mathworld.wolfram.com/GammaDistribution.html + .. [2] Wikipedia, "Gamma distribution", + https://en.wikipedia.org/wiki/Gamma_distribution + + Examples:: + + Draw samples from the distribution: + + >>> shape, scale = 2., 1. # mean and width + >>> s = bm.random.standard_gamma(shape, 1000000) + + Display the histogram of the samples, along with + the probability density function: + + >>> import matplotlib.pyplot as plt + >>> import scipy.special as sps # doctest: +SKIP + >>> count, bins, ignored = plt.hist(s, 50, density=True) + >>> y = bins**(shape-1) * ((np.exp(-bins/scale))/ # doctest: +SKIP + ... (sps.gamma(shape) * scale**shape)) + >>> plt.plot(bins, y, linewidth=2, color='r') # doctest: +SKIP + >>> plt.show() + """ + return DEFAULT.standard_gamma(shape, size, key=key) + + +def standard_normal(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a standard Normal distribution (mean=0, stdev=1). + + Parameters:: + + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + + Returns:: + + out : float or ndarray + A floating-point array of shape ``size`` of drawn samples, or a + single sample if ``size`` was not specified. + + See Also:: + + normal : + Equivalent function with additional ``loc`` and ``scale`` arguments + for setting the mean and standard deviation. + + Notes:: + + For random samples from the normal distribution with mean ``mu`` and + standard deviation ``sigma``, use one of:: + + mu + sigma * bm.random.standard_normal(size=...) + bm.random.normal(mu, sigma, size=...) + + Examples:: + + >>> bm.random.standard_normal() + 2.1923875335537315 #random + + >>> s = bm.random.standard_normal(8000) + >>> s + array([ 0.6888893 , 0.78096262, -0.89086505, ..., 0.49876311, # random + -0.38672696, -0.4685006 ]) # random + >>> s.shape + (8000,) + >>> s = bm.random.standard_normal(size=(3, 4, 2)) + >>> s.shape + (3, 4, 2) + + Two-by-four array of samples from the normal distribution with + mean 3 and standard deviation 2.5: + + >>> 3 + 2.5 * bm.random.standard_normal(size=(2, 4)) + array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random + [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random + """ + return DEFAULT.standard_normal(size, key=key) + + +def standard_t(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a standard Student's t distribution with `df` degrees + of freedom. + + A special case of the hyperbolic distribution. As `df` gets + large, the result resembles that of the standard normal + distribution (`standard_normal`). + + Parameters:: + + df : float or array_like of floats + Degrees of freedom, must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``df`` is a scalar. Otherwise, + ``np.array(df).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized standard Student's t distribution. + + Notes:: + + The probability density function for the t distribution is + + .. math:: P(x, df) = \frac{\Gamma(\frac{df+1}{2})}{\sqrt{\pi df} + \Gamma(\frac{df}{2})}\Bigl( 1+\frac{x^2}{df} \Bigr)^{-(df+1)/2} + + The t test is based on an assumption that the data come from a + Normal distribution. The t test provides a way to test whether + the sample mean (that is the mean calculated from the data) is + a good estimate of the true mean. + + The derivation of the t-distribution was first published in + 1908 by William Gosset while working for the Guinness Brewery + in Dublin. Due to proprietary issues, he had to publish under + a pseudonym, and so he used the name Student. + + References:: + + .. [1] Dalgaard, Peter, "Introductory Statistics With R", + Springer, 2002. + .. [2] Wikipedia, "Student's t-distribution" + https://en.wikipedia.org/wiki/Student's_t-distribution + + Examples:: + + From Dalgaard page 83 [1]_, suppose the daily energy intake for 11 + women in kilojoules (kJ) is: + + >>> intake = np.array([5260., 5470, 5640, 6180, 6390, 6515, 6805, 7515, \ + ... 7515, 8230, 8770]) + + Does their energy intake deviate systematically from the recommended + value of 7725 kJ? Our null hypothesis will be the absence of deviation, + and the alternate hypothesis will be the presence of an effect that could be + either positive or negative, hence making our test 2-tailed. + + Because we are estimating the mean and we have N=11 values in our sample, + we have N-1=10 degrees of freedom. We set our significance level to 95% and + compute the t statistic using the empirical mean and empirical standard + deviation of our intake. We use a ddof of 1 to base the computation of our + empirical standard deviation on an unbiased estimate of the variance (note: + the final estimate is not unbiased due to the concave nature of the square + root). + + >>> np.mean(intake) + 6753.636363636364 + >>> intake.std(ddof=1) + 1142.1232221373727 + >>> t = (np.mean(intake)-7725)/(intake.std(ddof=1)/np.sqrt(len(intake))) + >>> t + -2.8207540608310198 + + We draw 1000000 samples from Student's t distribution with the adequate + degrees of freedom. + + >>> import matplotlib.pyplot as plt + >>> s = bm.random.standard_t(10, size=1000000) + >>> h = plt.hist(s, bins=100, density=True) + + Does our t statistic land in one of the two critical regions found at + both tails of the distribution? + + >>> np.sum(np.abs(t) < np.abs(s)) / float(len(s)) + 0.018318 #random < 0.05, statistic is in critical region + + The probability value for this 2-tailed test is about 1.83%, which is + lower than the 5% pre-determined significance threshold. + + Therefore, the probability of observing values as extreme as our intake + conditionally on the null hypothesis being true is too low, and we reject + the null hypothesis of no deviation. + """ + return DEFAULT.standard_t(df, size, key=key) + + +def uniform(low=0.0, high=1.0, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters:: + + low : float or array_like of floats, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float or array_like of floats + Upper boundary of the output interval. All values generated will be + less than or equal to high. The high limit may be included in the + returned array of floats due to floating-point rounding in the + equation ``low + (high-low) * random_sample()``. The default value + is 1.0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``low`` and ``high`` are both scalars. + Otherwise, ``np.broadcast(low, high).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized uniform distribution. + + See Also:: + + randint : Discrete uniform distribution, yielding integers. + random_integers : Discrete uniform distribution over the closed + interval ``[low, high]``. + random_sample : Floats uniformly distributed over ``[0, 1)``. + random : Alias for `random_sample`. + rand : Convenience function that accepts dimensions as input, e.g., + ``rand(2,2)`` would generate a 2-by-2 array of floats, + uniformly distributed over ``[0, 1)``. + + Notes:: + + The probability density function of the uniform distribution is + + .. math:: p(x) = \frac{1}{b - a} + + anywhere within the interval ``[a, b)``, and zero elsewhere. + + When ``high`` == ``low``, values of ``low`` will be returned. + If ``high`` < ``low``, the results are officially undefined + and may eventually raise an error, i.e. do not rely on this + function to behave when passed arguments satisfying that + inequality condition. The ``high`` limit may be included in the + returned array of floats due to floating-point rounding in the + equation ``low + (high-low) * random_sample()``. For example: + + >>> x = np.float32(5*0.99999999) + >>> x + 5.0 + + + Examples:: + + Draw samples from the distribution: + + >>> s = bm.random.uniform(-1,0,1000) + + All values are within the given interval: + + >>> np.all(s >= -1) + True + >>> np.all(s < 0) + True + + Display the histogram of the samples, along with the + probability density function: + + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, 15, density=True) + >>> plt.plot(bins, np.ones_like(bins), linewidth=2, color='r') + >>> plt.show() + """ + return DEFAULT.uniform(low, high, size, key=key) + + +def truncated_normal(lower, upper, size: Optional[Union[int, Sequence[int]]] = None, loc=0., scale=1., dtype=float, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r"""Sample truncated standard normal random values with given shape and dtype. + + Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + + + Notes:: + + This distribution is the normal distribution centered on ``loc`` (default + 0), with standard deviation ``scale`` (default 1), and clipped at ``a``, + ``b`` standard deviations to the left, right (respectively) from ``loc``. + If ``myclip_a`` and ``myclip_b`` are clip values in the sample space (as + opposed to the number of standard deviations) then they can be converted + to the required form according to:: + + a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale + + + Parameters:: + + lower : float, ndarray + A float or array of floats representing the lower bound for + truncation. Must be broadcast-compatible with ``upper``. + upper : float, ndarray + A float or array of floats representing the upper bound for + truncation. Must be broadcast-compatible with ``lower``. + loc : float, ndarray + Mean ("centre") of the distribution before truncating. Note that + the mean of the truncated distribution will not be exactly equal + to ``loc``. + size : optional, list of int, tuple of int + A tuple of nonnegative integers specifying the result + shape. Must be broadcast-compatible with ``lower`` and ``upper``. The + default (None) produces a result shape by broadcasting ``lower`` and + ``upper``. + loc: optional, float, ndarray + A float or array of floats representing the mean of the + distribution. Default is 0. + scale : float, ndarray + Standard deviation (spread or "width") of the distribution. Must be + non-negative. Default is 1. + dtype: optional + The float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + key: jax.Array + The key for random generator. Consistent with the jax's random + paradigm. + + Returns:: + + out : Array + A random array with the specified dtype and shape given by ``shape`` if + ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. + Returns values in the open interval ``(lower, upper)``. + """ + return DEFAULT.truncated_normal(lower, upper, size, loc, scale, dtype=dtype, key=key) + + +RandomState.truncated_normal.__doc__ = truncated_normal.__doc__ + + +def bernoulli(p=0.5, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r"""Sample Bernoulli random values with given shape and mean. + + Parameters:: + + p: float, array_like, optional + A float or array of floats for the mean of the random + variables. Must be broadcast-compatible with ``shape`` and the values + should be within [0, 1]. Default 0.5. + size: optional, tuple of int, int + A tuple of nonnegative integers representing the result + shape. Must be broadcast-compatible with ``p.shape``. The default (None) + produces a result shape equal to ``p.shape``. + + Returns:: + + out: array_like + A random array with boolean dtype and shape given by ``shape`` if ``shape`` + is not None, or else ``p.shape``. + """ + return DEFAULT.bernoulli(p, size, key=key) + + +def lognormal(mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a log-normal distribution. + + Draw samples from a log-normal distribution with specified mean, + standard deviation, and array shape. Note that the mean and standard + deviation are not the values for the distribution itself, but of the + underlying normal distribution it is derived from. + + Parameters:: + + mean : float or array_like of floats, optional + Mean value of the underlying normal distribution. Default is 0. + sigma : float or array_like of floats, optional + Standard deviation of the underlying normal distribution. Must be + non-negative. Default is 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``mean`` and ``sigma`` are both scalars. + Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized log-normal distribution. + + See Also:: + + scipy.stats.lognorm : probability density function, distribution, + cumulative density function, etc. + + Notes:: + + A variable `x` has a log-normal distribution if `log(x)` is normally + distributed. The probability density function for the log-normal + distribution is: + + .. math:: p(x) = \frac{1}{\sigma x \sqrt{2\pi}} + e^{(-\frac{(ln(x)-\mu)^2}{2\sigma^2})} + + where :math:`\mu` is the mean and :math:`\sigma` is the standard + deviation of the normally distributed logarithm of the variable. + A log-normal distribution results if a random variable is the *product* + of a large number of independent, identically-distributed variables in + the same way that a normal distribution results if the variable is the + *sum* of a large number of independent, identically-distributed + variables. + + References:: + + .. [1] Limpert, E., Stahel, W. A., and Abbt, M., "Log-normal + Distributions across the Sciences: Keys and Clues," + BioScience, Vol. 51, No. 5, May, 2001. + https://stat.ethz.ch/~stahel/lognormal/bioscience.pdf + .. [2] Reiss, R.D. and Thomas, M., "Statistical Analysis of Extreme + Values," Basel: Birkhauser Verlag, 2001, pp. 31-32. + + Examples:: + + Draw samples from the distribution: + + >>> mu, sigma = 3., 1. # mean and standard deviation + >>> s = bm.random.lognormal(mu, sigma, 1000) + + Display the histogram of the samples, along with + the probability density function: + + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, 100, density=True, align='mid') + + >>> x = np.linspace(min(bins), max(bins), 10000) + >>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2)) + ... / (x * sigma * np.sqrt(2 * np.pi))) + + >>> plt.plot(x, pdf, linewidth=2, color='r') + >>> plt.axis('tight') + >>> plt.show() + + Demonstrate that taking the products of random samples from a uniform + distribution can be fit well by a log-normal probability density + function. + + >>> # Generate a thousand samples: each is the product of 100 random + >>> # values, drawn from a normal distribution. + >>> b = [] + >>> for i in range(1000): + ... a = 10. + bm.random.standard_normal(100) + ... b.append(np.product(a)) + + >>> b = np.array(b) / np.min(b) # scale values to be positive + >>> count, bins, ignored = plt.hist(b, 100, density=True, align='mid') + >>> sigma = np.std(np.log(b)) + >>> mu = np.mean(np.log(b)) + + >>> x = np.linspace(min(bins), max(bins), 10000) + >>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2)) + ... / (x * sigma * np.sqrt(2 * np.pi))) + + >>> plt.plot(x, pdf, color='r', linewidth=2) + >>> plt.show() + """ + return DEFAULT.lognormal(mean, sigma, size, key=key) + + +def binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a binomial distribution. + + Samples are drawn from a binomial distribution with specified + parameters, n trials and p probability of success where + n an integer >= 0 and p is in the interval [0,1]. (n may be + input as a float, but it is truncated to an integer in use) + + Parameters:: + + n : int or array_like of ints + Parameter of the distribution, >= 0. Floats are also accepted, + but they will be truncated to integers. + p : float or array_like of floats + Parameter of the distribution, >= 0 and <=1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``n`` and ``p`` are both scalars. + Otherwise, ``np.broadcast(n, p).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized binomial distribution, where + each sample is equal to the number of successes over the n trials. + + See Also:: + + scipy.stats.binom : probability density function, distribution or + cumulative density function, etc. + + Notes:: + + The probability density for the binomial distribution is + + .. math:: P(N) = \binom{n}{N}p^N(1-p)^{n-N}, + + where :math:`n` is the number of trials, :math:`p` is the probability + of success, and :math:`N` is the number of successes. + + When estimating the standard error of a proportion in a population by + using a random sample, the normal distribution works well unless the + product p*n <=5, where p = population proportion estimate, and n = + number of samples, in which case the binomial distribution is used + instead. For example, a sample of 15 people shows 4 who are left + handed, and 11 who are right handed. Then p = 4/15 = 27%. 0.27*15 = 4, + so the binomial distribution should be used in this case. + + References:: + + .. [1] Dalgaard, Peter, "Introductory Statistics with R", + Springer-Verlag, 2002. + .. [2] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill, + Fifth Edition, 2002. + .. [3] Lentner, Marvin, "Elementary Applied Statistics", Bogden + and Quigley, 1972. + .. [4] Weisstein, Eric W. "Binomial Distribution." From MathWorld--A + Wolfram Web Resource. + http://mathworld.wolfram.com/BinomialDistribution.html + .. [5] Wikipedia, "Binomial distribution", + https://en.wikipedia.org/wiki/Binomial_distribution + + Examples:: + + Draw samples from the distribution: + + >>> n, p = 10, .5 # number of trials, probability of each trial + >>> s = bm.random.binomial(n, p, 1000) + # result of flipping a coin 10 times, tested 1000 times. + + A real world example. A company drills 9 wild-cat oil exploration + wells, each with an estimated probability of success of 0.1. All nine + wells fail. What is the probability of that happening? + + Let's do 20,000 trials of the model, and count the number that + generate zero positive results. + + >>> sum(bm.random.binomial(9, 0.1, 20000) == 0)/20000. + # answer = 0.38885, or 38%. + """ + return DEFAULT.binomial(n, p, size, key=key) + + +def chisquare(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a chi-square distribution. + + When `df` independent random variables, each with standard normal + distributions (mean 0, variance 1), are squared and summed, the + resulting distribution is chi-square (see Notes). This distribution + is often used in hypothesis testing. + + Parameters:: + + df : float or array_like of floats + Number of degrees of freedom, must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``df`` is a scalar. Otherwise, + ``np.array(df).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized chi-square distribution. + + Raises:: + + ValueError + When `df` <= 0 or when an inappropriate `size` (e.g. ``size=-1``) + is given. + + Notes:: + + The variable obtained by summing the squares of `df` independent, + standard normally distributed random variables: + + .. math:: Q = \sum_{i=0}^{\mathtt{df}} X^2_i + + is chi-square distributed, denoted + + .. math:: Q \sim \chi^2_k. + + The probability density function of the chi-squared distribution is + + .. math:: p(x) = \frac{(1/2)^{k/2}}{\Gamma(k/2)} + x^{k/2 - 1} e^{-x/2}, + + where :math:`\Gamma` is the gamma function, + + .. math:: \Gamma(x) = \int_0^{-\infty} t^{x - 1} e^{-t} dt. + + References:: + + .. [1] NIST "Engineering Statistics Handbook" + https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm + + Examples:: + + >>> bm.random.chisquare(2,4) + array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random + """ + return DEFAULT.chisquare(df, size, key=key) + + +def dirichlet(alpha, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from the Dirichlet distribution. + + Draw `size` samples of dimension k from a Dirichlet distribution. A + Dirichlet-distributed random variable can be seen as a multivariate + generalization of a Beta distribution. The Dirichlet distribution + is a conjugate prior of a multinomial distribution in Bayesian + inference. + + Parameters:: + + alpha : sequence of floats, length k + Parameter of the distribution (length ``k`` for sample of + length ``k``). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + vector of length ``k`` is returned. + + Returns:: + + samples : ndarray, + The drawn samples, of shape ``(size, k)``. + + Raises:: + + ValueError + If any value in ``alpha`` is less than or equal to zero + + Notes:: + + The Dirichlet distribution is a distribution over vectors + :math:`x` that fulfil the conditions :math:`x_i>0` and + :math:`\sum_{i=1}^k x_i = 1`. + + The probability density function :math:`p` of a + Dirichlet-distributed random vector :math:`X` is + proportional to + + .. math:: p(x) \propto \prod_{i=1}^{k}{x^{\alpha_i-1}_i}, + + where :math:`\alpha` is a vector containing the positive + concentration parameters. + + The method uses the following property for computation: let :math:`Y` + be a random vector which has components that follow a standard gamma + distribution, then :math:`X = \frac{1}{\sum_{i=1}^k{Y_i}} Y` + is Dirichlet-distributed + + References:: + + .. [1] David McKay, "Information Theory, Inference and Learning + Algorithms," chapter 23, + http://www.inference.org.uk/mackay/itila/ + .. [2] Wikipedia, "Dirichlet distribution", + https://en.wikipedia.org/wiki/Dirichlet_distribution + + Examples:: + + Taking an example cited in Wikipedia, this distribution can be used if + one wanted to cut strings (each of initial length 1.0) into K pieces + with different lengths, where each piece had, on average, a designated + average length, but allowing some variation in the relative sizes of + the pieces. + + >>> s = bm.random.dirichlet((10, 5, 3), 20).transpose() + + >>> import matplotlib.pyplot as plt + >>> plt.barh(range(20), s[0]) + >>> plt.barh(range(20), s[1], left=s[0], color='g') + >>> plt.barh(range(20), s[2], left=s[0]+s[1], color='r') + >>> plt.title("Lengths of Strings") + """ + return DEFAULT.dirichlet(alpha, size, key=key) + + +def geometric(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from the geometric distribution. + + Bernoulli trials are experiments with one of two outcomes: + success or failure (an example of such an experiment is flipping + a coin). The geometric distribution models the number of trials + that must be run in order to achieve success. It is therefore + supported on the positive integers, ``k = 1, 2, ...``. + + The probability mass function of the geometric distribution is + + .. math:: f(k) = (1 - p)^{k - 1} p + + where `p` is the probability of success of an individual trial. + + Parameters:: + + p : float or array_like of floats + The probability of success of an individual trial. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``p`` is a scalar. Otherwise, + ``np.array(p).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized geometric distribution. + + Examples:: + + Draw ten thousand values from the geometric distribution, + with the probability of an individual success equal to 0.35: + + >>> z = bm.random.geometric(p=0.35, size=10000) + + How many trials succeeded after a single run? + + >>> (z == 1).sum() / 10000. + 0.34889999999999999 #random + """ + return DEFAULT.geometric(p, size, key=key) + + +def f(dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from an F distribution. + + Samples are drawn from an F distribution with specified parameters, + `dfnum` (degrees of freedom in numerator) and `dfden` (degrees of + freedom in denominator), where both parameters must be greater than + zero. + + The random variate of the F distribution (also known as the + Fisher distribution) is a continuous probability distribution + that arises in ANOVA tests, and is the ratio of two chi-square + variates. + + Parameters:: + + dfnum : float or array_like of floats + Degrees of freedom in numerator, must be > 0. + dfden : float or array_like of float + Degrees of freedom in denominator, must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``dfnum`` and ``dfden`` are both scalars. + Otherwise, ``np.broadcast(dfnum, dfden).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Fisher distribution. + + See Also:: + + scipy.stats.f : probability density function, distribution or + cumulative density function, etc. + + Notes:: + + The F statistic is used to compare in-group variances to between-group + variances. Calculating the distribution depends on the sampling, and + so it is a function of the respective degrees of freedom in the + problem. The variable `dfnum` is the number of samples minus one, the + between-groups degrees of freedom, while `dfden` is the within-groups + degrees of freedom, the sum of the number of samples in each group + minus the number of groups. + + References:: + + .. [1] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill, + Fifth Edition, 2002. + .. [2] Wikipedia, "F-distribution", + https://en.wikipedia.org/wiki/F-distribution + + Examples:: + + An example from Glantz[1], pp 47-40: + + Two groups, children of diabetics (25 people) and children from people + without diabetes (25 controls). Fasting blood glucose was measured, + case group had a mean value of 86.1, controls had a mean value of + 82.2. Standard deviations were 2.09 and 2.49 respectively. Are these + data consistent with the null hypothesis that the parents diabetic + status does not affect their children's blood glucose levels? + Calculating the F statistic from the data gives a value of 36.01. + + Draw samples from the distribution: + + >>> dfnum = 1. # between group degrees of freedom + >>> dfden = 48. # within groups degrees of freedom + >>> s = bm.random.f(dfnum, dfden, 1000) + + The lower bound for the top 1% of the samples is : + + >>> np.sort(s)[-10] + 7.61988120985 # random + + So there is about a 1% chance that the F statistic will exceed 7.62, + the measured value is 36, so the null hypothesis is rejected at the 1% + level. + """ + return DEFAULT.f(dfnum, dfden, size, key=key) + + +def hypergeometric(ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Hypergeometric distribution. + + Samples are drawn from a hypergeometric distribution with specified + parameters, `ngood` (ways to make a good selection), `nbad` (ways to make + a bad selection), and `nsample` (number of items sampled, which is less + than or equal to the sum ``ngood + nbad``). + + Parameters:: - Returns: - The random data. - """ - return self.randn(*shape(input), key=key).astype(dtype) + ngood : int or array_like of ints + Number of ways to make a good selection. Must be nonnegative. + nbad : int or array_like of ints + Number of ways to make a bad selection. Must be nonnegative. + nsample : int or array_like of ints + Number of items sampled. Must be at least 1 and at most + ``ngood + nbad``. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if `ngood`, `nbad`, and `nsample` + are all scalars. Otherwise, ``np.broadcast(ngood, nbad, nsample).size`` + samples are drawn. - def randint_like(self, input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - if high is None: - high = max(input) - return self.randint(low, high=high, size=shape(input), dtype=dtype, key=key) + Returns:: + out : ndarray or scalar + Drawn samples from the parameterized hypergeometric distribution. Each + sample is the number of good items within a randomly selected subset of + size `nsample` taken from a set of `ngood` good items and `nbad` bad items. -# alias -Generator = RandomState + See Also:: -# default random generator -__a = Array(None) -__a._value = np.random.randint(0, 10000, size=2, dtype=np.uint32) -DEFAULT = RandomState(__a) -del __a + scipy.stats.hypergeom : probability density function, distribution or + cumulative density function, etc. + Notes:: -def split_key(): - """Create a new seed from the current seed. + The probability density for the Hypergeometric distribution is - This function is useful for the consistency with JAX's random paradigm.""" - return DEFAULT.split_key() + .. math:: P(x) = \frac{\binom{g}{x}\binom{b}{n-x}}{\binom{g+b}{n}}, + where :math:`0 \le x \le n` and :math:`n-b \le x \le g` -def split_keys(n): - """Create multiple seeds from the current seed. This is used - internally by `pmap` and `vmap` to ensure that random numbers - are different in parallel threads. + for P(x) the probability of ``x`` good results in the drawn sample, + g = `ngood`, b = `nbad`, and n = `nsample`. - .. versionadded:: 2.4.5 + Consider an urn with black and white marbles in it, `ngood` of them + are black and `nbad` are white. If you draw `nsample` balls without + replacement, then the hypergeometric distribution describes the + distribution of black balls in the drawn sample. - Parameters:: - - n : int - The number of seeds to generate. - """ - return DEFAULT.split_keys(n) + Note that this distribution is very similar to the binomial + distribution, except that in this case, samples are drawn without + replacement, whereas in the Binomial case samples are drawn with + replacement (or the sample space is infinite). As the sample space + becomes large, this distribution approaches the binomial. + References:: -def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState: - """Clone the random state according to the given setting. + .. [1] Lentner, Marvin, "Elementary Applied Statistics", Bogden + and Quigley, 1972. + .. [2] Weisstein, Eric W. "Hypergeometric Distribution." From + MathWorld--A Wolfram Web Resource. + http://mathworld.wolfram.com/HypergeometricDistribution.html + .. [3] Wikipedia, "Hypergeometric distribution", + https://en.wikipedia.org/wiki/Hypergeometric_distribution - Args: - seed_or_key: The seed (an integer) or the random key. - clone: Bool. Whether clone the default random state. + Examples:: - Returns: - The random state. - """ - if seed_or_key is None: - return DEFAULT.clone() if clone else DEFAULT - else: - return RandomState(seed_or_key) + Draw samples from the distribution: + >>> ngood, nbad, nsamp = 100, 2, 10 + # number of good, number of bad, and number of samples + >>> s = bm.random.hypergeometric(ngood, nbad, nsamp, 1000) + >>> from matplotlib.pyplot import hist + >>> hist(s) + # note that it is very unlikely to grab both bad items -def default_rng(seed_or_key=None, clone: bool = True) -> RandomState: - if seed_or_key is None: - return DEFAULT.clone() if clone else DEFAULT - else: - return RandomState(seed_or_key) + Suppose you have an urn with 15 white and 15 black marbles. + If you pull 15 marbles at random, how likely is it that + 12 or more of them are one color? + >>> s = bm.random.hypergeometric(15, 15, 15, 100000) + >>> sum(s>=12)/100000. + sum(s<=3)/100000. + # answer = 0.003 ... pretty unlikely! + """ + return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key) -def seed(seed: int = None): - """Sets a new random seed. - Parameters:: - - seed: int, optional - The random seed. - """ - with jax.ensure_compile_time_eval(): - if seed is None: - seed = np.random.randint(0, 100000) - np.random.seed(seed) - DEFAULT.seed(seed) +def logseries(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a logarithmic series distribution. + Samples are drawn from a log series distribution with specified + shape parameter, 0 <= ``p`` < 1. -def rand(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r"""Random values in a given shape. - - .. note:: - This is a convenience function for users porting code from Matlab, - and wraps `random_sample`. That function takes a - tuple to specify the size of the output, which is consistent with - other NumPy functions like `numpy.zeros` and `numpy.ones`. - - Create an array of the given shape and populate it with - random samples from a uniform distribution - over ``[0, 1)``. - - Parameters:: - - d0, d1, ..., dn : int, optional - The dimensions of the returned array, must be non-negative. - If no argument is given a single Python float is returned. - - Returns:: - - out : ndarray, shape ``(d0, d1, ..., dn)`` - Random values. - - See Also:: - - random - - Examples:: - - >>> brainpy.math.random.rand(3,2) - array([[ 0.14022471, 0.96360618], #random - [ 0.37601032, 0.25528411], #random - [ 0.49313049, 0.94909878]]) #random - """ - return DEFAULT.rand(*dn, key=key) + Parameters:: + p : float or array_like of floats + Shape parameter for the distribution. Must be in the range [0, 1). + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``p`` is a scalar. Otherwise, + ``np.array(p).size`` samples are drawn. -def randint(low, high=None, size: Optional[Union[int, Sequence[int]]] = None, dtype=int, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r"""Return random integers from `low` (inclusive) to `high` (exclusive). - - Return random integers from the "discrete uniform" distribution of - the specified dtype in the "half-open" interval [`low`, `high`). If - `high` is None (the default), then results are from [0, `low`). - - Parameters:: - - low : int or array-like of ints - Lowest (signed) integers to be drawn from the distribution (unless - ``high=None``, in which case this parameter is one above the - *highest* such integer). - high : int or array-like of ints, optional - If provided, one above the largest (signed) integer to be drawn - from the distribution (see above for behavior if ``high=None``). - If array-like, must contain integer values - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - dtype : dtype, optional - Desired dtype of the result. Byteorder must be native. - The default value is int. - - Returns:: - - out : int or ndarray of ints - `size`-shaped array of random integers from the appropriate - distribution, or a single such random int if `size` not provided. - - See Also:: - - random_integers : similar to `randint`, only for the closed - interval [`low`, `high`], and 1 is the lowest value if `high` is - omitted. - Generator.integers: which should be used for new code. - - Examples:: - - >>> import brainpy.math as bm - >>> bm.random.randint(2, size=10) - array([1, 0, 0, 0, 1, 1, 0, 0, 1, 0]) # random - >>> bm.random.randint(1, size=10) - array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - - Generate a 2 x 4 array of ints between 0 and 4, inclusive: - - >>> bm.random.randint(5, size=(2, 4)) - array([[4, 0, 2, 1], # random - [3, 2, 2, 0]]) - - Generate a 1 x 3 array with 3 different upper bounds - - >>> bm.random.randint(1, [3, 5, 10]) - array([2, 2, 9]) # random - - Generate a 1 by 3 array with 3 different lower bounds - - >>> bm.random.randint([1, 5, 7], 10) - array([9, 8, 7]) # random - - Generate a 2 by 4 array using broadcasting with dtype of uint8 - - >>> bm.random.randint([1, 3, 5, 7], [[10], [20]], dtype=np.uint8) - array([[ 8, 6, 9, 7], # random - [ 1, 16, 9, 12]], dtype=uint8) - """ - - return DEFAULT.randint(low, high=high, size=size, dtype=dtype, key=key) + Returns:: + out : ndarray or scalar + Drawn samples from the parameterized logarithmic series distribution. -def random_integers(low, - high=None, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Random integers of type `np.int_` between `low` and `high`, inclusive. - - Return random integers of type `np.int_` from the "discrete uniform" - distribution in the closed interval [`low`, `high`]. If `high` is - None (the default), then results are from [1, `low`]. The `np.int_` - type translates to the C long integer type and its precision - is platform dependent. - - Parameters:: - - low : int - Lowest (signed) integer to be drawn from the distribution (unless - ``high=None``, in which case this parameter is the *highest* such - integer). - high : int, optional - If provided, the largest (signed) integer to be drawn from the - distribution (see above for behavior if ``high=None``). - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - - Returns:: - - out : int or ndarray of ints - `size`-shaped array of random integers from the appropriate - distribution, or a single such random int if `size` not provided. - - See Also:: - - randint : Similar to `random_integers`, only for the half-open - interval [`low`, `high`), and 0 is the lowest value if `high` is - omitted. - - Notes:: - - To sample from N evenly spaced floating-point numbers between a and b, - use:: - - a + (b - a) * (bm.random.random_integers(N) - 1) / (N - 1.) - - Examples:: - - >>> import brainpy.math as bm - >>> bm.random.random_integers(5) - 4 # random - >>> type(bm.random.random_integers(5)) - - >>> bm.random.random_integers(5, size=(3,2)) - array([[5, 4], # random - [3, 3], - [4, 5]]) - - Choose five random numbers from the set of five evenly-spaced - numbers between 0 and 2.5, inclusive (*i.e.*, from the set - :math:`{0, 5/8, 10/8, 15/8, 20/8}`): - - >>> 2.5 * (bm.random.random_integers(5, size=(5,)) - 1) / 4. - array([ 0.625, 1.25 , 0.625, 0.625, 2.5 ]) # random - - Roll two six sided dice 1000 times and sum the results: - - >>> d1 = bm.random.random_integers(1, 6, 1000) - >>> d2 = bm.random.random_integers(1, 6, 1000) - >>> dsums = d1 + d2 - - Display results as a histogram: - - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(dsums, 11, density=True) - >>> plt.show() - """ - - return DEFAULT.random_integers(low, high=high, size=size, key=key) + See Also:: + scipy.stats.logser : probability density function, distribution or + cumulative density function, etc. -def randn(*dn, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Return a sample (or samples) from the "standard normal" distribution. + Notes:: - .. note:: - This is a convenience function for users porting code from Matlab, - and wraps `standard_normal`. That function takes a - tuple to specify the size of the output, which is consistent with - other NumPy functions like `numpy.zeros` and `numpy.ones`. + The probability density for the Log Series distribution is - .. note:: - New code should use the ``standard_normal`` method of a ``default_rng()`` - instance instead; please see the :ref:`random-quick-start`. + .. math:: P(k) = \frac{-p^k}{k \ln(1-p)}, - If positive int_like arguments are provided, `randn` generates an array - of shape ``(d0, d1, ..., dn)``, filled - with random floats sampled from a univariate "normal" (Gaussian) - distribution of mean 0 and variance 1. A single float randomly sampled - from the distribution is returned if no argument is provided. + where p = probability. - Parameters:: - - d0, d1, ..., dn : int, optional - The dimensions of the returned array, must be non-negative. - If no argument is given a single Python float is returned. + The log series distribution is frequently used to represent species + richness and occurrence, first proposed by Fisher, Corbet, and + Williams in 1943 [2]. It may also be used to model the numbers of + occupants seen in cars [3]. - Returns:: - - Z : ndarray or float - A ``(d0, d1, ..., dn)``-shaped array of floating-point samples from - the standard normal distribution, or a single such float if - no parameters were supplied. + References:: - See Also:: - - standard_normal : Similar, but takes a tuple as its argument. - normal : Also accepts mu and sigma arguments. + .. [1] Buzas, Martin A.; Culver, Stephen J., Understanding regional + species diversity through the log series distribution of + occurrences: BIODIVERSITY RESEARCH Diversity & Distributions, + Volume 5, Number 5, September 1999 , pp. 187-195(9). + .. [2] Fisher, R.A,, A.S. Corbet, and C.B. Williams. 1943. The + relation between the number of species and the number of + individuals in a random sample of an animal population. + Journal of Animal Ecology, 12:42-58. + .. [3] D. J. Hand, F. Daly, D. Lunn, E. Ostrowski, A Handbook of Small + Data Sets, CRC Press, 1994. + .. [4] Wikipedia, "Logarithmic distribution", + https://en.wikipedia.org/wiki/Logarithmic_distribution - Notes:: - - For random samples from :math:`N(\mu, \sigma^2)`, use: + Examples:: - ``sigma * bm.random.randn(...) + mu`` + Draw samples from the distribution: - Examples:: - - >>> import brainpy.math as bm - >>> bm.random.randn() - 2.1923875335537315 # random + >>> a = .6 + >>> s = bm.random.logseries(a, 10000) + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s) - Two-by-four array of samples from N(3, 6.25): + # plot against distribution - >>> 3 + 2.5 * bm.random.randn(2, 4) - array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random - [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random - """ + >>> def logseries(k, p): + ... return -p**k/(k*np.log(1-p)) + >>> plt.plot(bins, logseries(bins, a)*count.max()/ + ... logseries(bins, a).max(), 'r') + >>> plt.show() + """ + return DEFAULT.logseries(p, size, key=key) - return DEFAULT.randn(*dn, key=key) +def multinomial(n, pvals, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a multinomial distribution. -def random(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Return random floats in the half-open interval [0.0, 1.0). Alias for - `random_sample` to ease forward-porting to the new random API. - """ - return DEFAULT.random(size, key=key) + The multinomial distribution is a multivariate generalization of the + binomial distribution. Take an experiment with one of ``p`` + possible outcomes. An example of such an experiment is throwing a dice, + where the outcome can be 1 through 6. Each sample drawn from the + distribution represents `n` such experiments. Its values, + ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the + outcome was ``i``. + Parameters:: -def random_sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Return random floats in the half-open interval [0.0, 1.0). - - Results are from the "continuous uniform" distribution over the - stated interval. To sample :math:`Unif[a, b), b > a` multiply - the output of `random_sample` by `(b-a)` and add `a`:: - - (b - a) * random_sample() + a - - .. note:: - New code should use the ``random`` method of a ``default_rng()`` - instance instead; please see the :ref:`random-quick-start`. - - Parameters:: - - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - - Returns:: - - out : float or ndarray of floats - Array of random floats of shape `size` (unless ``size=None``, in which - case a single float is returned). - - See Also:: - - Generator.random: which should be used for new code. - - Examples:: - - >>> import brainpy.math as bm - >>> bm.random.random_sample() - 0.47108547995356098 # random - >>> type(bm.random.random_sample()) - - >>> bm.random.random_sample((5,)) - array([ 0.30220482, 0.86820401, 0.1654503 , 0.11659149, 0.54323428]) # random - - Three-by-two array of random numbers from [-5, 0): - - >>> 5 * bm.random.random_sample((3, 2)) - 5 - array([[-3.99149989, -0.52338984], # random - [-2.99091858, -0.79479508], - [-1.23204345, -1.75224494]]) - """ - return DEFAULT.random_sample(size, key=key) + n : int + Number of experiments. + pvals : sequence of floats, length p + Probabilities of each of the ``p`` different outcomes. These + must sum to 1 (however, the last element is always assumed to + account for the remaining probability, as long as + ``sum(pvals[:-1]) <= 1)``. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + Returns:: -def ranf(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - This is an alias of `random_sample`. See `random_sample` for the complete - documentation. - """ - return DEFAULT.ranf(size, key=key) + out : ndarray + The drawn samples, of shape *size*, if that was provided. If not, + the shape is ``(N,)``. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional + value drawn from the distribution. -def sample(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """ - This is an alias of `random_sample`. See `random_sample` for the complete - documentation. - """ - return DEFAULT.sample(size, key=key) + Examples:: + Throw a dice 20 times: -def choice(a, size: Optional[Union[int, Sequence[int]]] = None, replace=True, p=None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Generates a random sample from a given 1-D array - - Parameters:: - - a : 1-D array-like or int - If an ndarray, a random sample is generated from its elements. - If an int, the random sample is generated as if it were ``np.arange(a)`` - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - replace : boolean, optional - Whether the sample is with or without replacement. Default is True, - meaning that a value of ``a`` can be selected multiple times. - p : 1-D array-like, optional - The probabilities associated with each entry in a. - If not given, the sample assumes a uniform distribution over all - entries in ``a``. - - Returns:: - - samples : single item or ndarray - The generated random samples - - Raises:: - - ValueError - If a is an int and less than zero, if a or p are not 1-dimensional, - if a is an array-like of size 0, if p is not a vector of - probabilities, if a and p have different lengths, or if - replace=False and the sample size is greater than the population - size - - See Also:: - - randint, shuffle, permutation - Generator.choice: which should be used in new code - - Notes:: - - Setting user-specified probabilities through ``p`` uses a more general but less - efficient sampler than the default. The general sampler produces a different sample - than the optimized sampler even if each element of ``p`` is 1 / len(a). - - Sampling random rows from a 2-D array is not possible with this function, - but is possible with `Generator.choice` through its ``axis`` keyword. - - Examples:: - - Generate a uniform random sample from np.arange(5) of size 3: - - >>> import brainpy.math as bm - >>> bm.random.choice(5, 3) - array([0, 3, 4]) # random - >>> #This is equivalent to brainpy.math.random.randint(0,5,3) - - Generate a non-uniform random sample from np.arange(5) of size 3: - - >>> bm.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0]) - array([3, 3, 0]) # random - - Generate a uniform random sample from np.arange(5) of size 3 without - replacement: - - >>> bm.random.choice(5, 3, replace=False) - array([3,1,0]) # random - >>> #This is equivalent to brainpy.math.random.permutation(np.arange(5))[:3] - - Generate a non-uniform random sample from np.arange(5) of size - 3 without replacement: - - >>> bm.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) - array([2, 3, 0]) # random - - Any of the above can be repeated with an arbitrary array-like - instead of just integers. For instance: - - >>> aa_milne_arr = ['pooh', 'rabbit', 'piglet', 'Christopher'] - >>> bm.random.choice(aa_milne_arr, 5, p=[0.5, 0.1, 0.1, 0.3]) - array(['pooh', 'pooh', 'pooh', 'Christopher', 'piglet'], # random - dtype='>> bm.random.multinomial(20, [1/6.]*6, size=1) + array([[4, 1, 7, 5, 2, 1]]) # random + It landed 4 times on 1, once on 2, etc. -def permutation(x, - axis: int = 0, - independent: bool = False, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Randomly permute a sequence, or return a permuted range. - - If `x` is a multi-dimensional array, it is only shuffled along its - first index. - - Parameters:: - - x : int or array_like - If `x` is an integer, randomly permute ``np.arange(x)``. - If `x` is an array, make a copy and shuffle the elements - randomly. - - Returns:: - - out : ndarray - Permuted sequence or array range. - - Examples:: - - >>> import brainpy.math as bm - >>> bm.random.permutation(10) - array([1, 7, 4, 3, 0, 9, 2, 5, 8, 6]) # random - - >>> bm.random.permutation([1, 4, 9, 12, 15]) - array([15, 1, 9, 4, 12]) # random - - >>> arr = np.arange(9).reshape((3, 3)) - >>> bm.random.permutation(arr) - array([[6, 7, 8], # random - [0, 1, 2], - [3, 4, 5]]) - """ - return DEFAULT.permutation(x, axis=axis, independent=independent, key=key) + Now, throw the dice 20 times, and 20 times again: + >>> bm.random.multinomial(20, [1/6.]*6, size=2) + array([[3, 4, 3, 3, 4, 3], # random + [2, 4, 3, 4, 0, 7]]) -def shuffle(x, axis=0, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Modify a sequence in-place by shuffling its contents. - - This function only shuffles the array along the first axis of a - multi-dimensional array. The order of sub-arrays is changed but - their contents remains the same. - - Parameters:: - - x : ndarray or MutableSequence - The array, list or mutable sequence to be shuffled. - - Returns:: - - None - - Examples:: - - >>> import brainpy.math as bm - >>> arr = np.arange(10) - >>> bm.random.shuffle(arr) - >>> arr - [1 7 5 2 9 4 3 6 0 8] # random - - Multi-dimensional arrays are only shuffled along the first axis: - - >>> arr = np.arange(9).reshape((3, 3)) - >>> bm.random.shuffle(arr) - >>> arr - array([[3, 4, 5], # random - [6, 7, 8], - [0, 1, 2]]) - """ - DEFAULT.shuffle(x, axis, key=key) + For the first run, we threw 3 times 1, 4 times 2, etc. For the second, + we threw 2 times 1, 4 times 2, etc. + A loaded die is more likely to land on number 6: -def beta(a, b, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Beta distribution. + >>> bm.random.multinomial(100, [1/7.]*5 + [2/7.]) + array([11, 16, 14, 17, 16, 26]) # random - The Beta distribution is a special case of the Dirichlet distribution, - and is related to the Gamma distribution. It has the probability - distribution function + The probability inputs should be normalized. As an implementation + detail, the value of the last entry is ignored and assumed to take + up any leftover probability mass, but this should not be relied on. + A biased coin which has twice as much weight on one side as on the + other should be sampled like so: - .. math:: f(x; a,b) = \frac{1}{B(\alpha, \beta)} x^{\alpha - 1} - (1 - x)^{\beta - 1}, + >>> bm.random.multinomial(100, [1.0 / 3, 2.0 / 3]) # RIGHT + array([38, 62]) # random - where the normalization, B, is the beta function, + not like: - .. math:: B(\alpha, \beta) = \int_0^1 t^{\alpha - 1} - (1 - t)^{\beta - 1} dt. + >>> bm.random.multinomial(100, [1.0, 2.0]) # WRONG + Traceback (most recent call last): + ValueError: pvals < 0, pvals > 1 or pvals contains NaNs + """ + return DEFAULT.multinomial(n, pvals, size, key=key) - It is often seen in Bayesian inference and order statistics. - Parameters:: - - a : float or array_like of floats - Alpha, positive (>0). - b : float or array_like of floats - Beta, positive (>0). - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``a`` and ``b`` are both scalars. - Otherwise, ``np.broadcast(a, b).size`` samples are drawn. +def multivariate_normal(mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky', + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw random samples from a multivariate normal distribution. - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized beta distribution. - """ - return DEFAULT.beta(a, b, size=size, key=key) + The multivariate normal, multinormal or Gaussian distribution is a + generalization of the one-dimensional normal distribution to higher + dimensions. Such a distribution is specified by its mean and + covariance matrix. These parameters are analogous to the mean + (average or "center") and variance (standard deviation, or "width," + squared) of the one-dimensional normal distribution. + Parameters:: -def exponential(scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from an exponential distribution. - - Its probability density function is - - .. math:: f(x; \frac{1}{\beta}) = \frac{1}{\beta} \exp(-\frac{x}{\beta}), - - for ``x > 0`` and 0 elsewhere. :math:`\beta` is the scale parameter, - which is the inverse of the rate parameter :math:`\lambda = 1/\beta`. - The rate parameter is an alternative, widely used parameterization - of the exponential distribution [3]_. - - The exponential distribution is a continuous analogue of the - geometric distribution. It describes many common situations, such as - the size of raindrops measured over many rainstorms [1]_, or the time - between page requests to Wikipedia [2]_. - - Parameters:: - - scale : float or array_like of floats - The scale parameter, :math:`\beta = 1/\lambda`. Must be - non-negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``scale`` is a scalar. Otherwise, - ``np.array(scale).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized exponential distribution. - - References:: - - .. [1] Peyton Z. Peebles Jr., "Probability, Random Variables and - Random Signal Principles", 4th ed, 2001, p. 57. - .. [2] Wikipedia, "Poisson process", - https://en.wikipedia.org/wiki/Poisson_process - .. [3] Wikipedia, "Exponential distribution", - https://en.wikipedia.org/wiki/Exponential_distribution - """ - return DEFAULT.exponential(scale, size, key=key) + mean : 1-D array_like, of length N + Mean of the N-dimensional distribution. + cov : 2-D array_like, of shape (N, N) + Covariance matrix of the distribution. It must be symmetric and + positive-semidefinite for proper sampling. + size : int or tuple of ints, optional + Given a shape of, for example, ``(m,n,k)``, ``m*n*k`` samples are + generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because + each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``. + If no shape is specified, a single (`N`-D) sample is returned. + check_valid : { 'warn', 'raise', 'ignore' }, optional + Behavior when the covariance matrix is not positive semidefinite. + tol : float, optional + Tolerance when checking the singular values in covariance matrix. + cov is cast to double before the check. + Returns:: -def gamma(shape, scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Gamma distribution. + out : ndarray + The drawn samples, of shape *size*, if that was provided. If not, + the shape is ``(N,)``. - Samples are drawn from a Gamma distribution with specified parameters, - `shape` (sometimes designated "k") and `scale` (sometimes designated - "theta"), where both parameters are > 0. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional + value drawn from the distribution. - Parameters:: - - shape : float or array_like of floats - The shape of the gamma distribution. Must be non-negative. - scale : float or array_like of floats, optional - The scale of the gamma distribution. Must be non-negative. - Default is equal to 1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``shape`` and ``scale`` are both scalars. - Otherwise, ``np.broadcast(shape, scale).size`` samples are drawn. + Notes:: - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized gamma distribution. + The mean is a coordinate in N-dimensional space, which represents the + location where samples are most likely to be generated. This is + analogous to the peak of the bell curve for the one-dimensional or + univariate normal distribution. + Covariance indicates the level to which two variables vary together. + From the multivariate normal distribution, we draw N-dimensional + samples, :math:`X = [x_1, x_2, ... x_N]`. The covariance matrix + element :math:`C_{ij}` is the covariance of :math:`x_i` and :math:`x_j`. + The element :math:`C_{ii}` is the variance of :math:`x_i` (i.e. its + "spread"). - Notes:: - - The probability density for the Gamma distribution is + Instead of specifying the full covariance matrix, popular + approximations include: - .. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)}, + - Spherical covariance (`cov` is a multiple of the identity matrix) + - Diagonal covariance (`cov` has non-negative elements, and only on + the diagonal) - where :math:`k` is the shape and :math:`\theta` the scale, - and :math:`\Gamma` is the Gamma function. + This geometrical property can be seen in two dimensions by plotting + generated data-points: - The Gamma distribution is often used to model the times to failure of - electronic components, and arises naturally in processes for which the - waiting times between Poisson distributed events are relevant. + >>> mean = [0, 0] + >>> cov = [[1, 0], [0, 100]] # diagonal covariance - References:: - - .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A - Wolfram Web Resource. - http://mathworld.wolfram.com/GammaDistribution.html - .. [2] Wikipedia, "Gamma distribution", - https://en.wikipedia.org/wiki/Gamma_distribution + Diagonal covariance means that points are oriented along x or y-axis: - """ - return DEFAULT.gamma(shape, scale, size=size, key=key) + >>> import matplotlib.pyplot as plt + >>> x, y = bm.random.multivariate_normal(mean, cov, 5000).T + >>> plt.plot(x, y, 'x') + >>> plt.axis('equal') + >>> plt.show() + Note that the covariance matrix must be positive semidefinite (a.k.a. + nonnegative-definite). Otherwise, the behavior of this method is + undefined and backwards compatibility is not guaranteed. -def gumbel(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Gumbel distribution. - - Draw samples from a Gumbel distribution with specified location and - scale. For more information on the Gumbel distribution, see - Notes and References below. - - Parameters:: - - loc : float or array_like of floats, optional - The location of the mode of the distribution. Default is 0. - scale : float or array_like of floats, optional - The scale parameter of the distribution. Default is 1. Must be non- - negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``loc`` and ``scale`` are both scalars. - Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Gumbel distribution. - - Notes:: - - The Gumbel (or Smallest Extreme Value (SEV) or the Smallest Extreme - Value Type I) distribution is one of a class of Generalized Extreme - Value (GEV) distributions used in modeling extreme value problems. - The Gumbel is a special case of the Extreme Value Type I distribution - for maximums from distributions with "exponential-like" tails. - - The probability density for the Gumbel distribution is - - .. math:: p(x) = \frac{e^{-(x - \mu)/ \beta}}{\beta} e^{ -e^{-(x - \mu)/ - \beta}}, - - where :math:`\mu` is the mode, a location parameter, and - :math:`\beta` is the scale parameter. - - The Gumbel (named for German mathematician Emil Julius Gumbel) was used - very early in the hydrology literature, for modeling the occurrence of - flood events. It is also used for modeling maximum wind speed and - rainfall rates. It is a "fat-tailed" distribution - the probability of - an event in the tail of the distribution is larger than if one used a - Gaussian, hence the surprisingly frequent occurrence of 100-year - floods. Floods were initially modeled as a Gaussian process, which - underestimated the frequency of extreme events. - - It is one of a class of extreme value distributions, the Generalized - Extreme Value (GEV) distributions, which also includes the Weibull and - Frechet. - - The function has a mean of :math:`\mu + 0.57721\beta` and a variance - of :math:`\frac{\pi^2}{6}\beta^2`. - - References:: - - .. [1] Gumbel, E. J., "Statistics of Extremes," - New York: Columbia University Press, 1958. - .. [2] Reiss, R.-D. and Thomas, M., "Statistical Analysis of Extreme - Values from Insurance, Finance, Hydrology and Other Fields," - Basel: Birkhauser Verlag, 2001. - """ - return DEFAULT.gumbel(loc, scale, size=size, key=key) + References:: + .. [1] Papoulis, A., "Probability, Random Variables, and Stochastic + Processes," 3rd ed., New York: McGraw-Hill, 1991. + .. [2] Duda, R. O., Hart, P. E., and Stork, D. G., "Pattern + Classification," 2nd ed., New York: Wiley, 2001. -def laplace(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from the Laplace or double exponential distribution with - specified location (or mean) and scale (decay). - - The Laplace distribution is similar to the Gaussian/normal distribution, - but is sharper at the peak and has fatter tails. It represents the - difference between two independent, identically distributed exponential - random variables. - - Parameters:: - - loc : float or array_like of floats, optional - The position, :math:`\mu`, of the distribution peak. Default is 0. - scale : float or array_like of floats, optional - :math:`\lambda`, the exponential decay. Default is 1. Must be non- - negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``loc`` and ``scale`` are both scalars. - Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Laplace distribution. - - Notes:: - - It has the probability density function - - .. math:: f(x; \mu, \lambda) = \frac{1}{2\lambda} - \exp\left(-\frac{|x - \mu|}{\lambda}\right). - - The first law of Laplace, from 1774, states that the frequency - of an error can be expressed as an exponential function of the - absolute magnitude of the error, which leads to the Laplace - distribution. For many problems in economics and health - sciences, this distribution seems to model the data better - than the standard Gaussian distribution. - - References:: - - .. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of - Mathematical Functions with Formulas, Graphs, and Mathematical - Tables, 9th printing," New York: Dover, 1972. - .. [2] Kotz, Samuel, et. al. "The Laplace Distribution and - Generalizations, " Birkhauser, 2001. - .. [3] Weisstein, Eric W. "Laplace Distribution." - From MathWorld--A Wolfram Web Resource. - http://mathworld.wolfram.com/LaplaceDistribution.html - .. [4] Wikipedia, "Laplace distribution", - https://en.wikipedia.org/wiki/Laplace_distribution - - Examples:: - - Draw samples from the distribution - - >>> loc, scale = 0., 1. - >>> s = bm.random.laplace(loc, scale, 1000) - - Display the histogram of the samples, along with - the probability density function: - - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, 30, density=True) - >>> x = np.arange(-8., 8., .01) - >>> pdf = np.exp(-abs(x-loc)/scale)/(2.*scale) - >>> plt.plot(x, pdf) - - Plot Gaussian for comparison: - - >>> g = (1/(scale * np.sqrt(2 * np.pi)) * - ... np.exp(-(x - loc)**2 / (2 * scale**2))) - >>> plt.plot(x,g) - """ - return DEFAULT.laplace(loc, scale, size, key=key) + Examples:: + >>> mean = (1, 2) + >>> cov = [[1, 0], [0, 1]] + >>> x = bm.random.multivariate_normal(mean, cov, (3, 3)) + >>> x.shape + (3, 3, 2) -def logistic(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a logistic distribution. - - Samples are drawn from a logistic distribution with specified - parameters, loc (location or mean, also median), and scale (>0). - - Parameters:: - - loc : float or array_like of floats, optional - Parameter of the distribution. Default is 0. - scale : float or array_like of floats, optional - Parameter of the distribution. Must be non-negative. - Default is 1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``loc`` and ``scale`` are both scalars. - Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized logistic distribution. - - Notes:: - - The probability density for the Logistic distribution is - - .. math:: P(x) = P(x) = \frac{e^{-(x-\mu)/s}}{s(1+e^{-(x-\mu)/s})^2}, - - where :math:`\mu` = location and :math:`s` = scale. - - The Logistic distribution is used in Extreme Value problems where it - can act as a mixture of Gumbel distributions, in Epidemiology, and by - the World Chess Federation (FIDE) where it is used in the Elo ranking - system, assuming the performance of each player is a logistically - distributed random variable. - - References:: - - .. [1] Reiss, R.-D. and Thomas M. (2001), "Statistical Analysis of - Extreme Values, from Insurance, Finance, Hydrology and Other - Fields," Birkhauser Verlag, Basel, pp 132-133. - .. [2] Weisstein, Eric W. "Logistic Distribution." From - MathWorld--A Wolfram Web Resource. - http://mathworld.wolfram.com/LogisticDistribution.html - .. [3] Wikipedia, "Logistic-distribution", - https://en.wikipedia.org/wiki/Logistic_distribution - - Examples:: - - Draw samples from the distribution: - - >>> loc, scale = 10, 1 - >>> s = bm.random.logistic(loc, scale, 10000) - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, bins=50) - - # plot against distribution - - >>> def logist(x, loc, scale): - ... return np.exp((loc-x)/scale)/(scale*(1+np.exp((loc-x)/scale))**2) - >>> lgst_val = logist(bins, loc, scale) - >>> plt.plot(bins, lgst_val * count.max() / lgst_val.max()) - >>> plt.show() - """ - return DEFAULT.logistic(loc, scale, size, key=key) + Here we generate 800 samples from the bivariate normal distribution + with mean [0, 0] and covariance matrix [[6, -3], [-3, 3.5]]. The + expected variances of the first and second components of the sample + are 6 and 3.5, respectively, and the expected correlation + coefficient is -3/sqrt(6*3.5) ≈ -0.65465. + >>> cov = np.array([[6, -3], [-3, 3.5]]) + >>> pts = bm.random.multivariate_normal([0, 0], cov, size=800) -def normal(loc=None, scale=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw random samples from a normal (Gaussian) distribution. - - The probability density function of the normal distribution, first - derived by De Moivre and 200 years later by both Gauss and Laplace - independently [2]_, is often called the bell curve because of - its characteristic shape (see the example below). - - The normal distributions occurs often in nature. For example, it - describes the commonly occurring distribution of samples influenced - by a large number of tiny, random disturbances, each with its own - unique distribution [2]_. - - Parameters:: - - loc : float or array_like of floats - Mean ("centre") of the distribution. - scale : float or array_like of floats - Standard deviation (spread or "width") of the distribution. Must be - non-negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``loc`` and ``scale`` are both scalars. - Otherwise, ``np.broadcast(loc, scale).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized normal distribution. - - Notes:: - - The probability density for the Gaussian distribution is - - .. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} - e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} }, - - where :math:`\mu` is the mean and :math:`\sigma` the standard - deviation. The square of the standard deviation, :math:`\sigma^2`, - is called the variance. - - The function has its peak at the mean, and its "spread" increases with - the standard deviation (the function reaches 0.607 times its maximum at - :math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that - normal is more likely to return samples lying close to the mean, rather - than those far away. - - References:: - - .. [1] Wikipedia, "Normal distribution", - https://en.wikipedia.org/wiki/Normal_distribution - .. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability, - Random Variables and Random Signal Principles", 4th ed., 2001, - pp. 51, 51, 125. - - Examples:: - - Draw samples from the distribution: - - >>> mu, sigma = 0, 0.1 # mean and standard deviation - >>> s = bm.random.normal(mu, sigma, 1000) - - Verify the mean and the variance: - - >>> abs(mu - np.mean(s)) - 0.0 # may vary - - >>> abs(sigma - np.std(s, ddof=1)) - 0.1 # may vary - - Display the histogram of the samples, along with - the probability density function: - - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, 30, density=True) - >>> plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) * - ... np.exp( - (bins - mu)**2 / (2 * sigma**2) ), - ... linewidth=2, color='r') - >>> plt.show() - - Two-by-four array of samples from the normal distribution with - mean 3 and standard deviation 2.5: - - >>> bm.random.normal(3, 2.5, size=(2, 4)) - array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random - [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random - """ - return DEFAULT.normal(loc, scale, size, key=key) + Check that the mean, covariance, and correlation coefficient of the + sample are close to the expected values: + >>> pts.mean(axis=0) + array([ 0.0326911 , -0.01280782]) # may vary + >>> np.cov(pts.T) + array([[ 5.96202397, -2.85602287], + [-2.85602287, 3.47613949]]) # may vary + >>> np.corrcoef(pts.T)[0, 1] + -0.6273591314603949 # may vary -def pareto(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Pareto II or Lomax distribution with - specified shape. - - The Lomax or Pareto II distribution is a shifted Pareto - distribution. The classical Pareto distribution can be - obtained from the Lomax distribution by adding 1 and - multiplying by the scale parameter ``m`` (see Notes). The - smallest value of the Lomax distribution is zero while for the - classical Pareto distribution it is ``mu``, where the standard - Pareto distribution has location ``mu = 1``. Lomax can also - be considered as a simplified version of the Generalized - Pareto distribution (available in SciPy), with the scale set - to one and the location set to zero. - - The Pareto distribution must be greater than zero, and is - unbounded above. It is also known as the "80-20 rule". In - this distribution, 80 percent of the weights are in the lowest - 20 percent of the range, while the other 20 percent fill the - remaining 80 percent of the range. - - Parameters:: - - a : float or array_like of floats - Shape of the distribution. Must be positive. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``a`` is a scalar. Otherwise, - ``np.array(a).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Pareto distribution. - - See Also:: - - scipy.stats.lomax : probability density function, distribution or - cumulative density function, etc. - scipy.stats.genpareto : probability density function, distribution or - cumulative density function, etc. - - Notes:: - - The probability density for the Pareto distribution is - - .. math:: p(x) = \frac{am^a}{x^{a+1}} - - where :math:`a` is the shape and :math:`m` the scale. - - The Pareto distribution, named after the Italian economist - Vilfredo Pareto, is a power law probability distribution - useful in many real world problems. Outside the field of - economics it is generally referred to as the Bradford - distribution. Pareto developed the distribution to describe - the distribution of wealth in an economy. It has also found - use in insurance, web page access statistics, oil field sizes, - and many other problems, including the download frequency for - projects in Sourceforge [1]_. It is one of the so-called - "fat-tailed" distributions. - - References:: - - .. [1] Francis Hunt and Paul Johnson, On the Pareto Distribution of - Sourceforge projects. - .. [2] Pareto, V. (1896). Course of Political Economy. Lausanne. - .. [3] Reiss, R.D., Thomas, M.(2001), Statistical Analysis of Extreme - Values, Birkhauser Verlag, Basel, pp 23-30. - .. [4] Wikipedia, "Pareto distribution", - https://en.wikipedia.org/wiki/Pareto_distribution - - Examples:: - - Draw samples from the distribution: - - >>> a, m = 3., 2. # shape and mode - >>> s = (bm.random.pareto(a, 1000) + 1) * m - - Display the histogram of the samples, along with the probability - density function: - - >>> import matplotlib.pyplot as plt - >>> count, bins, _ = plt.hist(s, 100, density=True) - >>> fit = a*m**a / bins**(a+1) - >>> plt.plot(bins, max(count)*fit/max(fit), linewidth=2, color='r') - >>> plt.show() - """ - return DEFAULT.pareto(a, size, key=key) + We can visualize this data with a scatter plot. The orientation + of the point cloud illustrates the negative correlation of the + components of this sample. + >>> import matplotlib.pyplot as plt + >>> plt.plot(pts[:, 0], pts[:, 1], '.', alpha=0.5) + >>> plt.axis('equal') + >>> plt.grid() + >>> plt.show() + """ + return DEFAULT.multivariate_normal(mean, cov, size, method, key=key) -def poisson(lam=1.0, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Poisson distribution. - The Poisson distribution is the limit of the binomial distribution - for large N. +def negative_binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a negative binomial distribution. - Parameters:: - - lam : float or array_like of floats - Expected number of events occurring in a fixed-time interval, - must be >= 0. A sequence must be broadcastable over the requested - size. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``lam`` is a scalar. Otherwise, - ``np.array(lam).size`` samples are drawn. + Samples are drawn from a negative binomial distribution with specified + parameters, `n` successes and `p` probability of success where `n` + is > 0 and `p` is in the interval [0, 1]. + + Parameters:: - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Poisson distribution. + n : float or array_like of floats + Parameter of the distribution, > 0. + p : float or array_like of floats + Parameter of the distribution, >= 0 and <=1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``n`` and ``p`` are both scalars. + Otherwise, ``np.broadcast(n, p).size`` samples are drawn. - Notes:: - - The Poisson distribution + Returns:: - .. math:: f(k; \lambda)=\frac{\lambda^k e^{-\lambda}}{k!} + out : ndarray or scalar + Drawn samples from the parameterized negative binomial distribution, + where each sample is equal to N, the number of failures that + occurred before a total of n successes was reached. - For events with an expected separation :math:`\lambda` the Poisson - distribution :math:`f(k; \lambda)` describes the probability of - :math:`k` events occurring within the observed - interval :math:`\lambda`. + Notes:: - Because the output is limited to the range of the C int64 type, a - ValueError is raised when `lam` is within 10 sigma of the maximum - representable value. + The probability mass function of the negative binomial distribution is - References:: - - .. [1] Weisstein, Eric W. "Poisson Distribution." - From MathWorld--A Wolfram Web Resource. - http://mathworld.wolfram.com/PoissonDistribution.html - .. [2] Wikipedia, "Poisson distribution", - https://en.wikipedia.org/wiki/Poisson_distribution + .. math:: P(N;n,p) = \frac{\Gamma(N+n)}{N!\Gamma(n)}p^{n}(1-p)^{N}, - Examples:: - - Draw samples from the distribution: + where :math:`n` is the number of successes, :math:`p` is the + probability of success, :math:`N+n` is the number of trials, and + :math:`\Gamma` is the gamma function. When :math:`n` is an integer, + :math:`\frac{\Gamma(N+n)}{N!\Gamma(n)} = \binom{N+n-1}{N}`, which is + the more common form of this term in the pmf. The negative + binomial distribution gives the probability of N failures given n + successes, with a success on the last trial. - >>> import numpy as np - >>> s = bm.random.poisson(5, 10000) + If one throws a die repeatedly until the third time a "1" appears, + then the probability distribution of the number of non-"1"s that + appear before the third "1" is a negative binomial distribution. - Display histogram of the sample: + References:: - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, 14, density=True) - >>> plt.show() + .. [1] Weisstein, Eric W. "Negative Binomial Distribution." From + MathWorld--A Wolfram Web Resource. + http://mathworld.wolfram.com/NegativeBinomialDistribution.html + .. [2] Wikipedia, "Negative binomial distribution", + https://en.wikipedia.org/wiki/Negative_binomial_distribution - Draw each 100 values for lambda 100 and 500: + Examples:: - >>> s = bm.random.poisson(lam=(100., 500.), size=(100, 2)) - """ - return DEFAULT.poisson(lam, size, key=key) + Draw samples from the distribution: + A real world example. A company drills wild-cat oil + exploration wells, each with an estimated probability of + success of 0.1. What is the probability of having one success + for each successive well, that is what is the probability of a + single success after drilling 5 wells, after 6 wells, etc.? -def standard_cauchy(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a standard Cauchy distribution with mode = 0. - - Also known as the Lorentz distribution. - - Parameters:: - - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - - Returns:: - - samples : ndarray or scalar - The drawn samples. - - Notes:: - - The probability density function for the full Cauchy distribution is - - .. math:: P(x; x_0, \gamma) = \frac{1}{\pi \gamma \bigl[ 1+ - (\frac{x-x_0}{\gamma})^2 \bigr] } - - and the Standard Cauchy distribution just sets :math:`x_0=0` and - :math:`\gamma=1` - - The Cauchy distribution arises in the solution to the driven harmonic - oscillator problem, and also describes spectral line broadening. It - also describes the distribution of values at which a line tilted at - a random angle will cut the x axis. - - When studying hypothesis tests that assume normality, seeing how the - tests perform on data from a Cauchy distribution is a good indicator of - their sensitivity to a heavy-tailed distribution, since the Cauchy looks - very much like a Gaussian distribution, but with heavier tails. - - References:: - - .. [1] NIST/SEMATECH e-Handbook of Statistical Methods, "Cauchy - Distribution", - https://www.itl.nist.gov/div898/handbook/eda/section3/eda3663.htm - .. [2] Weisstein, Eric W. "Cauchy Distribution." From MathWorld--A - Wolfram Web Resource. - http://mathworld.wolfram.com/CauchyDistribution.html - .. [3] Wikipedia, "Cauchy distribution" - https://en.wikipedia.org/wiki/Cauchy_distribution - - Examples:: - - Draw samples and plot the distribution: - - >>> import matplotlib.pyplot as plt - >>> s = bm.random.standard_cauchy(1000000) - >>> s = s[(s>-25) & (s<25)] # truncate distribution so it plots well - >>> plt.hist(s, bins=100) - >>> plt.show() - """ - return DEFAULT.standard_cauchy(size, key=key) + >>> s = bm.random.negative_binomial(1, 0.1, 100000) + >>> for i in range(1, 11): # doctest: +SKIP + ... probability = sum(s 0. + nonc : float or array_like of floats + Non-centrality, must be non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``df`` and ``nonc`` are both scalars. + Otherwise, ``np.broadcast(df, nonc).size`` samples are drawn. - Examples:: - - Output a 3x8000 array: + Returns:: - >>> n = bm.random.standard_exponential((3, 8000)) - """ - return DEFAULT.standard_exponential(size, key=key) + out : ndarray or scalar + Drawn samples from the parameterized noncentral chi-square distribution. + Notes:: -def standard_gamma(shape, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a standard Gamma distribution. - - Samples are drawn from a Gamma distribution with specified parameters, - shape (sometimes designated "k") and scale=1. - - Parameters:: - - shape : float or array_like of floats - Parameter, must be non-negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``shape`` is a scalar. Otherwise, - ``np.array(shape).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized standard gamma distribution. - - See Also:: - - scipy.stats.gamma : probability density function, distribution or - cumulative density function, etc. - - Notes:: - - The probability density for the Gamma distribution is - - .. math:: p(x) = x^{k-1}\frac{e^{-x/\theta}}{\theta^k\Gamma(k)}, - - where :math:`k` is the shape and :math:`\theta` the scale, - and :math:`\Gamma` is the Gamma function. - - The Gamma distribution is often used to model the times to failure of - electronic components, and arises naturally in processes for which the - waiting times between Poisson distributed events are relevant. - - References:: - - .. [1] Weisstein, Eric W. "Gamma Distribution." From MathWorld--A - Wolfram Web Resource. - http://mathworld.wolfram.com/GammaDistribution.html - .. [2] Wikipedia, "Gamma distribution", - https://en.wikipedia.org/wiki/Gamma_distribution - - Examples:: - - Draw samples from the distribution: - - >>> shape, scale = 2., 1. # mean and width - >>> s = bm.random.standard_gamma(shape, 1000000) - - Display the histogram of the samples, along with - the probability density function: - - >>> import matplotlib.pyplot as plt - >>> import scipy.special as sps # doctest: +SKIP - >>> count, bins, ignored = plt.hist(s, 50, density=True) - >>> y = bins**(shape-1) * ((np.exp(-bins/scale))/ # doctest: +SKIP - ... (sps.gamma(shape) * scale**shape)) - >>> plt.plot(bins, y, linewidth=2, color='r') # doctest: +SKIP - >>> plt.show() - """ - return DEFAULT.standard_gamma(shape, size, key=key) + The probability density function for the noncentral Chi-square + distribution is + .. math:: P(x;df,nonc) = \sum^{\infty}_{i=0} + \frac{e^{-nonc/2}(nonc/2)^{i}}{i!} + P_{Y_{df+2i}}(x), -def standard_normal(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a standard Normal distribution (mean=0, stdev=1). - - Parameters:: - - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. - - Returns:: - - out : float or ndarray - A floating-point array of shape ``size`` of drawn samples, or a - single sample if ``size`` was not specified. - - See Also:: - - normal : - Equivalent function with additional ``loc`` and ``scale`` arguments - for setting the mean and standard deviation. - - Notes:: - - For random samples from the normal distribution with mean ``mu`` and - standard deviation ``sigma``, use one of:: - - mu + sigma * bm.random.standard_normal(size=...) - bm.random.normal(mu, sigma, size=...) - - Examples:: - - >>> bm.random.standard_normal() - 2.1923875335537315 #random - - >>> s = bm.random.standard_normal(8000) - >>> s - array([ 0.6888893 , 0.78096262, -0.89086505, ..., 0.49876311, # random - -0.38672696, -0.4685006 ]) # random - >>> s.shape - (8000,) - >>> s = bm.random.standard_normal(size=(3, 4, 2)) - >>> s.shape - (3, 4, 2) - - Two-by-four array of samples from the normal distribution with - mean 3 and standard deviation 2.5: - - >>> 3 + 2.5 * bm.random.standard_normal(size=(2, 4)) - array([[-4.49401501, 4.00950034, -1.81814867, 7.29718677], # random - [ 0.39924804, 4.68456316, 4.99394529, 4.84057254]]) # random - """ - return DEFAULT.standard_normal(size, key=key) + where :math:`Y_{q}` is the Chi-square with q degrees of freedom. + References:: -def standard_t(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a standard Student's t distribution with `df` degrees - of freedom. - - A special case of the hyperbolic distribution. As `df` gets - large, the result resembles that of the standard normal - distribution (`standard_normal`). - - Parameters:: - - df : float or array_like of floats - Degrees of freedom, must be > 0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``df`` is a scalar. Otherwise, - ``np.array(df).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized standard Student's t distribution. - - Notes:: - - The probability density function for the t distribution is - - .. math:: P(x, df) = \frac{\Gamma(\frac{df+1}{2})}{\sqrt{\pi df} - \Gamma(\frac{df}{2})}\Bigl( 1+\frac{x^2}{df} \Bigr)^{-(df+1)/2} - - The t test is based on an assumption that the data come from a - Normal distribution. The t test provides a way to test whether - the sample mean (that is the mean calculated from the data) is - a good estimate of the true mean. - - The derivation of the t-distribution was first published in - 1908 by William Gosset while working for the Guinness Brewery - in Dublin. Due to proprietary issues, he had to publish under - a pseudonym, and so he used the name Student. - - References:: - - .. [1] Dalgaard, Peter, "Introductory Statistics With R", - Springer, 2002. - .. [2] Wikipedia, "Student's t-distribution" - https://en.wikipedia.org/wiki/Student's_t-distribution - - Examples:: - - From Dalgaard page 83 [1]_, suppose the daily energy intake for 11 - women in kilojoules (kJ) is: - - >>> intake = np.array([5260., 5470, 5640, 6180, 6390, 6515, 6805, 7515, \ - ... 7515, 8230, 8770]) - - Does their energy intake deviate systematically from the recommended - value of 7725 kJ? Our null hypothesis will be the absence of deviation, - and the alternate hypothesis will be the presence of an effect that could be - either positive or negative, hence making our test 2-tailed. - - Because we are estimating the mean and we have N=11 values in our sample, - we have N-1=10 degrees of freedom. We set our significance level to 95% and - compute the t statistic using the empirical mean and empirical standard - deviation of our intake. We use a ddof of 1 to base the computation of our - empirical standard deviation on an unbiased estimate of the variance (note: - the final estimate is not unbiased due to the concave nature of the square - root). - - >>> np.mean(intake) - 6753.636363636364 - >>> intake.std(ddof=1) - 1142.1232221373727 - >>> t = (np.mean(intake)-7725)/(intake.std(ddof=1)/np.sqrt(len(intake))) - >>> t - -2.8207540608310198 - - We draw 1000000 samples from Student's t distribution with the adequate - degrees of freedom. - - >>> import matplotlib.pyplot as plt - >>> s = bm.random.standard_t(10, size=1000000) - >>> h = plt.hist(s, bins=100, density=True) - - Does our t statistic land in one of the two critical regions found at - both tails of the distribution? - - >>> np.sum(np.abs(t) < np.abs(s)) / float(len(s)) - 0.018318 #random < 0.05, statistic is in critical region - - The probability value for this 2-tailed test is about 1.83%, which is - lower than the 5% pre-determined significance threshold. - - Therefore, the probability of observing values as extreme as our intake - conditionally on the null hypothesis being true is too low, and we reject - the null hypothesis of no deviation. - """ - return DEFAULT.standard_t(df, size, key=key) + .. [1] Wikipedia, "Noncentral chi-squared distribution" + https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution + Examples:: -def uniform(low=0.0, high=1.0, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a uniform distribution. - - Samples are uniformly distributed over the half-open interval - ``[low, high)`` (includes low, but excludes high). In other words, - any value within the given interval is equally likely to be drawn - by `uniform`. - - Parameters:: - - low : float or array_like of floats, optional - Lower boundary of the output interval. All values generated will be - greater than or equal to low. The default value is 0. - high : float or array_like of floats - Upper boundary of the output interval. All values generated will be - less than or equal to high. The high limit may be included in the - returned array of floats due to floating-point rounding in the - equation ``low + (high-low) * random_sample()``. The default value - is 1.0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``low`` and ``high`` are both scalars. - Otherwise, ``np.broadcast(low, high).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized uniform distribution. - - See Also:: - - randint : Discrete uniform distribution, yielding integers. - random_integers : Discrete uniform distribution over the closed - interval ``[low, high]``. - random_sample : Floats uniformly distributed over ``[0, 1)``. - random : Alias for `random_sample`. - rand : Convenience function that accepts dimensions as input, e.g., - ``rand(2,2)`` would generate a 2-by-2 array of floats, - uniformly distributed over ``[0, 1)``. - - Notes:: - - The probability density function of the uniform distribution is - - .. math:: p(x) = \frac{1}{b - a} - - anywhere within the interval ``[a, b)``, and zero elsewhere. - - When ``high`` == ``low``, values of ``low`` will be returned. - If ``high`` < ``low``, the results are officially undefined - and may eventually raise an error, i.e. do not rely on this - function to behave when passed arguments satisfying that - inequality condition. The ``high`` limit may be included in the - returned array of floats due to floating-point rounding in the - equation ``low + (high-low) * random_sample()``. For example: - - >>> x = np.float32(5*0.99999999) - >>> x - 5.0 - - - Examples:: - - Draw samples from the distribution: - - >>> s = bm.random.uniform(-1,0,1000) - - All values are within the given interval: - - >>> np.all(s >= -1) - True - >>> np.all(s < 0) - True - - Display the histogram of the samples, along with the - probability density function: - - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, 15, density=True) - >>> plt.plot(bins, np.ones_like(bins), linewidth=2, color='r') - >>> plt.show() - """ - return DEFAULT.uniform(low, high, size, key=key) + Draw values from the distribution and plot the histogram + >>> import matplotlib.pyplot as plt + >>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000), + ... bins=200, density=True) + >>> plt.show() -def truncated_normal(lower, upper, size: Optional[Union[int, Sequence[int]]] = None, loc=0., scale=1., dtype=float, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r"""Sample truncated standard normal random values with given shape and dtype. - - Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf - - - Notes:: - - This distribution is the normal distribution centered on ``loc`` (default - 0), with standard deviation ``scale`` (default 1), and clipped at ``a``, - ``b`` standard deviations to the left, right (respectively) from ``loc``. - If ``myclip_a`` and ``myclip_b`` are clip values in the sample space (as - opposed to the number of standard deviations) then they can be converted - to the required form according to:: - - a, b = (myclip_a - loc) / scale, (myclip_b - loc) / scale - - - Parameters:: - - lower : float, ndarray - A float or array of floats representing the lower bound for - truncation. Must be broadcast-compatible with ``upper``. - upper : float, ndarray - A float or array of floats representing the upper bound for - truncation. Must be broadcast-compatible with ``lower``. - loc : float, ndarray - Mean ("centre") of the distribution before truncating. Note that - the mean of the truncated distribution will not be exactly equal - to ``loc``. - size : optional, list of int, tuple of int - A tuple of nonnegative integers specifying the result - shape. Must be broadcast-compatible with ``lower`` and ``upper``. The - default (None) produces a result shape by broadcasting ``lower`` and - ``upper``. - loc: optional, float, ndarray - A float or array of floats representing the mean of the - distribution. Default is 0. - scale : float, ndarray - Standard deviation (spread or "width") of the distribution. Must be - non-negative. Default is 1. - dtype: optional - The float dtype for the returned values (default float64 if - jax_enable_x64 is true, otherwise float32). - key: jax.Array - The key for random generator. Consistent with the jax's random - paradigm. - - Returns:: - - out : Array - A random array with the specified dtype and shape given by ``shape`` if - ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. - Returns values in the open interval ``(lower, upper)``. - """ - return DEFAULT.truncated_normal(lower, upper, size, loc, scale, dtype=dtype, key=key) + Draw values from a noncentral chisquare with very small noncentrality, + and compare to a chisquare. + >>> plt.figure() + >>> values = plt.hist(bm.random.noncentral_chisquare(3, .0000001, 100000), + ... bins=np.arange(0., 25, .1), density=True) + >>> values2 = plt.hist(bm.random.chisquare(3, 100000), + ... bins=np.arange(0., 25, .1), density=True) + >>> plt.plot(values[1][0:-1], values[0]-values2[0], 'ob') + >>> plt.show() -RandomState.truncated_normal.__doc__ = truncated_normal.__doc__ + Demonstrate how large values of non-centrality lead to a more symmetric + distribution. + >>> plt.figure() + >>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000), + ... bins=200, density=True) + >>> plt.show() + """ + return DEFAULT.noncentral_chisquare(df, nonc, size, key=key) -def bernoulli(p=0.5, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r"""Sample Bernoulli random values with given shape and mean. - - Parameters:: - - p: float, array_like, optional - A float or array of floats for the mean of the random - variables. Must be broadcast-compatible with ``shape`` and the values - should be within [0, 1]. Default 0.5. - size: optional, tuple of int, int - A tuple of nonnegative integers representing the result - shape. Must be broadcast-compatible with ``p.shape``. The default (None) - produces a result shape equal to ``p.shape``. - - Returns:: - - out: array_like - A random array with boolean dtype and shape given by ``shape`` if ``shape`` - is not None, or else ``p.shape``. - """ - return DEFAULT.bernoulli(p, size, key=key) +def noncentral_f(dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from the noncentral F distribution. -def lognormal(mean=None, sigma=None, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a log-normal distribution. - - Draw samples from a log-normal distribution with specified mean, - standard deviation, and array shape. Note that the mean and standard - deviation are not the values for the distribution itself, but of the - underlying normal distribution it is derived from. - - Parameters:: - - mean : float or array_like of floats, optional - Mean value of the underlying normal distribution. Default is 0. - sigma : float or array_like of floats, optional - Standard deviation of the underlying normal distribution. Must be - non-negative. Default is 1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``mean`` and ``sigma`` are both scalars. - Otherwise, ``np.broadcast(mean, sigma).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized log-normal distribution. - - See Also:: - - scipy.stats.lognorm : probability density function, distribution, - cumulative density function, etc. - - Notes:: - - A variable `x` has a log-normal distribution if `log(x)` is normally - distributed. The probability density function for the log-normal - distribution is: - - .. math:: p(x) = \frac{1}{\sigma x \sqrt{2\pi}} - e^{(-\frac{(ln(x)-\mu)^2}{2\sigma^2})} - - where :math:`\mu` is the mean and :math:`\sigma` is the standard - deviation of the normally distributed logarithm of the variable. - A log-normal distribution results if a random variable is the *product* - of a large number of independent, identically-distributed variables in - the same way that a normal distribution results if the variable is the - *sum* of a large number of independent, identically-distributed - variables. - - References:: - - .. [1] Limpert, E., Stahel, W. A., and Abbt, M., "Log-normal - Distributions across the Sciences: Keys and Clues," - BioScience, Vol. 51, No. 5, May, 2001. - https://stat.ethz.ch/~stahel/lognormal/bioscience.pdf - .. [2] Reiss, R.D. and Thomas, M., "Statistical Analysis of Extreme - Values," Basel: Birkhauser Verlag, 2001, pp. 31-32. - - Examples:: - - Draw samples from the distribution: - - >>> mu, sigma = 3., 1. # mean and standard deviation - >>> s = bm.random.lognormal(mu, sigma, 1000) - - Display the histogram of the samples, along with - the probability density function: - - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, 100, density=True, align='mid') - - >>> x = np.linspace(min(bins), max(bins), 10000) - >>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2)) - ... / (x * sigma * np.sqrt(2 * np.pi))) - - >>> plt.plot(x, pdf, linewidth=2, color='r') - >>> plt.axis('tight') - >>> plt.show() - - Demonstrate that taking the products of random samples from a uniform - distribution can be fit well by a log-normal probability density - function. - - >>> # Generate a thousand samples: each is the product of 100 random - >>> # values, drawn from a normal distribution. - >>> b = [] - >>> for i in range(1000): - ... a = 10. + bm.random.standard_normal(100) - ... b.append(np.product(a)) - - >>> b = np.array(b) / np.min(b) # scale values to be positive - >>> count, bins, ignored = plt.hist(b, 100, density=True, align='mid') - >>> sigma = np.std(np.log(b)) - >>> mu = np.mean(np.log(b)) - - >>> x = np.linspace(min(bins), max(bins), 10000) - >>> pdf = (np.exp(-(np.log(x) - mu)**2 / (2 * sigma**2)) - ... / (x * sigma * np.sqrt(2 * np.pi))) - - >>> plt.plot(x, pdf, color='r', linewidth=2) - >>> plt.show() - """ - return DEFAULT.lognormal(mean, sigma, size, key=key) + Samples are drawn from an F distribution with specified parameters, + `dfnum` (degrees of freedom in numerator) and `dfden` (degrees of + freedom in denominator), where both parameters > 1. + `nonc` is the non-centrality parameter. + Parameters:: -def binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a binomial distribution. - - Samples are drawn from a binomial distribution with specified - parameters, n trials and p probability of success where - n an integer >= 0 and p is in the interval [0,1]. (n may be - input as a float, but it is truncated to an integer in use) - - Parameters:: - - n : int or array_like of ints - Parameter of the distribution, >= 0. Floats are also accepted, - but they will be truncated to integers. - p : float or array_like of floats - Parameter of the distribution, >= 0 and <=1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``n`` and ``p`` are both scalars. - Otherwise, ``np.broadcast(n, p).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized binomial distribution, where - each sample is equal to the number of successes over the n trials. - - See Also:: - - scipy.stats.binom : probability density function, distribution or - cumulative density function, etc. - - Notes:: - - The probability density for the binomial distribution is - - .. math:: P(N) = \binom{n}{N}p^N(1-p)^{n-N}, - - where :math:`n` is the number of trials, :math:`p` is the probability - of success, and :math:`N` is the number of successes. - - When estimating the standard error of a proportion in a population by - using a random sample, the normal distribution works well unless the - product p*n <=5, where p = population proportion estimate, and n = - number of samples, in which case the binomial distribution is used - instead. For example, a sample of 15 people shows 4 who are left - handed, and 11 who are right handed. Then p = 4/15 = 27%. 0.27*15 = 4, - so the binomial distribution should be used in this case. - - References:: - - .. [1] Dalgaard, Peter, "Introductory Statistics with R", - Springer-Verlag, 2002. - .. [2] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill, - Fifth Edition, 2002. - .. [3] Lentner, Marvin, "Elementary Applied Statistics", Bogden - and Quigley, 1972. - .. [4] Weisstein, Eric W. "Binomial Distribution." From MathWorld--A - Wolfram Web Resource. - http://mathworld.wolfram.com/BinomialDistribution.html - .. [5] Wikipedia, "Binomial distribution", - https://en.wikipedia.org/wiki/Binomial_distribution - - Examples:: - - Draw samples from the distribution: - - >>> n, p = 10, .5 # number of trials, probability of each trial - >>> s = bm.random.binomial(n, p, 1000) - # result of flipping a coin 10 times, tested 1000 times. - - A real world example. A company drills 9 wild-cat oil exploration - wells, each with an estimated probability of success of 0.1. All nine - wells fail. What is the probability of that happening? - - Let's do 20,000 trials of the model, and count the number that - generate zero positive results. - - >>> sum(bm.random.binomial(9, 0.1, 20000) == 0)/20000. - # answer = 0.38885, or 38%. - """ - return DEFAULT.binomial(n, p, size, key=key) + dfnum : float or array_like of floats + Numerator degrees of freedom, must be > 0. + dfden : float or array_like of floats + Denominator degrees of freedom, must be > 0. + nonc : float or array_like of floats + Non-centrality parameter, the sum of the squares of the numerator + means, must be >= 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``dfnum``, ``dfden``, and ``nonc`` + are all scalars. Otherwise, ``np.broadcast(dfnum, dfden, nonc).size`` + samples are drawn. + Returns:: -def chisquare(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a chi-square distribution. + out : ndarray or scalar + Drawn samples from the parameterized noncentral Fisher distribution. + + Notes:: + + When calculating the power of an experiment (power = probability of + rejecting the null hypothesis when a specific alternative is true) the + non-central F statistic becomes important. When the null hypothesis is + true, the F statistic follows a central F distribution. When the null + hypothesis is not true, then it follows a non-central F statistic. + + References:: + + .. [1] Weisstein, Eric W. "Noncentral F-Distribution." + From MathWorld--A Wolfram Web Resource. + http://mathworld.wolfram.com/NoncentralF-Distribution.html + .. [2] Wikipedia, "Noncentral F-distribution", + https://en.wikipedia.org/wiki/Noncentral_F-distribution + + Examples:: + + In a study, testing for a specific alternative to the null hypothesis + requires use of the Noncentral F distribution. We need to calculate the + area in the tail of the distribution that exceeds the value of the F + distribution for the null hypothesis. We'll plot the two probability + distributions for comparison. + + >>> dfnum = 3 # between group deg of freedom + >>> dfden = 20 # within groups degrees of freedom + >>> nonc = 3.0 + >>> nc_vals = bm.random.noncentral_f(dfnum, dfden, nonc, 1000000) + >>> NF = np.histogram(nc_vals, bins=50, density=True) + >>> c_vals = bm.random.f(dfnum, dfden, 1000000) + >>> F = np.histogram(c_vals, bins=50, density=True) + >>> import matplotlib.pyplot as plt + >>> plt.plot(F[1][1:], F[0]) + >>> plt.plot(NF[1][1:], NF[0]) + >>> plt.show() + """ + return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key) - When `df` independent random variables, each with standard normal - distributions (mean 0, variance 1), are squared and summed, the - resulting distribution is chi-square (see Notes). This distribution - is often used in hypothesis testing. - Parameters:: - - df : float or array_like of floats - Number of degrees of freedom, must be > 0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``df`` is a scalar. Otherwise, - ``np.array(df).size`` samples are drawn. +def power(a, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draws samples in [0, 1] from a power distribution with positive + exponent a - 1. - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized chi-square distribution. + Also known as the power function distribution. - Raises:: - - ValueError - When `df` <= 0 or when an inappropriate `size` (e.g. ``size=-1``) - is given. + Parameters:: - Notes:: - - The variable obtained by summing the squares of `df` independent, - standard normally distributed random variables: + a : float or array_like of floats + Parameter of the distribution. Must be non-negative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. - .. math:: Q = \sum_{i=0}^{\mathtt{df}} X^2_i + Returns:: - is chi-square distributed, denoted + out : ndarray or scalar + Drawn samples from the parameterized power distribution. - .. math:: Q \sim \chi^2_k. + Raises:: - The probability density function of the chi-squared distribution is + ValueError + If a <= 0. - .. math:: p(x) = \frac{(1/2)^{k/2}}{\Gamma(k/2)} - x^{k/2 - 1} e^{-x/2}, + Notes:: - where :math:`\Gamma` is the gamma function, + The probability density function is - .. math:: \Gamma(x) = \int_0^{-\infty} t^{x - 1} e^{-t} dt. + .. math:: P(x; a) = ax^{a-1}, 0 \le x \le 1, a>0. - References:: - - .. [1] NIST "Engineering Statistics Handbook" - https://www.itl.nist.gov/div898/handbook/eda/section3/eda3666.htm + The power function distribution is just the inverse of the Pareto + distribution. It may also be seen as a special case of the Beta + distribution. - Examples:: - - >>> bm.random.chisquare(2,4) - array([ 1.89920014, 9.00867716, 3.13710533, 5.62318272]) # random - """ - return DEFAULT.chisquare(df, size, key=key) + It is used, for example, in modeling the over-reporting of insurance + claims. + References:: -def dirichlet(alpha, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from the Dirichlet distribution. - - Draw `size` samples of dimension k from a Dirichlet distribution. A - Dirichlet-distributed random variable can be seen as a multivariate - generalization of a Beta distribution. The Dirichlet distribution - is a conjugate prior of a multinomial distribution in Bayesian - inference. - - Parameters:: - - alpha : sequence of floats, length k - Parameter of the distribution (length ``k`` for sample of - length ``k``). - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - vector of length ``k`` is returned. - - Returns:: - - samples : ndarray, - The drawn samples, of shape ``(size, k)``. - - Raises:: - - ValueError - If any value in ``alpha`` is less than or equal to zero - - Notes:: - - The Dirichlet distribution is a distribution over vectors - :math:`x` that fulfil the conditions :math:`x_i>0` and - :math:`\sum_{i=1}^k x_i = 1`. - - The probability density function :math:`p` of a - Dirichlet-distributed random vector :math:`X` is - proportional to - - .. math:: p(x) \propto \prod_{i=1}^{k}{x^{\alpha_i-1}_i}, - - where :math:`\alpha` is a vector containing the positive - concentration parameters. - - The method uses the following property for computation: let :math:`Y` - be a random vector which has components that follow a standard gamma - distribution, then :math:`X = \frac{1}{\sum_{i=1}^k{Y_i}} Y` - is Dirichlet-distributed - - References:: - - .. [1] David McKay, "Information Theory, Inference and Learning - Algorithms," chapter 23, - http://www.inference.org.uk/mackay/itila/ - .. [2] Wikipedia, "Dirichlet distribution", - https://en.wikipedia.org/wiki/Dirichlet_distribution - - Examples:: - - Taking an example cited in Wikipedia, this distribution can be used if - one wanted to cut strings (each of initial length 1.0) into K pieces - with different lengths, where each piece had, on average, a designated - average length, but allowing some variation in the relative sizes of - the pieces. - - >>> s = bm.random.dirichlet((10, 5, 3), 20).transpose() - - >>> import matplotlib.pyplot as plt - >>> plt.barh(range(20), s[0]) - >>> plt.barh(range(20), s[1], left=s[0], color='g') - >>> plt.barh(range(20), s[2], left=s[0]+s[1], color='r') - >>> plt.title("Lengths of Strings") - """ - return DEFAULT.dirichlet(alpha, size, key=key) + .. [1] Christian Kleiber, Samuel Kotz, "Statistical size distributions + in economics and actuarial sciences", Wiley, 2003. + .. [2] Heckert, N. A. and Filliben, James J. "NIST Handbook 148: + Dataplot Reference Manual, Volume 2: Let Subcommands and Library + Functions", National Institute of Standards and Technology + Handbook Series, June 2003. + https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/powpdf.pdf + Examples:: -def geometric(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from the geometric distribution. + Draw samples from the distribution: - Bernoulli trials are experiments with one of two outcomes: - success or failure (an example of such an experiment is flipping - a coin). The geometric distribution models the number of trials - that must be run in order to achieve success. It is therefore - supported on the positive integers, ``k = 1, 2, ...``. + >>> a = 5. # shape + >>> samples = 1000 + >>> s = bm.random.power(a, samples) - The probability mass function of the geometric distribution is + Display the histogram of the samples, along with + the probability density function: - .. math:: f(k) = (1 - p)^{k - 1} p + >>> import matplotlib.pyplot as plt + >>> count, bins, ignored = plt.hist(s, bins=30) + >>> x = np.linspace(0, 1, 100) + >>> y = a*x**(a-1.) + >>> normed_y = samples*np.diff(bins)[0]*y + >>> plt.plot(x, normed_y) + >>> plt.show() - where `p` is the probability of success of an individual trial. + Compare the power function distribution to the inverse of the Pareto. - Parameters:: - - p : float or array_like of floats - The probability of success of an individual trial. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``p`` is a scalar. Otherwise, - ``np.array(p).size`` samples are drawn. + >>> from scipy import stats # doctest: +SKIP + >>> rvs = bm.random.power(5, 1000000) + >>> rvsp = bm.random.pareto(5, 1000000) + >>> xx = np.linspace(0,1,100) + >>> powpdf = stats.powerlaw.pdf(xx,5) # doctest: +SKIP - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized geometric distribution. + >>> plt.figure() + >>> plt.hist(rvs, bins=50, density=True) + >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP + >>> plt.title('bm.random.power(5)') - Examples:: - - Draw ten thousand values from the geometric distribution, - with the probability of an individual success equal to 0.35: + >>> plt.figure() + >>> plt.hist(1./(1.+rvsp), bins=50, density=True) + >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP + >>> plt.title('inverse of 1 + bm.random.pareto(5)') - >>> z = bm.random.geometric(p=0.35, size=10000) + >>> plt.figure() + >>> plt.hist(1./(1.+rvsp), bins=50, density=True) + >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP + >>> plt.title('inverse of stats.pareto(5)') + """ + return DEFAULT.power(a, size, key=key) - How many trials succeeded after a single run? - >>> (z == 1).sum() / 10000. - 0.34889999999999999 #random - """ - return DEFAULT.geometric(p, size, key=key) +def rayleigh(scale=1.0, + size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from a Rayleigh distribution. + The :math:`\chi` and Weibull distributions are generalizations of the + Rayleigh. -def f(dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from an F distribution. - - Samples are drawn from an F distribution with specified parameters, - `dfnum` (degrees of freedom in numerator) and `dfden` (degrees of - freedom in denominator), where both parameters must be greater than - zero. - - The random variate of the F distribution (also known as the - Fisher distribution) is a continuous probability distribution - that arises in ANOVA tests, and is the ratio of two chi-square - variates. - - Parameters:: - - dfnum : float or array_like of floats - Degrees of freedom in numerator, must be > 0. - dfden : float or array_like of float - Degrees of freedom in denominator, must be > 0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``dfnum`` and ``dfden`` are both scalars. - Otherwise, ``np.broadcast(dfnum, dfden).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Fisher distribution. - - See Also:: - - scipy.stats.f : probability density function, distribution or - cumulative density function, etc. - - Notes:: - - The F statistic is used to compare in-group variances to between-group - variances. Calculating the distribution depends on the sampling, and - so it is a function of the respective degrees of freedom in the - problem. The variable `dfnum` is the number of samples minus one, the - between-groups degrees of freedom, while `dfden` is the within-groups - degrees of freedom, the sum of the number of samples in each group - minus the number of groups. - - References:: - - .. [1] Glantz, Stanton A. "Primer of Biostatistics.", McGraw-Hill, - Fifth Edition, 2002. - .. [2] Wikipedia, "F-distribution", - https://en.wikipedia.org/wiki/F-distribution - - Examples:: - - An example from Glantz[1], pp 47-40: - - Two groups, children of diabetics (25 people) and children from people - without diabetes (25 controls). Fasting blood glucose was measured, - case group had a mean value of 86.1, controls had a mean value of - 82.2. Standard deviations were 2.09 and 2.49 respectively. Are these - data consistent with the null hypothesis that the parents diabetic - status does not affect their children's blood glucose levels? - Calculating the F statistic from the data gives a value of 36.01. - - Draw samples from the distribution: - - >>> dfnum = 1. # between group degrees of freedom - >>> dfden = 48. # within groups degrees of freedom - >>> s = bm.random.f(dfnum, dfden, 1000) - - The lower bound for the top 1% of the samples is : - - >>> np.sort(s)[-10] - 7.61988120985 # random - - So there is about a 1% chance that the F statistic will exceed 7.62, - the measured value is 36, so the null hypothesis is rejected at the 1% - level. - """ - return DEFAULT.f(dfnum, dfden, size, key=key) + Parameters:: + scale : float or array_like of floats, optional + Scale, also equals the mode. Must be non-negative. Default is 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``scale`` is a scalar. Otherwise, + ``np.array(scale).size`` samples are drawn. -def hypergeometric(ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Hypergeometric distribution. - - Samples are drawn from a hypergeometric distribution with specified - parameters, `ngood` (ways to make a good selection), `nbad` (ways to make - a bad selection), and `nsample` (number of items sampled, which is less - than or equal to the sum ``ngood + nbad``). - - Parameters:: - - ngood : int or array_like of ints - Number of ways to make a good selection. Must be nonnegative. - nbad : int or array_like of ints - Number of ways to make a bad selection. Must be nonnegative. - nsample : int or array_like of ints - Number of items sampled. Must be at least 1 and at most - ``ngood + nbad``. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if `ngood`, `nbad`, and `nsample` - are all scalars. Otherwise, ``np.broadcast(ngood, nbad, nsample).size`` - samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized hypergeometric distribution. Each - sample is the number of good items within a randomly selected subset of - size `nsample` taken from a set of `ngood` good items and `nbad` bad items. - - See Also:: - - scipy.stats.hypergeom : probability density function, distribution or - cumulative density function, etc. - - Notes:: - - The probability density for the Hypergeometric distribution is - - .. math:: P(x) = \frac{\binom{g}{x}\binom{b}{n-x}}{\binom{g+b}{n}}, - - where :math:`0 \le x \le n` and :math:`n-b \le x \le g` - - for P(x) the probability of ``x`` good results in the drawn sample, - g = `ngood`, b = `nbad`, and n = `nsample`. - - Consider an urn with black and white marbles in it, `ngood` of them - are black and `nbad` are white. If you draw `nsample` balls without - replacement, then the hypergeometric distribution describes the - distribution of black balls in the drawn sample. - - Note that this distribution is very similar to the binomial - distribution, except that in this case, samples are drawn without - replacement, whereas in the Binomial case samples are drawn with - replacement (or the sample space is infinite). As the sample space - becomes large, this distribution approaches the binomial. - - References:: - - .. [1] Lentner, Marvin, "Elementary Applied Statistics", Bogden - and Quigley, 1972. - .. [2] Weisstein, Eric W. "Hypergeometric Distribution." From - MathWorld--A Wolfram Web Resource. - http://mathworld.wolfram.com/HypergeometricDistribution.html - .. [3] Wikipedia, "Hypergeometric distribution", - https://en.wikipedia.org/wiki/Hypergeometric_distribution - - Examples:: - - Draw samples from the distribution: - - >>> ngood, nbad, nsamp = 100, 2, 10 - # number of good, number of bad, and number of samples - >>> s = bm.random.hypergeometric(ngood, nbad, nsamp, 1000) - >>> from matplotlib.pyplot import hist - >>> hist(s) - # note that it is very unlikely to grab both bad items - - Suppose you have an urn with 15 white and 15 black marbles. - If you pull 15 marbles at random, how likely is it that - 12 or more of them are one color? - - >>> s = bm.random.hypergeometric(15, 15, 15, 100000) - >>> sum(s>=12)/100000. + sum(s<=3)/100000. - # answer = 0.003 ... pretty unlikely! - """ - return DEFAULT.hypergeometric(ngood, nbad, nsample, size, key=key) + Returns:: + out : ndarray or scalar + Drawn samples from the parameterized Rayleigh distribution. -def logseries(p, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a logarithmic series distribution. - - Samples are drawn from a log series distribution with specified - shape parameter, 0 <= ``p`` < 1. - - Parameters:: - - p : float or array_like of floats - Shape parameter for the distribution. Must be in the range [0, 1). - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``p`` is a scalar. Otherwise, - ``np.array(p).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized logarithmic series distribution. - - See Also:: - - scipy.stats.logser : probability density function, distribution or - cumulative density function, etc. - - Notes:: - - The probability density for the Log Series distribution is - - .. math:: P(k) = \frac{-p^k}{k \ln(1-p)}, - - where p = probability. - - The log series distribution is frequently used to represent species - richness and occurrence, first proposed by Fisher, Corbet, and - Williams in 1943 [2]. It may also be used to model the numbers of - occupants seen in cars [3]. - - References:: - - .. [1] Buzas, Martin A.; Culver, Stephen J., Understanding regional - species diversity through the log series distribution of - occurrences: BIODIVERSITY RESEARCH Diversity & Distributions, - Volume 5, Number 5, September 1999 , pp. 187-195(9). - .. [2] Fisher, R.A,, A.S. Corbet, and C.B. Williams. 1943. The - relation between the number of species and the number of - individuals in a random sample of an animal population. - Journal of Animal Ecology, 12:42-58. - .. [3] D. J. Hand, F. Daly, D. Lunn, E. Ostrowski, A Handbook of Small - Data Sets, CRC Press, 1994. - .. [4] Wikipedia, "Logarithmic distribution", - https://en.wikipedia.org/wiki/Logarithmic_distribution - - Examples:: - - Draw samples from the distribution: - - >>> a = .6 - >>> s = bm.random.logseries(a, 10000) - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s) - - # plot against distribution - - >>> def logseries(k, p): - ... return -p**k/(k*np.log(1-p)) - >>> plt.plot(bins, logseries(bins, a)*count.max()/ - ... logseries(bins, a).max(), 'r') - >>> plt.show() - """ - return DEFAULT.logseries(p, size, key=key) + Notes:: + The probability density function for the Rayleigh distribution is -def multinomial(n, pvals, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a multinomial distribution. + .. math:: P(x;scale) = \frac{x}{scale^2}e^{\frac{-x^2}{2 \cdotp scale^2}} - The multinomial distribution is a multivariate generalization of the - binomial distribution. Take an experiment with one of ``p`` - possible outcomes. An example of such an experiment is throwing a dice, - where the outcome can be 1 through 6. Each sample drawn from the - distribution represents `n` such experiments. Its values, - ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the - outcome was ``i``. + The Rayleigh distribution would arise, for example, if the East + and North components of the wind velocity had identical zero-mean + Gaussian distributions. Then the wind speed would have a Rayleigh + distribution. - Parameters:: - - n : int - Number of experiments. - pvals : sequence of floats, length p - Probabilities of each of the ``p`` different outcomes. These - must sum to 1 (however, the last element is always assumed to - account for the remaining probability, as long as - ``sum(pvals[:-1]) <= 1)``. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. Default is None, in which case a - single value is returned. + References:: - Returns:: - - out : ndarray - The drawn samples, of shape *size*, if that was provided. If not, - the shape is ``(N,)``. + .. [1] Brighton Webs Ltd., "Rayleigh Distribution," + https://web.archive.org/web/20090514091424/http://brighton-webs.co.uk:80/distributions/rayleigh.asp + .. [2] Wikipedia, "Rayleigh distribution" + https://en.wikipedia.org/wiki/Rayleigh_distribution - In other words, each entry ``out[i,j,...,:]`` is an N-dimensional - value drawn from the distribution. + Examples:: - Examples:: - - Throw a dice 20 times: + Draw values from the distribution and plot the histogram - >>> bm.random.multinomial(20, [1/6.]*6, size=1) - array([[4, 1, 7, 5, 2, 1]]) # random + >>> from matplotlib.pyplot import hist + >>> values = hist(bm.random.rayleigh(3, 100000), bins=200, density=True) - It landed 4 times on 1, once on 2, etc. + Wave heights tend to follow a Rayleigh distribution. If the mean wave + height is 1 meter, what fraction of waves are likely to be larger than 3 + meters? - Now, throw the dice 20 times, and 20 times again: + >>> meanvalue = 1 + >>> modevalue = np.sqrt(2 / np.pi) * meanvalue + >>> s = bm.random.rayleigh(modevalue, 1000000) - >>> bm.random.multinomial(20, [1/6.]*6, size=2) - array([[3, 4, 3, 3, 4, 3], # random - [2, 4, 3, 4, 0, 7]]) + The percentage of waves larger than 3 meters is: - For the first run, we threw 3 times 1, 4 times 2, etc. For the second, - we threw 2 times 1, 4 times 2, etc. + >>> 100.*sum(s>3)/1000000. + 0.087300000000000003 # random + """ + return DEFAULT.rayleigh(scale, size, key=key) - A loaded die is more likely to land on number 6: - >>> bm.random.multinomial(100, [1/7.]*5 + [2/7.]) - array([11, 16, 14, 17, 16, 26]) # random +def triangular(size: Optional[Union[int, Sequence[int]]] = None, + key: Optional[Union[int, JAX_RAND_KEY]] = None): + r""" + Draw samples from the triangular distribution over the + interval ``[left, right]``. - The probability inputs should be normalized. As an implementation - detail, the value of the last entry is ignored and assumed to take - up any leftover probability mass, but this should not be relied on. - A biased coin which has twice as much weight on one side as on the - other should be sampled like so: + The triangular distribution is a continuous probability + distribution with lower limit left, peak at mode, and upper + limit right. Unlike the other distributions, these parameters + directly define the shape of the pdf. - >>> bm.random.multinomial(100, [1.0 / 3, 2.0 / 3]) # RIGHT - array([38, 62]) # random + Parameters:: - not like: + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``left``, ``mode``, and ``right`` + are all scalars. Otherwise, ``np.broadcast(left, mode, right).size`` + samples are drawn. - >>> bm.random.multinomial(100, [1.0, 2.0]) # WRONG - Traceback (most recent call last): - ValueError: pvals < 0, pvals > 1 or pvals contains NaNs - """ - return DEFAULT.multinomial(n, pvals, size, key=key) + Returns:: + out : ndarray or scalar + Drawn samples from the parameterized triangular distribution. -def multivariate_normal(mean, cov, size: Optional[Union[int, Sequence[int]]] = None, method: str = 'cholesky', - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw random samples from a multivariate normal distribution. - - The multivariate normal, multinormal or Gaussian distribution is a - generalization of the one-dimensional normal distribution to higher - dimensions. Such a distribution is specified by its mean and - covariance matrix. These parameters are analogous to the mean - (average or "center") and variance (standard deviation, or "width," - squared) of the one-dimensional normal distribution. - - Parameters:: - - mean : 1-D array_like, of length N - Mean of the N-dimensional distribution. - cov : 2-D array_like, of shape (N, N) - Covariance matrix of the distribution. It must be symmetric and - positive-semidefinite for proper sampling. - size : int or tuple of ints, optional - Given a shape of, for example, ``(m,n,k)``, ``m*n*k`` samples are - generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because - each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``. - If no shape is specified, a single (`N`-D) sample is returned. - check_valid : { 'warn', 'raise', 'ignore' }, optional - Behavior when the covariance matrix is not positive semidefinite. - tol : float, optional - Tolerance when checking the singular values in covariance matrix. - cov is cast to double before the check. - - Returns:: - - out : ndarray - The drawn samples, of shape *size*, if that was provided. If not, - the shape is ``(N,)``. - - In other words, each entry ``out[i,j,...,:]`` is an N-dimensional - value drawn from the distribution. - - Notes:: - - The mean is a coordinate in N-dimensional space, which represents the - location where samples are most likely to be generated. This is - analogous to the peak of the bell curve for the one-dimensional or - univariate normal distribution. - - Covariance indicates the level to which two variables vary together. - From the multivariate normal distribution, we draw N-dimensional - samples, :math:`X = [x_1, x_2, ... x_N]`. The covariance matrix - element :math:`C_{ij}` is the covariance of :math:`x_i` and :math:`x_j`. - The element :math:`C_{ii}` is the variance of :math:`x_i` (i.e. its - "spread"). - - Instead of specifying the full covariance matrix, popular - approximations include: - - - Spherical covariance (`cov` is a multiple of the identity matrix) - - Diagonal covariance (`cov` has non-negative elements, and only on - the diagonal) - - This geometrical property can be seen in two dimensions by plotting - generated data-points: - - >>> mean = [0, 0] - >>> cov = [[1, 0], [0, 100]] # diagonal covariance - - Diagonal covariance means that points are oriented along x or y-axis: - - >>> import matplotlib.pyplot as plt - >>> x, y = bm.random.multivariate_normal(mean, cov, 5000).T - >>> plt.plot(x, y, 'x') - >>> plt.axis('equal') - >>> plt.show() - - Note that the covariance matrix must be positive semidefinite (a.k.a. - nonnegative-definite). Otherwise, the behavior of this method is - undefined and backwards compatibility is not guaranteed. - - References:: - - .. [1] Papoulis, A., "Probability, Random Variables, and Stochastic - Processes," 3rd ed., New York: McGraw-Hill, 1991. - .. [2] Duda, R. O., Hart, P. E., and Stork, D. G., "Pattern - Classification," 2nd ed., New York: Wiley, 2001. - - Examples:: - - >>> mean = (1, 2) - >>> cov = [[1, 0], [0, 1]] - >>> x = bm.random.multivariate_normal(mean, cov, (3, 3)) - >>> x.shape - (3, 3, 2) - - Here we generate 800 samples from the bivariate normal distribution - with mean [0, 0] and covariance matrix [[6, -3], [-3, 3.5]]. The - expected variances of the first and second components of the sample - are 6 and 3.5, respectively, and the expected correlation - coefficient is -3/sqrt(6*3.5) ≈ -0.65465. - - >>> cov = np.array([[6, -3], [-3, 3.5]]) - >>> pts = bm.random.multivariate_normal([0, 0], cov, size=800) - - Check that the mean, covariance, and correlation coefficient of the - sample are close to the expected values: - - >>> pts.mean(axis=0) - array([ 0.0326911 , -0.01280782]) # may vary - >>> np.cov(pts.T) - array([[ 5.96202397, -2.85602287], - [-2.85602287, 3.47613949]]) # may vary - >>> np.corrcoef(pts.T)[0, 1] - -0.6273591314603949 # may vary - - We can visualize this data with a scatter plot. The orientation - of the point cloud illustrates the negative correlation of the - components of this sample. - - >>> import matplotlib.pyplot as plt - >>> plt.plot(pts[:, 0], pts[:, 1], '.', alpha=0.5) - >>> plt.axis('equal') - >>> plt.grid() - >>> plt.show() - """ - return DEFAULT.multivariate_normal(mean, cov, size, method, key=key) + Notes:: + The probability density function for the triangular distribution is -def negative_binomial(n, p, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a negative binomial distribution. - - Samples are drawn from a negative binomial distribution with specified - parameters, `n` successes and `p` probability of success where `n` - is > 0 and `p` is in the interval [0, 1]. - - Parameters:: - - n : float or array_like of floats - Parameter of the distribution, > 0. - p : float or array_like of floats - Parameter of the distribution, >= 0 and <=1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``n`` and ``p`` are both scalars. - Otherwise, ``np.broadcast(n, p).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized negative binomial distribution, - where each sample is equal to N, the number of failures that - occurred before a total of n successes was reached. - - Notes:: - - The probability mass function of the negative binomial distribution is - - .. math:: P(N;n,p) = \frac{\Gamma(N+n)}{N!\Gamma(n)}p^{n}(1-p)^{N}, - - where :math:`n` is the number of successes, :math:`p` is the - probability of success, :math:`N+n` is the number of trials, and - :math:`\Gamma` is the gamma function. When :math:`n` is an integer, - :math:`\frac{\Gamma(N+n)}{N!\Gamma(n)} = \binom{N+n-1}{N}`, which is - the more common form of this term in the pmf. The negative - binomial distribution gives the probability of N failures given n - successes, with a success on the last trial. - - If one throws a die repeatedly until the third time a "1" appears, - then the probability distribution of the number of non-"1"s that - appear before the third "1" is a negative binomial distribution. - - References:: - - .. [1] Weisstein, Eric W. "Negative Binomial Distribution." From - MathWorld--A Wolfram Web Resource. - http://mathworld.wolfram.com/NegativeBinomialDistribution.html - .. [2] Wikipedia, "Negative binomial distribution", - https://en.wikipedia.org/wiki/Negative_binomial_distribution - - Examples:: - - Draw samples from the distribution: - - A real world example. A company drills wild-cat oil - exploration wells, each with an estimated probability of - success of 0.1. What is the probability of having one success - for each successive well, that is what is the probability of a - single success after drilling 5 wells, after 6 wells, etc.? - - >>> s = bm.random.negative_binomial(1, 0.1, 100000) - >>> for i in range(1, 11): # doctest: +SKIP - ... probability = sum(s 0. - nonc : float or array_like of floats - Non-centrality, must be non-negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``df`` and ``nonc`` are both scalars. - Otherwise, ``np.broadcast(df, nonc).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized noncentral chi-square distribution. - - Notes:: - - The probability density function for the noncentral Chi-square - distribution is - - .. math:: P(x;df,nonc) = \sum^{\infty}_{i=0} - \frac{e^{-nonc/2}(nonc/2)^{i}}{i!} - P_{Y_{df+2i}}(x), - - where :math:`Y_{q}` is the Chi-square with q degrees of freedom. - - References:: - - .. [1] Wikipedia, "Noncentral chi-squared distribution" - https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution - - Examples:: - - Draw values from the distribution and plot the histogram - - >>> import matplotlib.pyplot as plt - >>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000), - ... bins=200, density=True) - >>> plt.show() - - Draw values from a noncentral chisquare with very small noncentrality, - and compare to a chisquare. - - >>> plt.figure() - >>> values = plt.hist(bm.random.noncentral_chisquare(3, .0000001, 100000), - ... bins=np.arange(0., 25, .1), density=True) - >>> values2 = plt.hist(bm.random.chisquare(3, 100000), - ... bins=np.arange(0., 25, .1), density=True) - >>> plt.plot(values[1][0:-1], values[0]-values2[0], 'ob') - >>> plt.show() - - Demonstrate how large values of non-centrality lead to a more symmetric - distribution. - - >>> plt.figure() - >>> values = plt.hist(bm.random.noncentral_chisquare(3, 20, 100000), - ... bins=200, density=True) - >>> plt.show() - """ - return DEFAULT.noncentral_chisquare(df, nonc, size, key=key) + References:: + .. [1] Wikipedia, "Triangular distribution" + https://en.wikipedia.org/wiki/Triangular_distribution -def noncentral_f(dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from the noncentral F distribution. - - Samples are drawn from an F distribution with specified parameters, - `dfnum` (degrees of freedom in numerator) and `dfden` (degrees of - freedom in denominator), where both parameters > 1. - `nonc` is the non-centrality parameter. - - Parameters:: - - dfnum : float or array_like of floats - Numerator degrees of freedom, must be > 0. - dfden : float or array_like of floats - Denominator degrees of freedom, must be > 0. - nonc : float or array_like of floats - Non-centrality parameter, the sum of the squares of the numerator - means, must be >= 0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``dfnum``, ``dfden``, and ``nonc`` - are all scalars. Otherwise, ``np.broadcast(dfnum, dfden, nonc).size`` - samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized noncentral Fisher distribution. - - Notes:: - - When calculating the power of an experiment (power = probability of - rejecting the null hypothesis when a specific alternative is true) the - non-central F statistic becomes important. When the null hypothesis is - true, the F statistic follows a central F distribution. When the null - hypothesis is not true, then it follows a non-central F statistic. - - References:: - - .. [1] Weisstein, Eric W. "Noncentral F-Distribution." - From MathWorld--A Wolfram Web Resource. - http://mathworld.wolfram.com/NoncentralF-Distribution.html - .. [2] Wikipedia, "Noncentral F-distribution", - https://en.wikipedia.org/wiki/Noncentral_F-distribution - - Examples:: - - In a study, testing for a specific alternative to the null hypothesis - requires use of the Noncentral F distribution. We need to calculate the - area in the tail of the distribution that exceeds the value of the F - distribution for the null hypothesis. We'll plot the two probability - distributions for comparison. - - >>> dfnum = 3 # between group deg of freedom - >>> dfden = 20 # within groups degrees of freedom - >>> nonc = 3.0 - >>> nc_vals = bm.random.noncentral_f(dfnum, dfden, nonc, 1000000) - >>> NF = np.histogram(nc_vals, bins=50, density=True) - >>> c_vals = bm.random.f(dfnum, dfden, 1000000) - >>> F = np.histogram(c_vals, bins=50, density=True) - >>> import matplotlib.pyplot as plt - >>> plt.plot(F[1][1:], F[0]) - >>> plt.plot(NF[1][1:], NF[0]) - >>> plt.show() - """ - return DEFAULT.noncentral_f(dfnum, dfden, nonc, size, key=key) + Examples:: + Draw values from the distribution and plot the histogram: -def power(a, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draws samples in [0, 1] from a power distribution with positive - exponent a - 1. - - Also known as the power function distribution. - - Parameters:: - - a : float or array_like of floats - Parameter of the distribution. Must be non-negative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``a`` is a scalar. Otherwise, - ``np.array(a).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized power distribution. - - Raises:: - - ValueError - If a <= 0. - - Notes:: - - The probability density function is - - .. math:: P(x; a) = ax^{a-1}, 0 \le x \le 1, a>0. - - The power function distribution is just the inverse of the Pareto - distribution. It may also be seen as a special case of the Beta - distribution. - - It is used, for example, in modeling the over-reporting of insurance - claims. - - References:: - - .. [1] Christian Kleiber, Samuel Kotz, "Statistical size distributions - in economics and actuarial sciences", Wiley, 2003. - .. [2] Heckert, N. A. and Filliben, James J. "NIST Handbook 148: - Dataplot Reference Manual, Volume 2: Let Subcommands and Library - Functions", National Institute of Standards and Technology - Handbook Series, June 2003. - https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/powpdf.pdf - - Examples:: - - Draw samples from the distribution: - - >>> a = 5. # shape - >>> samples = 1000 - >>> s = bm.random.power(a, samples) - - Display the histogram of the samples, along with - the probability density function: - - >>> import matplotlib.pyplot as plt - >>> count, bins, ignored = plt.hist(s, bins=30) - >>> x = np.linspace(0, 1, 100) - >>> y = a*x**(a-1.) - >>> normed_y = samples*np.diff(bins)[0]*y - >>> plt.plot(x, normed_y) - >>> plt.show() - - Compare the power function distribution to the inverse of the Pareto. - - >>> from scipy import stats # doctest: +SKIP - >>> rvs = bm.random.power(5, 1000000) - >>> rvsp = bm.random.pareto(5, 1000000) - >>> xx = np.linspace(0,1,100) - >>> powpdf = stats.powerlaw.pdf(xx,5) # doctest: +SKIP - - >>> plt.figure() - >>> plt.hist(rvs, bins=50, density=True) - >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP - >>> plt.title('bm.random.power(5)') - - >>> plt.figure() - >>> plt.hist(1./(1.+rvsp), bins=50, density=True) - >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP - >>> plt.title('inverse of 1 + bm.random.pareto(5)') - - >>> plt.figure() - >>> plt.hist(1./(1.+rvsp), bins=50, density=True) - >>> plt.plot(xx,powpdf,'r-') # doctest: +SKIP - >>> plt.title('inverse of stats.pareto(5)') - """ - return DEFAULT.power(a, size, key=key) + >>> import matplotlib.pyplot as plt + >>> h = plt.hist(bm.random.triangular(-3, 0, 8, 100000), bins=200, + ... density=True) + >>> plt.show() + """ + return DEFAULT.triangular(size, key=key) -def rayleigh(scale=1.0, +def vonmises(mu, + kappa, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Rayleigh distribution. + r""" + Draw samples from a von Mises distribution. - The :math:`\chi` and Weibull distributions are generalizations of the - Rayleigh. + Samples are drawn from a von Mises distribution with specified mode + (mu) and dispersion (kappa), on the interval [-pi, pi]. - Parameters:: - - scale : float or array_like of floats, optional - Scale, also equals the mode. Must be non-negative. Default is 1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``scale`` is a scalar. Otherwise, - ``np.array(scale).size`` samples are drawn. + The von Mises distribution (also known as the circular normal + distribution) is a continuous probability distribution on the unit + circle. It may be thought of as the circular analogue of the normal + distribution. - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Rayleigh distribution. + Parameters:: - Notes:: - - The probability density function for the Rayleigh distribution is + mu : float or array_like of floats + Mode ("center") of the distribution. + kappa : float or array_like of floats + Dispersion of the distribution, has to be >=0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``mu`` and ``kappa`` are both scalars. + Otherwise, ``np.broadcast(mu, kappa).size`` samples are drawn. - .. math:: P(x;scale) = \frac{x}{scale^2}e^{\frac{-x^2}{2 \cdotp scale^2}} + Returns:: - The Rayleigh distribution would arise, for example, if the East - and North components of the wind velocity had identical zero-mean - Gaussian distributions. Then the wind speed would have a Rayleigh - distribution. + out : ndarray or scalar + Drawn samples from the parameterized von Mises distribution. - References:: - - .. [1] Brighton Webs Ltd., "Rayleigh Distribution," - https://web.archive.org/web/20090514091424/http://brighton-webs.co.uk:80/distributions/rayleigh.asp - .. [2] Wikipedia, "Rayleigh distribution" - https://en.wikipedia.org/wiki/Rayleigh_distribution + See Also:: - Examples:: - - Draw values from the distribution and plot the histogram + scipy.stats.vonmises : probability density function, distribution, or + cumulative density function, etc. - >>> from matplotlib.pyplot import hist - >>> values = hist(bm.random.rayleigh(3, 100000), bins=200, density=True) + Notes:: - Wave heights tend to follow a Rayleigh distribution. If the mean wave - height is 1 meter, what fraction of waves are likely to be larger than 3 - meters? + The probability density for the von Mises distribution is - >>> meanvalue = 1 - >>> modevalue = np.sqrt(2 / np.pi) * meanvalue - >>> s = bm.random.rayleigh(modevalue, 1000000) + .. math:: p(x) = \frac{e^{\kappa cos(x-\mu)}}{2\pi I_0(\kappa)}, - The percentage of waves larger than 3 meters is: + where :math:`\mu` is the mode and :math:`\kappa` the dispersion, + and :math:`I_0(\kappa)` is the modified Bessel function of order 0. - >>> 100.*sum(s>3)/1000000. - 0.087300000000000003 # random - """ - return DEFAULT.rayleigh(scale, size, key=key) + The von Mises is named for Richard Edler von Mises, who was born in + Austria-Hungary, in what is now the Ukraine. He fled to the United + States in 1939 and became a professor at Harvard. He worked in + probability theory, aerodynamics, fluid mechanics, and philosophy of + science. + References:: -def triangular(size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from the triangular distribution over the - interval ``[left, right]``. - - The triangular distribution is a continuous probability - distribution with lower limit left, peak at mode, and upper - limit right. Unlike the other distributions, these parameters - directly define the shape of the pdf. - - Parameters:: - - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``left``, ``mode``, and ``right`` - are all scalars. Otherwise, ``np.broadcast(left, mode, right).size`` - samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized triangular distribution. - - Notes:: - - The probability density function for the triangular distribution is - - .. math:: P(x;l, m, r) = \begin{cases} - \frac{2(x-l)}{(r-l)(m-l)}& \text{for $l \leq x \leq m$},\\ - \frac{2(r-x)}{(r-l)(r-m)}& \text{for $m \leq x \leq r$},\\ - 0& \text{otherwise}. - \end{cases} - - The triangular distribution is often used in ill-defined - problems where the underlying distribution is not known, but - some knowledge of the limits and mode exists. Often it is used - in simulations. - - References:: - - .. [1] Wikipedia, "Triangular distribution" - https://en.wikipedia.org/wiki/Triangular_distribution - - Examples:: - - Draw values from the distribution and plot the histogram: - - >>> import matplotlib.pyplot as plt - >>> h = plt.hist(bm.random.triangular(-3, 0, 8, 100000), bins=200, - ... density=True) - >>> plt.show() - """ - return DEFAULT.triangular(size, key=key) + .. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of + Mathematical Functions with Formulas, Graphs, and Mathematical + Tables, 9th printing," New York: Dover, 1972. + .. [2] von Mises, R., "Mathematical Theory of Probability + and Statistics", New York: Academic Press, 1964. + Examples:: -def vonmises(mu, - kappa, - size: Optional[Union[int, Sequence[int]]] = None, - key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a von Mises distribution. - - Samples are drawn from a von Mises distribution with specified mode - (mu) and dispersion (kappa), on the interval [-pi, pi]. - - The von Mises distribution (also known as the circular normal - distribution) is a continuous probability distribution on the unit - circle. It may be thought of as the circular analogue of the normal - distribution. - - Parameters:: - - mu : float or array_like of floats - Mode ("center") of the distribution. - kappa : float or array_like of floats - Dispersion of the distribution, has to be >=0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``mu`` and ``kappa`` are both scalars. - Otherwise, ``np.broadcast(mu, kappa).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized von Mises distribution. - - See Also:: - - scipy.stats.vonmises : probability density function, distribution, or - cumulative density function, etc. - - Notes:: - - The probability density for the von Mises distribution is - - .. math:: p(x) = \frac{e^{\kappa cos(x-\mu)}}{2\pi I_0(\kappa)}, - - where :math:`\mu` is the mode and :math:`\kappa` the dispersion, - and :math:`I_0(\kappa)` is the modified Bessel function of order 0. - - The von Mises is named for Richard Edler von Mises, who was born in - Austria-Hungary, in what is now the Ukraine. He fled to the United - States in 1939 and became a professor at Harvard. He worked in - probability theory, aerodynamics, fluid mechanics, and philosophy of - science. - - References:: - - .. [1] Abramowitz, M. and Stegun, I. A. (Eds.). "Handbook of - Mathematical Functions with Formulas, Graphs, and Mathematical - Tables, 9th printing," New York: Dover, 1972. - .. [2] von Mises, R., "Mathematical Theory of Probability - and Statistics", New York: Academic Press, 1964. - - Examples:: - - Draw samples from the distribution: - - >>> mu, kappa = 0.0, 4.0 # mean and dispersion - >>> s = bm.random.vonmises(mu, kappa, 1000) - - Display the histogram of the samples, along with - the probability density function: - - >>> import matplotlib.pyplot as plt - >>> from scipy.special import i0 # doctest: +SKIP - >>> plt.hist(s, 50, density=True) - >>> x = np.linspace(-np.pi, np.pi, num=51) - >>> y = np.exp(kappa*np.cos(x-mu))/(2*np.pi*i0(kappa)) # doctest: +SKIP - >>> plt.plot(x, y, linewidth=2, color='r') # doctest: +SKIP - >>> plt.show() - """ - return DEFAULT.vonmises(mu, kappa, size, key=key) + Draw samples from the distribution: + + >>> mu, kappa = 0.0, 4.0 # mean and dispersion + >>> s = bm.random.vonmises(mu, kappa, 1000) + + Display the histogram of the samples, along with + the probability density function: + + >>> import matplotlib.pyplot as plt + >>> from scipy.special import i0 # doctest: +SKIP + >>> plt.hist(s, 50, density=True) + >>> x = np.linspace(-np.pi, np.pi, num=51) + >>> y = np.exp(kappa*np.cos(x-mu))/(2*np.pi*i0(kappa)) # doctest: +SKIP + >>> plt.plot(x, y, linewidth=2, color='r') # doctest: +SKIP + >>> plt.show() + """ + return DEFAULT.vonmises(mu, kappa, size, key=key) def wald(mean, scale, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Wald, or inverse Gaussian, distribution. - - As the scale approaches infinity, the distribution becomes more like a - Gaussian. Some references claim that the Wald is an inverse Gaussian - with mean equal to 1, but this is by no means universal. - - The inverse Gaussian distribution was first studied in relationship to - Brownian motion. In 1956 M.C.K. Tweedie used the name inverse Gaussian - because there is an inverse relationship between the time to cover a - unit distance and distance covered in unit time. - - Parameters:: - - mean : float or array_like of floats - Distribution mean, must be > 0. - scale : float or array_like of floats - Scale parameter, must be > 0. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``mean`` and ``scale`` are both scalars. - Otherwise, ``np.broadcast(mean, scale).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Wald distribution. - - Notes:: - - The probability density function for the Wald distribution is - - .. math:: P(x;mean,scale) = \sqrt{\frac{scale}{2\pi x^3}}e^ - \frac{-scale(x-mean)^2}{2\cdotp mean^2x} - - As noted above the inverse Gaussian distribution first arise - from attempts to model Brownian motion. It is also a - competitor to the Weibull for use in reliability modeling and - modeling stock returns and interest rate processes. - - References:: - - .. [1] Brighton Webs Ltd., Wald Distribution, - https://web.archive.org/web/20090423014010/http://www.brighton-webs.co.uk:80/distributions/wald.asp - .. [2] Chhikara, Raj S., and Folks, J. Leroy, "The Inverse Gaussian - Distribution: Theory : Methodology, and Applications", CRC Press, - 1988. - .. [3] Wikipedia, "Inverse Gaussian distribution" - https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution - - Examples:: - - Draw values from the distribution and plot the histogram: - - >>> import matplotlib.pyplot as plt - >>> h = plt.hist(bm.random.wald(3, 2, 100000), bins=200, density=True) - >>> plt.show() - """ - return DEFAULT.wald(mean, scale, size, key=key) + r""" + Draw samples from a Wald, or inverse Gaussian, distribution. + + As the scale approaches infinity, the distribution becomes more like a + Gaussian. Some references claim that the Wald is an inverse Gaussian + with mean equal to 1, but this is by no means universal. + + The inverse Gaussian distribution was first studied in relationship to + Brownian motion. In 1956 M.C.K. Tweedie used the name inverse Gaussian + because there is an inverse relationship between the time to cover a + unit distance and distance covered in unit time. + + Parameters:: + + mean : float or array_like of floats + Distribution mean, must be > 0. + scale : float or array_like of floats + Scale parameter, must be > 0. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``mean`` and ``scale`` are both scalars. + Otherwise, ``np.broadcast(mean, scale).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Wald distribution. + + Notes:: + + The probability density function for the Wald distribution is + + .. math:: P(x;mean,scale) = \sqrt{\frac{scale}{2\pi x^3}}e^ + \frac{-scale(x-mean)^2}{2\cdotp mean^2x} + + As noted above the inverse Gaussian distribution first arise + from attempts to model Brownian motion. It is also a + competitor to the Weibull for use in reliability modeling and + modeling stock returns and interest rate processes. + + References:: + + .. [1] Brighton Webs Ltd., Wald Distribution, + https://web.archive.org/web/20090423014010/http://www.brighton-webs.co.uk:80/distributions/wald.asp + .. [2] Chhikara, Raj S., and Folks, J. Leroy, "The Inverse Gaussian + Distribution: Theory : Methodology, and Applications", CRC Press, + 1988. + .. [3] Wikipedia, "Inverse Gaussian distribution" + https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution + + Examples:: + + Draw values from the distribution and plot the histogram: + + >>> import matplotlib.pyplot as plt + >>> h = plt.hist(bm.random.wald(3, 2, 100000), bins=200, density=True) + >>> plt.show() + """ + return DEFAULT.wald(mean, scale, size, key=key) def weibull(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Weibull distribution. + r""" + Draw samples from a Weibull distribution. + + Draw samples from a 1-parameter Weibull distribution with the given + shape parameter `a`. + + .. math:: X = (-ln(U))^{1/a} + + Here, U is drawn from the uniform distribution over (0,1]. - Draw samples from a 1-parameter Weibull distribution with the given - shape parameter `a`. + The more common 2-parameter Weibull, including a scale parameter + :math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`. - .. math:: X = (-ln(U))^{1/a} + .. note:: + New code should use the ``weibull`` method of a ``default_rng()`` + instance instead; please see the :ref:`random-quick-start`. - Here, U is drawn from the uniform distribution over (0,1]. + Parameters:: + + a : float or array_like of floats + Shape parameter of the distribution. Must be nonnegative. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns:: - The more common 2-parameter Weibull, including a scale parameter - :math:`\lambda` is just :math:`X = \lambda(-ln(U))^{1/a}`. + out : ndarray or scalar + Drawn samples from the parameterized Weibull distribution. - .. note:: - New code should use the ``weibull`` method of a ``default_rng()`` - instance instead; please see the :ref:`random-quick-start`. + Notes:: - Parameters:: - - a : float or array_like of floats - Shape parameter of the distribution. Must be nonnegative. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``a`` is a scalar. Otherwise, - ``np.array(a).size`` samples are drawn. + The Weibull (or Type III asymptotic extreme value distribution + for smallest values, SEV Type III, or Rosin-Rammler + distribution) is one of a class of Generalized Extreme Value + (GEV) distributions used in modeling extreme value problems. + This class includes the Gumbel and Frechet distributions. - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Weibull distribution. + The probability density for the Weibull distribution is - Notes:: - - The Weibull (or Type III asymptotic extreme value distribution - for smallest values, SEV Type III, or Rosin-Rammler - distribution) is one of a class of Generalized Extreme Value - (GEV) distributions used in modeling extreme value problems. - This class includes the Gumbel and Frechet distributions. + .. math:: p(x) = \frac{a} + {\lambda}(\frac{x}{\lambda})^{a-1}e^{-(x/\lambda)^a}, - The probability density for the Weibull distribution is + where :math:`a` is the shape and :math:`\lambda` the scale. - .. math:: p(x) = \frac{a} - {\lambda}(\frac{x}{\lambda})^{a-1}e^{-(x/\lambda)^a}, + The function has its peak (the mode) at + :math:`\lambda(\frac{a-1}{a})^{1/a}`. - where :math:`a` is the shape and :math:`\lambda` the scale. + When ``a = 1``, the Weibull distribution reduces to the exponential + distribution. - The function has its peak (the mode) at - :math:`\lambda(\frac{a-1}{a})^{1/a}`. + References:: - When ``a = 1``, the Weibull distribution reduces to the exponential - distribution. + .. [1] Waloddi Weibull, Royal Technical University, Stockholm, + 1939 "A Statistical Theory Of The Strength Of Materials", + Ingeniorsvetenskapsakademiens Handlingar Nr 151, 1939, + Generalstabens Litografiska Anstalts Forlag, Stockholm. + .. [2] Waloddi Weibull, "A Statistical Distribution Function of + Wide Applicability", Journal Of Applied Mechanics ASME Paper + 1951. + .. [3] Wikipedia, "Weibull distribution", + https://en.wikipedia.org/wiki/Weibull_distribution - References:: - - .. [1] Waloddi Weibull, Royal Technical University, Stockholm, - 1939 "A Statistical Theory Of The Strength Of Materials", - Ingeniorsvetenskapsakademiens Handlingar Nr 151, 1939, - Generalstabens Litografiska Anstalts Forlag, Stockholm. - .. [2] Waloddi Weibull, "A Statistical Distribution Function of - Wide Applicability", Journal Of Applied Mechanics ASME Paper - 1951. - .. [3] Wikipedia, "Weibull distribution", - https://en.wikipedia.org/wiki/Weibull_distribution + Examples:: - Examples:: - - Draw samples from the distribution: + Draw samples from the distribution: - >>> a = 5. # shape - >>> s = brainpy.math.random.weibull(a, 1000) + >>> a = 5. # shape + >>> s = brainpy.math.random.weibull(a, 1000) - Display the histogram of the samples, along with - the probability density function: + Display the histogram of the samples, along with + the probability density function: - >>> import matplotlib.pyplot as plt - >>> x = np.arange(1,100.)/50. - >>> def weib(x,n,a): - ... return (a / n) * (x / n)**(a - 1) * np.exp(-(x / n)**a) + >>> import matplotlib.pyplot as plt + >>> x = np.arange(1,100.)/50. + >>> def weib(x,n,a): + ... return (a / n) * (x / n)**(a - 1) * np.exp(-(x / n)**a) - >>> count, bins, ignored = plt.hist(brainpy.math.random.weibull(5.,1000)) - >>> x = np.arange(1,100.)/50. - >>> scale = count.max()/weib(x, 1., 5.).max() - >>> plt.plot(x, weib(x, 1., 5.)*scale) - >>> plt.show() + >>> count, bins, ignored = plt.hist(brainpy.math.random.weibull(5.,1000)) + >>> x = np.arange(1,100.)/50. + >>> scale = count.max()/weib(x, 1., 5.).max() + >>> plt.plot(x, weib(x, 1., 5.)*scale) + >>> plt.show() - """ - return DEFAULT.weibull(a, size, key=key) + """ + return DEFAULT.weibull(a, size, key=key) def weibull_min(a, scale=None, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample from a Weibull distribution. + """Sample from a Weibull distribution. - The scipy counterpart is `scipy.stats.weibull_min`. + The scipy counterpart is `scipy.stats.weibull_min`. - Args: - scale: The scale parameter of the distribution. - concentration: The concentration parameter of the distribution. - shape: The shape added to the parameters loc and scale broadcastable shape. - dtype: The type used for samples. - key: a PRNG key or a seed. + Args: + scale: The scale parameter of the distribution. + concentration: The concentration parameter of the distribution. + shape: The shape added to the parameters loc and scale broadcastable shape. + dtype: The type used for samples. + key: a PRNG key or a seed. - Returns: - A jnp.array of samples. + Returns: + A jnp.array of samples. - """ - return DEFAULT.weibull_min(a, scale, size, key=key) + """ + return DEFAULT.weibull_min(a, scale, size, key=key) def zipf(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - r""" - Draw samples from a Zipf distribution. - - Samples are drawn from a Zipf distribution with specified parameter - `a` > 1. - - The Zipf distribution (also known as the zeta distribution) is a - discrete probability distribution that satisfies Zipf's law: the - frequency of an item is inversely proportional to its rank in a - frequency table. - - .. note:: - New code should use the ``zipf`` method of a ``default_rng()`` - instance instead; please see the :ref:`random-quick-start`. - - Parameters:: - - a : float or array_like of floats - Distribution parameter. Must be greater than 1. - size : int or tuple of ints, optional - Output shape. If the given shape is, e.g., ``(m, n, k)``, then - ``m * n * k`` samples are drawn. If size is ``None`` (default), - a single value is returned if ``a`` is a scalar. Otherwise, - ``np.array(a).size`` samples are drawn. - - Returns:: - - out : ndarray or scalar - Drawn samples from the parameterized Zipf distribution. - - See Also:: - - scipy.stats.zipf : probability density function, distribution, or - cumulative density function, etc. - - Notes:: - - The probability density for the Zipf distribution is - - .. math:: p(k) = \frac{k^{-a}}{\zeta(a)}, - - for integers :math:`k \geq 1`, where :math:`\zeta` is the Riemann Zeta - function. - - It is named for the American linguist George Kingsley Zipf, who noted - that the frequency of any word in a sample of a language is inversely - proportional to its rank in the frequency table. - - References:: - - .. [1] Zipf, G. K., "Selected Studies of the Principle of Relative - Frequency in Language," Cambridge, MA: Harvard Univ. Press, - 1932. - - Examples:: - - Draw samples from the distribution: - - >>> a = 4.0 - >>> n = 20000 - >>> s = brainpy.math.random.zipf(a, n) - - Display the histogram of the samples, along with - the expected histogram based on the probability - density function: - - >>> import matplotlib.pyplot as plt - >>> from scipy.special import zeta # doctest: +SKIP - - `bincount` provides a fast histogram for small integers. - - >>> count = np.bincount(s) - >>> k = np.arange(1, s.max() + 1) - - >>> plt.bar(k, count[1:], alpha=0.5, label='sample count') - >>> plt.plot(k, n*(k**-a)/zeta(a), 'k.-', alpha=0.5, - ... label='expected count') # doctest: +SKIP - >>> plt.semilogy() - >>> plt.grid(alpha=0.4) - >>> plt.legend() - >>> plt.title(f'Zipf sample, a={a}, size={n}') - >>> plt.show() - """ - return DEFAULT.zipf(a, size, key=key) + r""" + Draw samples from a Zipf distribution. + + Samples are drawn from a Zipf distribution with specified parameter + `a` > 1. + + The Zipf distribution (also known as the zeta distribution) is a + discrete probability distribution that satisfies Zipf's law: the + frequency of an item is inversely proportional to its rank in a + frequency table. + + .. note:: + New code should use the ``zipf`` method of a ``default_rng()`` + instance instead; please see the :ref:`random-quick-start`. + + Parameters:: + + a : float or array_like of floats + Distribution parameter. Must be greater than 1. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. If size is ``None`` (default), + a single value is returned if ``a`` is a scalar. Otherwise, + ``np.array(a).size`` samples are drawn. + + Returns:: + + out : ndarray or scalar + Drawn samples from the parameterized Zipf distribution. + + See Also:: + + scipy.stats.zipf : probability density function, distribution, or + cumulative density function, etc. + + Notes:: + + The probability density for the Zipf distribution is + + .. math:: p(k) = \frac{k^{-a}}{\zeta(a)}, + + for integers :math:`k \geq 1`, where :math:`\zeta` is the Riemann Zeta + function. + + It is named for the American linguist George Kingsley Zipf, who noted + that the frequency of any word in a sample of a language is inversely + proportional to its rank in the frequency table. + + References:: + + .. [1] Zipf, G. K., "Selected Studies of the Principle of Relative + Frequency in Language," Cambridge, MA: Harvard Univ. Press, + 1932. + + Examples:: + + Draw samples from the distribution: + + >>> a = 4.0 + >>> n = 20000 + >>> s = brainpy.math.random.zipf(a, n) + + Display the histogram of the samples, along with + the expected histogram based on the probability + density function: + + >>> import matplotlib.pyplot as plt + >>> from scipy.special import zeta # doctest: +SKIP + + `bincount` provides a fast histogram for small integers. + + >>> count = np.bincount(s) + >>> k = np.arange(1, s.max() + 1) + + >>> plt.bar(k, count[1:], alpha=0.5, label='sample count') + >>> plt.plot(k, n*(k**-a)/zeta(a), 'k.-', alpha=0.5, + ... label='expected count') # doctest: +SKIP + >>> plt.semilogy() + >>> plt.grid(alpha=0.4) + >>> plt.legend() + >>> plt.title(f'Zipf sample, a={a}, size={n}') + >>> plt.show() + """ + return DEFAULT.zipf(a, size, key=key) def maxwell(size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample from a one sided Maxwell distribution. + """Sample from a one sided Maxwell distribution. - The scipy counterpart is `scipy.stats.maxwell`. + The scipy counterpart is `scipy.stats.maxwell`. - Args: - key: a PRNG key. - size: The shape of the returned samples. - dtype: The type used for samples. + Args: + key: a PRNG key. + size: The shape of the returned samples. + dtype: The type used for samples. - Returns: - A jnp.array of samples, of shape `shape`. + Returns: + A jnp.array of samples, of shape `shape`. - """ - return DEFAULT.maxwell(size, key=key) + """ + return DEFAULT.maxwell(size, key=key) def t(df, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample Student’s t random values. + """Sample Student’s t random values. + + Parameters:: - Parameters:: - - df: float, array_like - A float or array of floats broadcast-compatible with shape representing the parameter of the distribution. - size: optional, int, tuple of int - A tuple of non-negative integers specifying the result shape. - Must be broadcast-compatible with `df`. The default (None) produces a result shape equal to `df.shape`. + df: float, array_like + A float or array of floats broadcast-compatible with shape representing the parameter of the distribution. + size: optional, int, tuple of int + A tuple of non-negative integers specifying the result shape. + Must be broadcast-compatible with `df`. The default (None) produces a result shape equal to `df.shape`. + + Returns:: - Returns:: - - out: array_like - The sampled value. - """ - return DEFAULT.t(df, size, key=key) + out: array_like + The sampled value. + """ + return DEFAULT.t(df, size, key=key) def orthogonal(n: int, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample uniformly from the orthogonal group `O(n)`. + """Sample uniformly from the orthogonal group `O(n)`. + + Parameters:: + + n: int + An integer indicating the resulting dimension. + size: optional, int, tuple of int + The batch dimensions of the result. - Parameters:: - - n: int - An integer indicating the resulting dimension. - size: optional, int, tuple of int - The batch dimensions of the result. + Returns:: - Returns:: - - out: Array - The sampled results. - """ - return DEFAULT.orthogonal(n, size, key=key) + out: Array + The sampled results. + """ + return DEFAULT.orthogonal(n, size, key=key) def loggamma(a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample log-gamma random values. + """Sample log-gamma random values. + + Parameters:: + + a: float, array_like + A float or array of floats broadcast-compatible with shape representing the parameter of the distribution. + size: optional, int, tuple of int + A tuple of nonnegative integers specifying the result shape. + Must be broadcast-compatible with `a`. The default (None) produces a result shape equal to `a.shape`. - Parameters:: - - a: float, array_like - A float or array of floats broadcast-compatible with shape representing the parameter of the distribution. - size: optional, int, tuple of int - A tuple of nonnegative integers specifying the result shape. - Must be broadcast-compatible with `a`. The default (None) produces a result shape equal to `a.shape`. + Returns:: - Returns:: - - out: array_like - The sampled results. - """ - return DEFAULT.loggamma(a, size, key=key) + out: array_like + The sampled results. + """ + return DEFAULT.loggamma(a, size, key=key) def categorical(logits, axis: int = -1, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Sample random values from categorical distributions. + """Sample random values from categorical distributions. - Args: - logits: Unnormalized log probabilities of the categorical distribution(s) to sample from, - so that `softmax(logits, axis)` gives the corresponding probabilities. - axis: Axis along which logits belong to the same categorical distribution. - shape: Optional, a tuple of nonnegative integers representing the result shape. - Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. - The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. - key: a PRNG key used as the random key. + Args: + logits: Unnormalized log probabilities of the categorical distribution(s) to sample from, + so that `softmax(logits, axis)` gives the corresponding probabilities. + axis: Axis along which logits belong to the same categorical distribution. + shape: Optional, a tuple of nonnegative integers representing the result shape. + Must be broadcast-compatible with ``np.delete(logits.shape, axis)``. + The default (None) produces a result shape equal to ``np.delete(logits.shape, axis)``. + key: a PRNG key used as the random key. - Returns: - A random array with int dtype and shape given by ``shape`` if ``shape`` - is not None, or else ``np.delete(logits.shape, axis)``. - """ - return DEFAULT.categorical(logits, axis, size, key=key) + Returns: + A random array with int dtype and shape given by ``shape`` if ``shape`` + is not None, or else ``np.delete(logits.shape, axis)``. + """ + return DEFAULT.categorical(logits, axis, size, key=key) def rand_like(input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Similar to ``rand_like`` in torch. + """Similar to ``rand_like`` in torch. - Returns a tensor with the same size as input that is filled with random - numbers from a uniform distribution on the interval ``[0, 1)``. + Returns a tensor with the same size as input that is filled with random + numbers from a uniform distribution on the interval ``[0, 1)``. - Args: - input: the ``size`` of input will determine size of the output tensor. - dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. - key: the seed or key for the random. + Args: + input: the ``size`` of input will determine size of the output tensor. + dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. + key: the seed or key for the random. - Returns: - The random data. - """ - return DEFAULT.rand_like(input, dtype=dtype, key=key) + Returns: + The random data. + """ + return DEFAULT.rand_like(input, dtype=dtype, key=key) def randn_like(input, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Similar to ``randn_like`` in torch. + """Similar to ``randn_like`` in torch. - Returns a tensor with the same size as ``input`` that is filled with - random numbers from a normal distribution with mean 0 and variance 1. + Returns a tensor with the same size as ``input`` that is filled with + random numbers from a normal distribution with mean 0 and variance 1. - Args: - input: the ``size`` of input will determine size of the output tensor. - dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. - key: the seed or key for the random. + Args: + input: the ``size`` of input will determine size of the output tensor. + dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. + key: the seed or key for the random. - Returns: - The random data. - """ - return DEFAULT.randn_like(input, dtype=dtype, key=key) + Returns: + The random data. + """ + return DEFAULT.randn_like(input, dtype=dtype, key=key) def randint_like(input, low=0, high=None, *, dtype=None, key: Optional[Union[int, JAX_RAND_KEY]] = None): - """Similar to ``randint_like`` in torch. + """Similar to ``randint_like`` in torch. - Returns a tensor with the same shape as Tensor ``input`` filled with - random integers generated uniformly between ``low`` (inclusive) and ``high`` (exclusive). + Returns a tensor with the same shape as Tensor ``input`` filled with + random integers generated uniformly between ``low`` (inclusive) and ``high`` (exclusive). - Args: - input: the ``size`` of input will determine size of the output tensor. - low: Lowest integer to be drawn from the distribution. Default: 0. - high: One above the highest integer to be drawn from the distribution. - dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. - key: the seed or key for the random. + Args: + input: the ``size`` of input will determine size of the output tensor. + low: Lowest integer to be drawn from the distribution. Default: 0. + high: One above the highest integer to be drawn from the distribution. + dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input. + key: the seed or key for the random. - Returns: - The random data. - """ - return DEFAULT.randint_like(input=input, low=low, high=high, dtype=dtype, key=key) + Returns: + The random data. + """ + return DEFAULT.randint_like(input=input, low=low, high=high, dtype=dtype, key=key) for __k in dir(RandomState): - __t = getattr(RandomState, __k) - if not __k.startswith('__') and callable(__t) and (not __t.__doc__): - __r = globals().get(__k, None) - if __r is not None and callable(__r): - __t.__doc__ = __r.__doc__ + __t = getattr(RandomState, __k) + if not __k.startswith('__') and callable(__t) and (not __t.__doc__): + __r = globals().get(__k, None) + if __r is not None and callable(__r): + __t.__doc__ = __r.__doc__ diff --git a/brainpy/_src/math/remove_vmap.py b/brainpy/_src/math/remove_vmap.py index 72417120f..82abb6c42 100644 --- a/brainpy/_src/math/remove_vmap.py +++ b/brainpy/_src/math/remove_vmap.py @@ -1,50 +1,51 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp import jax +import jax.numpy as jnp + if jax.__version__ >= '0.5.0': - from jax.extend.core import Primitive + from jax.extend.core import Primitive else: - from jax.core import Primitive + from jax.core import Primitive from jax.core import ShapedArray from jax.interpreters import batching, mlir, xla from .ndarray import BaseArray __all__ = [ - 'remove_vmap' + 'remove_vmap' ] def remove_vmap(x, op='any'): - if isinstance(x, BaseArray): - x = x.value - if op == 'any': - return _any_without_vmap(x) - elif op == 'all': - return _all_without_vmap(x) - else: - raise ValueError(f'Do not support type: {op}') + if isinstance(x, BaseArray): + x = x.value + if op == 'any': + return _any_without_vmap(x) + elif op == 'all': + return _all_without_vmap(x) + else: + raise ValueError(f'Do not support type: {op}') _any_no_vmap_prim = Primitive('any_no_vmap') def _any_without_vmap(x): - return _any_no_vmap_prim.bind(x) + return _any_no_vmap_prim.bind(x) def _any_without_vmap_imp(x): - return jnp.any(x) + return jnp.any(x) def _any_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) + return ShapedArray(shape=(), dtype=jnp.bool_) def _any_without_vmap_batch(x, batch_axes): - (x, ) = x - return _any_without_vmap(x), batching.not_mapped + (x,) = x + return _any_without_vmap(x), batching.not_mapped _any_no_vmap_prim.def_impl(_any_without_vmap_imp) @@ -55,25 +56,24 @@ def _any_without_vmap_batch(x, batch_axes): xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True)) mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False)) - _all_no_vmap_prim = Primitive('all_no_vmap') def _all_without_vmap(x): - return _all_no_vmap_prim.bind(x) + return _all_no_vmap_prim.bind(x) def _all_without_vmap_imp(x): - return jnp.all(x) + return jnp.all(x) def _all_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) + return ShapedArray(shape=(), dtype=jnp.bool_) def _all_without_vmap_batch(x, batch_axes): - (x, ) = x - return _all_without_vmap(x), batching.not_mapped + (x,) = x + return _all_without_vmap(x), batching.not_mapped _all_no_vmap_prim.def_impl(_all_without_vmap_imp) @@ -83,4 +83,3 @@ def _all_without_vmap_batch(x, batch_axes): xla.register_translation(_all_no_vmap_prim, xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True)) mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False)) - diff --git a/brainpy/_src/math/scales.py b/brainpy/_src/math/scales.py index 367ccb479..406c080ba 100644 --- a/brainpy/_src/math/scales.py +++ b/brainpy/_src/math/scales.py @@ -4,74 +4,74 @@ from typing import Sequence, Union __all__ = [ - 'Scaling', - 'IdScaling', + 'Scaling', + 'IdScaling', ] class Scaling(object): - def __init__(self, scale, bias): - self.scale = scale - self.bias = bias - - @classmethod - def transform( - cls, - V_range: Sequence[Union[float, int]], - scaled_V_range: Sequence[Union[float, int]] = (0., 1.) - ) -> 'Scaling': - """Transform the membrane potential range to a ``Scaling`` instance. - - Args: - V_range: [V_min, V_max] - scaled_V_range: [scaled_V_min, scaled_V_max] - - Returns: - The instanced scaling object. - """ - V_min, V_max = V_range - scaled_V_min, scaled_V_max = scaled_V_range - scale = (V_max - V_min) / (scaled_V_max - scaled_V_min) - bias = scaled_V_min * scale - V_min - return cls(scale=scale, bias=bias) - - def offset_scaling(self, x, bias=None, scale=None): - if bias is None: - bias = self.bias - if scale is None: - scale = self.scale - return (x + bias) / scale - - def std_scaling(self, x, scale=None): - if scale is None: - scale = self.scale - return x / scale - - def inv_scaling(self, x, scale=None): - if scale is None: - scale = self.scale - return x * scale - - def clone(self, bias=None, scale=None): - if bias is None: - bias = self.bias - if scale is None: - scale = self.scale - return Scaling(bias=bias, scale=scale) + def __init__(self, scale, bias): + self.scale = scale + self.bias = bias + + @classmethod + def transform( + cls, + V_range: Sequence[Union[float, int]], + scaled_V_range: Sequence[Union[float, int]] = (0., 1.) + ) -> 'Scaling': + """Transform the membrane potential range to a ``Scaling`` instance. + + Args: + V_range: [V_min, V_max] + scaled_V_range: [scaled_V_min, scaled_V_max] + + Returns: + The instanced scaling object. + """ + V_min, V_max = V_range + scaled_V_min, scaled_V_max = scaled_V_range + scale = (V_max - V_min) / (scaled_V_max - scaled_V_min) + bias = scaled_V_min * scale - V_min + return cls(scale=scale, bias=bias) + + def offset_scaling(self, x, bias=None, scale=None): + if bias is None: + bias = self.bias + if scale is None: + scale = self.scale + return (x + bias) / scale + + def std_scaling(self, x, scale=None): + if scale is None: + scale = self.scale + return x / scale + + def inv_scaling(self, x, scale=None): + if scale is None: + scale = self.scale + return x * scale + + def clone(self, bias=None, scale=None): + if bias is None: + bias = self.bias + if scale is None: + scale = self.scale + return Scaling(bias=bias, scale=scale) class IdScaling(Scaling): - def __init__(self): - super().__init__(scale=1., bias=0.) + def __init__(self): + super().__init__(scale=1., bias=0.) - def offset_scaling(self, x, bias=None, scale=None): - return x + def offset_scaling(self, x, bias=None, scale=None): + return x - def std_scaling(self, x, scale=None): - return x + def std_scaling(self, x, scale=None): + return x - def inv_scaling(self, x, scale=None): - return x + def inv_scaling(self, x, scale=None): + return x - def clone(self, bias=None, scale=None): - return IdScaling() + def clone(self, bias=None, scale=None): + return IdScaling() diff --git a/brainpy/_src/math/sharding.py b/brainpy/_src/math/sharding.py index a63cedd96..cfb4f029f 100644 --- a/brainpy/_src/math/sharding.py +++ b/brainpy/_src/math/sharding.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- +from contextlib import contextmanager from functools import partial from typing import Optional, Any, Union, Sequence -from contextlib import contextmanager import jax import numpy as np @@ -11,19 +11,19 @@ from .ndarray import Array, ShardedArray, BaseArray __all__ = [ - 'device_mesh', - 'get_sharding', - 'partition_by_axname', - 'partition_by_sharding', - 'partition', - 'keep_constraint', - - 'NEU_AXIS', - 'PRE_AXIS', - 'POST_AXIS', - 'SYN_AXIS', - 'BATCH_AXIS', - 'TIME_AXIS', + 'device_mesh', + 'get_sharding', + 'partition_by_axname', + 'partition_by_sharding', + 'partition', + 'keep_constraint', + + 'NEU_AXIS', + 'PRE_AXIS', + 'POST_AXIS', + 'SYN_AXIS', + 'BATCH_AXIS', + 'TIME_AXIS', ] NEU_AXIS = 'neuron' @@ -42,7 +42,7 @@ def is_bp_array(x): - return isinstance(x, BaseArray) + return isinstance(x, BaseArray) @contextmanager @@ -50,68 +50,68 @@ def device_mesh( devices: Any, axis_names: Sequence[str], ): - global _default_mesh - _old_mesh = _default_mesh + global _default_mesh + _old_mesh = _default_mesh - devices = np.asarray(devices) - assert devices.ndim == len(axis_names) - mesh = Mesh(devices, axis_names=axis_names) + devices = np.asarray(devices) + assert devices.ndim == len(axis_names) + mesh = Mesh(devices, axis_names=axis_names) - _default_mesh = mesh + _default_mesh = mesh - try: - yield _default_mesh - finally: - _default_mesh = _old_mesh + try: + yield _default_mesh + finally: + _default_mesh = _old_mesh def _device_put(x: Union[Array, jax.Array, np.ndarray], device: Union[None, jax.Device, Sharding] = None): - """Transfers ``x`` to ``device``. - - Note that this function can only transfer ``brainpy.math.Array``, ``jax.Array``, - and ``numpy.ndarray``. Other value will be directly returned. - - Args: - x: The input array. - device: The given device. - - Returns: - A copy of ``x`` that resides on ``device``. - """ - if isinstance(x, BaseArray): - x.value = jax.device_put(x.value, device=device) - return x - else: - if isinstance(x, (jax.Array, np.ndarray)): - # wrap the data as brainpy.math.Array is important (experimental) - return ShardedArray(jax.device_put(x, device=device), keep_sharding=True) + """Transfers ``x`` to ``device``. + + Note that this function can only transfer ``brainpy.math.Array``, ``jax.Array``, + and ``numpy.ndarray``. Other value will be directly returned. + + Args: + x: The input array. + device: The given device. + + Returns: + A copy of ``x`` that resides on ``device``. + """ + if isinstance(x, BaseArray): + x.value = jax.device_put(x.value, device=device) + return x else: - return x + if isinstance(x, (jax.Array, np.ndarray)): + # wrap the data as brainpy.math.Array is important (experimental) + return ShardedArray(jax.device_put(x, device=device), keep_sharding=True) + else: + return x def get_sharding( axis_names: Optional[Sequence[str]] = None, mesh: Optional[Mesh] = None ) -> Optional[NamedSharding]: - """Get sharding according to the given axes information. - - Args: - axis_names: list of str, or tuple of str. The name for each axis in the array. - mesh: Mesh. The given device mesh. - - Returns: - The instance of NamedSharding. - """ - if axis_names is None: - return None - if mesh is None: - mesh = _default_mesh - if mesh is None: - return None - else: - axis_names = [(name if name in mesh.axis_names else None) for name in axis_names] - return NamedSharding(mesh, PartitionSpec(*axis_names)) + """Get sharding according to the given axes information. + + Args: + axis_names: list of str, or tuple of str. The name for each axis in the array. + mesh: Mesh. The given device mesh. + + Returns: + The instance of NamedSharding. + """ + if axis_names is None: + return None + if mesh is None: + mesh = _default_mesh + if mesh is None: + return None + else: + axis_names = [(name if name in mesh.axis_names else None) for name in axis_names] + return NamedSharding(mesh, PartitionSpec(*axis_names)) def partition_by_axname( @@ -119,104 +119,104 @@ def partition_by_axname( axis_names: Optional[Sequence[str]] = None, mesh: Optional[Mesh] = None ): - """Put the given arrays into the mesh devices. - - Args: - x: any. Any array. - axis_names: sequence of str. The name for each axis in the array. - mesh: Mesh. The given device mesh. - - Returns: - The re-sharded arrays. - """ - if axis_names is None: - return x - else: - for _leaf in jax.tree_util.tree_leaves(x, is_leaf=is_bp_array): - if np.ndim(_leaf) != len(axis_names): - raise ValueError(f'The input array shape is {np.shape(_leaf)}, ' - f'while the given axis names are {axis_names}. ' - f'Dimensions are mismatch.') - if mesh is None: - if _default_mesh is None: - return x - mesh = _default_mesh - sharding = get_sharding(axis_names, mesh) - if sharding is None: - return x - else: - return jax.tree_util.tree_map(partial(_device_put, device=sharding), - x, is_leaf=is_bp_array) + """Put the given arrays into the mesh devices. + + Args: + x: any. Any array. + axis_names: sequence of str. The name for each axis in the array. + mesh: Mesh. The given device mesh. + + Returns: + The re-sharded arrays. + """ + if axis_names is None: + return x + else: + for _leaf in jax.tree_util.tree_leaves(x, is_leaf=is_bp_array): + if np.ndim(_leaf) != len(axis_names): + raise ValueError(f'The input array shape is {np.shape(_leaf)}, ' + f'while the given axis names are {axis_names}. ' + f'Dimensions are mismatch.') + if mesh is None: + if _default_mesh is None: + return x + mesh = _default_mesh + sharding = get_sharding(axis_names, mesh) + if sharding is None: + return x + else: + return jax.tree_util.tree_map(partial(_device_put, device=sharding), + x, is_leaf=is_bp_array) def partition_by_sharding( x: Any, sharding: Optional[Sharding] = None, ): - """Partition inputs with the given sharding strategy. + """Partition inputs with the given sharding strategy. - Args: - x: The input arrays. It can be a pyTree of arrays. - sharding: The `jax.sharding.Sharding` instance. + Args: + x: The input arrays. It can be a pyTree of arrays. + sharding: The `jax.sharding.Sharding` instance. - Returns: - The sharded ``x``, which has been partitioned by the given sharding stragety. - """ - if sharding is None: - return x - else: - if not isinstance(sharding, Sharding): - raise TypeError(f'sharding must be instance of jax.sharding.Sharding. While we got {sharding}.') - return jax.tree_util.tree_map(partial(_device_put, device=sharding), - x, - is_leaf=is_bp_array) + Returns: + The sharded ``x``, which has been partitioned by the given sharding stragety. + """ + if sharding is None: + return x + else: + if not isinstance(sharding, Sharding): + raise TypeError(f'sharding must be instance of jax.sharding.Sharding. While we got {sharding}.') + return jax.tree_util.tree_map(partial(_device_put, device=sharding), + x, + is_leaf=is_bp_array) def partition( x: Any, sharding: Optional[Union[Sequence[str], jax.Device, Sharding]] = None, ): - """Partition the input arrays onto devices by the given sharding strategies. - - Args: - x: Any input arrays. It can also be a PyTree of arrays. - sharding: The sharding strategy. - - Returns: - The partitioned arrays. - Notably, the - """ - if sharding is None: - return x - elif isinstance(sharding, (jax.Device, Sharding)): - return jax.tree_util.tree_map(partial(_device_put, device=sharding), - x, is_leaf=is_bp_array) - elif isinstance(sharding, (tuple, list)) and any([isinstance(s, str) for s in sharding]): - return partition_by_axname(x, sharding) - else: - raise TypeError('"sharding" only supports jax.sharding.Sharding or a sequence of axis names. \n' - f'But we got {sharding}') + """Partition the input arrays onto devices by the given sharding strategies. + + Args: + x: Any input arrays. It can also be a PyTree of arrays. + sharding: The sharding strategy. + + Returns: + The partitioned arrays. + Notably, the + """ + if sharding is None: + return x + elif isinstance(sharding, (jax.Device, Sharding)): + return jax.tree_util.tree_map(partial(_device_put, device=sharding), + x, is_leaf=is_bp_array) + elif isinstance(sharding, (tuple, list)) and any([isinstance(s, str) for s in sharding]): + return partition_by_axname(x, sharding) + else: + raise TypeError('"sharding" only supports jax.sharding.Sharding or a sequence of axis names. \n' + f'But we got {sharding}') def _keep_constraint(x: Any): - if isinstance(x, BaseArray): - x = x.value - if isinstance(x, jax.Array): - if hasattr(x, 'sharding'): - if x.sharding is not None: - return jax.lax.with_sharding_constraint(x, x.sharding) - return x - else: - return x + if isinstance(x, BaseArray): + x = x.value + if isinstance(x, jax.Array): + if hasattr(x, 'sharding'): + if x.sharding is not None: + return jax.lax.with_sharding_constraint(x, x.sharding) + return x + else: + return x def keep_constraint(x: Any): - """Keep the sharding constraint of the given inputs during computation. + """Keep the sharding constraint of the given inputs during computation. - Args: - x: Any. + Args: + x: Any. - Returns: - constraint_x: Same as ``x``. - """ - return jax.tree_util.tree_map(_keep_constraint, x, is_leaf=is_bp_array) + Returns: + constraint_x: Same as ``x``. + """ + return jax.tree_util.tree_map(_keep_constraint, x, is_leaf=is_bp_array) diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index eec5f53c0..68439c742 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,7 +1,5 @@ # from ._coo_mv import * -from .csr_mv import * from .csr_mm import * -from .utils import * +from .csr_mv import * from .jax_prim import * - - +from .utils import * diff --git a/brainpy/_src/math/sparse/coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py index d03679e30..d8ac7d2d0 100644 --- a/brainpy/_src/math/sparse/coo_mv.py +++ b/brainpy/_src/math/sparse/coo_mv.py @@ -3,14 +3,13 @@ from typing import Union, Tuple +import brainevent from jax import numpy as jnp from brainpy._src.math.ndarray import BaseArray as Array -import brainevent - __all__ = [ - 'coomv', + 'coomv', ] @@ -26,50 +25,50 @@ def coomv( transpose: bool = False, method: str = 'cusparse' ): - """Product of COO sparse matrix and a dense vector using cuSPARSE algorithm. + """Product of COO sparse matrix and a dense vector using cuSPARSE algorithm. - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. - Parameters:: + Parameters:: - data: ndarray, float - An array of shape ``(nse,)``. - row: ndarray - An array of shape ``(nse,)``. - col: ndarray - An array of shape ``(nse,)`` and dtype ``row.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` and - dtype ``data.dtype``. - shape: tuple of int - The shape of the sparse matrix. - rows_sorted: bool - Row index are sorted. - cols_sorted: bool - Column index are sorted. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - method: str - The method used to compute the matrix-vector multiplication. + data: ndarray, float + An array of shape ``(nse,)``. + row: ndarray + An array of shape ``(nse,)``. + col: ndarray + An array of shape ``(nse,)`` and dtype ``row.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` and + dtype ``data.dtype``. + shape: tuple of int + The shape of the sparse matrix. + rows_sorted: bool + Row index are sorted. + cols_sorted: bool + Column index are sorted. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + method: str + The method used to compute the matrix-vector multiplication. - Returns:: + Returns:: - y: ndarray - An array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - if isinstance(data, Array): - data = data.value - if isinstance(row, Array): - row = row.value - if isinstance(col, Array): - col = col.value - if isinstance(vector, Array): - vector = vector.value - csr = brainevent.COO((data, row, col), shape=shape) - if transpose: - return vector @ csr - else: - return csr @ vector + y: ndarray + An array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + if isinstance(data, Array): + data = data.value + if isinstance(row, Array): + row = row.value + if isinstance(col, Array): + col = col.value + if isinstance(vector, Array): + vector = vector.value + csr = brainevent.COO((data, row, col), shape=shape) + if transpose: + return vector @ csr + else: + return csr @ vector diff --git a/brainpy/_src/math/sparse/csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py index 3978953df..27bf9fc4c 100644 --- a/brainpy/_src/math/sparse/csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -9,7 +9,7 @@ from brainpy._src.math.ndarray import BaseArray as Array __all__ = [ - 'csrmv', + 'csrmv', ] @@ -22,54 +22,53 @@ def csrmv( shape: Tuple[int, int], transpose: bool = False, ): - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. + """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. - Parameters:: + Parameters:: - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - method: str - The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. - The candidate methods are: + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + method: str + The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. + The candidate methods are: - - ``None``: default using Taichi kernel. - - ``cusparse``: using cuSPARSE library. - - ``scalar``: - - ``vector``: - - ``adaptive``: + - ``None``: default using Taichi kernel. + - ``cusparse``: using cuSPARSE library. + - ``scalar``: + - ``vector``: + - ``adaptive``: - Returns:: - - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - if isinstance(data, Array): - data = data.value - if isinstance(indices, Array): - indices = indices.value - if isinstance(indptr, Array): - indptr = indptr.value - if isinstance(vector, Array): - vector = vector.value - csr = brainevent.CSR((data, indices, indptr), shape=shape) - if transpose: - return vector @ csr - else: - return csr @ vector + Returns:: + y : ndarry + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + if isinstance(data, Array): + data = data.value + if isinstance(indices, Array): + indices = indices.value + if isinstance(indptr, Array): + indptr = indptr.value + if isinstance(vector, Array): + vector = vector.value + csr = brainevent.CSR((data, indices, indptr), shape=shape) + if transpose: + return vector @ csr + else: + return csr @ vector diff --git a/brainpy/_src/math/sparse/jax_prim.py b/brainpy/_src/math/sparse/jax_prim.py index 316cbe245..40da4a64b 100644 --- a/brainpy/_src/math/sparse/jax_prim.py +++ b/brainpy/_src/math/sparse/jax_prim.py @@ -6,9 +6,8 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import BaseArray as Array - __all__ = [ - 'seg_matmul', + 'seg_matmul', ] @@ -16,150 +15,148 @@ def _matmul_with_left_sparse( sparse: Dict, dense: Union[Array, jnp.ndarray] ): - r"""Matrix multiplication with sparse matrix on the left. - - .. math:: - - Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}} - - Parameters:: - - sparse: dict - The sparse matrix with shape of :math:`(N, M)`. - dense: ArrayType - The dense matrix with the shape of :math:`(M, K)`. - - Returns:: - - matrix - A tensor the the shape of :math:`(N, K)`. - """ - assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' - values = sparse['data'] - rows, cols = sparse['index'] - shape = sparse['shape'] - if len(shape) != 2: - raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') - values = as_jax(values) - rows = as_jax(rows) - cols = as_jax(cols) - dense = as_jax(dense) - B = dense.take(cols, axis=0) - if B.ndim == 2: - prod = B * jnp.reshape(values, (-1, 1)) - else: - prod = B * values - return ops.segment_sum(prod, rows, shape[0]) + r"""Matrix multiplication with sparse matrix on the left. + + .. math:: + + Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}} + + Parameters:: + + sparse: dict + The sparse matrix with shape of :math:`(N, M)`. + dense: ArrayType + The dense matrix with the shape of :math:`(M, K)`. + + Returns:: + + matrix + A tensor the the shape of :math:`(N, K)`. + """ + assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' + values = sparse['data'] + rows, cols = sparse['index'] + shape = sparse['shape'] + if len(shape) != 2: + raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') + values = as_jax(values) + rows = as_jax(rows) + cols = as_jax(cols) + dense = as_jax(dense) + B = dense.take(cols, axis=0) + if B.ndim == 2: + prod = B * jnp.reshape(values, (-1, 1)) + else: + prod = B * values + return ops.segment_sum(prod, rows, shape[0]) def _matmul_with_right_sparse( dense: Union[Array, jnp.ndarray], sparse: Dict ): - r"""Matrix multiplication with sparse matrix on the left. - - .. math:: - - Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}} - - Parameters:: - - dense: ArrayType - The dense matrix with the shape of :math:`(N, M)`. - sparse: dict - The sparse matrix with shape of :math:`(M, K)`. - - Returns:: - - matrix - A tensor the the shape of :math:`(N, K)`. - """ - assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' - values = sparse['data'] - rows, cols = sparse['index'] - shape = sparse['shape'] - if len(shape) != 2: - raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') - values = as_jax(values) - rows = as_jax(rows) - cols = as_jax(cols) - dense = as_jax(dense) - if dense.ndim == 2: - A = dense[:, rows] - prod = (A * values).T - res = ops.segment_sum(prod, cols, shape[1]).T - else: - prod = dense[rows] * values - res = ops.segment_sum(prod, cols, shape[1]) - return res + r"""Matrix multiplication with sparse matrix on the left. + + .. math:: + + Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}} + + Parameters:: + + dense: ArrayType + The dense matrix with the shape of :math:`(N, M)`. + sparse: dict + The sparse matrix with shape of :math:`(M, K)`. + + Returns:: + + matrix + A tensor the the shape of :math:`(N, K)`. + """ + assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' + values = sparse['data'] + rows, cols = sparse['index'] + shape = sparse['shape'] + if len(shape) != 2: + raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') + values = as_jax(values) + rows = as_jax(rows) + cols = as_jax(cols) + dense = as_jax(dense) + if dense.ndim == 2: + A = dense[:, rows] + prod = (A * values).T + res = ops.segment_sum(prod, cols, shape[1]).T + else: + prod = dense[rows] * values + res = ops.segment_sum(prod, cols, shape[1]) + return res def seg_matmul(A, B): - r"""Sparse matrix multiplication. - - .. math:: - - y = A @ B - - where :math:`A` or :math:`B` is a sparse matrix. - :math:`A` and :math:`B` cannot be both sparse. - - Examples:: - - >>> import brainpy.math as bm - - 1. when the left matrix :math:`A` is a sparse matrix with the shape of :math:`(N, M)`, - - >>> # A is a sparse matrix (3, 4): - >>> # [[0, 2, 0, 4], - >>> # [1, 0, 0, 0], - >>> # [0, 3, 0, 2]] - >>> values = bm.asarray([2, 4, 1, 3, 2]) - >>> rows = bm.asarray([0, 0, 1, 2, 2]) - >>> cols = bm.asarray([1, 3, 0, 1, 3]) - >>> sparse = {'data': values, 'index': (rows, cols), 'shape': (3, 4)} - >>> B = bm.arange(4) - >>> bm.sparse.sparse_matmul(sparse, B) - ArrayType([14, 0, 9], dtype=int32) - >>> B = bm.random.rand(4, 3) - >>> bm.sparse.sparse_matmul(sparse, B) - ArrayType([[3.8331761 , 1.3708692 , 4.510223 ], - [0.9960836 , 0.37550318, 0.7370341 ], - [2.3700516 , 0.7574289 , 4.1124535 ]], dtype=float32) - - 2. when the right matrix :math:`B` is a sparse matrix with the shape of :math:`(M, K)`, - - >>> A = bm.arange(3) - >>> bm.sparse.sparse_matmul(A, sparse) - ArrayType([1, 6, 0, 4], dtype=int32) - >>> A = bm.random.rand(2, 3) - >>> bm.sparse.sparse_matmul(A, sparse) - ArrayType([[0.438388 , 1.4346815 , 0. , 2.361964 ], - [0.9171978 , 1.1214957 , 0. , 0.90534496]], dtype=float32) - - Parameters:: - - A: tensor, sequence - The dense or sparse matrix with the shape of :math:`(N, M)`. - B: tensor, sequence - The dense or sparse matrix with the shape of :math:`(M, K)`. - - Returns:: - - results: ArrayType - The tensor with the shape of :math:`(N, K)`. - """ - if isinstance(A, dict): - if not isinstance(B, (Array, jnp.ndarray)): - raise ValueError('A and B cannot be both sparse. \n' - f'A:\n{A}\n' - f'B:\n{B}') - return _matmul_with_left_sparse(A, B) - else: - if not isinstance(B, dict): - raise ValueError('A and B cannot be both dense. \n' - f'A:\n{A}\n' - f'B:\n{B}') - return _matmul_with_right_sparse(A, B) - - + r"""Sparse matrix multiplication. + + .. math:: + + y = A @ B + + where :math:`A` or :math:`B` is a sparse matrix. + :math:`A` and :math:`B` cannot be both sparse. + + Examples:: + + >>> import brainpy.math as bm + + 1. when the left matrix :math:`A` is a sparse matrix with the shape of :math:`(N, M)`, + + >>> # A is a sparse matrix (3, 4): + >>> # [[0, 2, 0, 4], + >>> # [1, 0, 0, 0], + >>> # [0, 3, 0, 2]] + >>> values = bm.asarray([2, 4, 1, 3, 2]) + >>> rows = bm.asarray([0, 0, 1, 2, 2]) + >>> cols = bm.asarray([1, 3, 0, 1, 3]) + >>> sparse = {'data': values, 'index': (rows, cols), 'shape': (3, 4)} + >>> B = bm.arange(4) + >>> bm.sparse.sparse_matmul(sparse, B) + ArrayType([14, 0, 9], dtype=int32) + >>> B = bm.random.rand(4, 3) + >>> bm.sparse.sparse_matmul(sparse, B) + ArrayType([[3.8331761 , 1.3708692 , 4.510223 ], + [0.9960836 , 0.37550318, 0.7370341 ], + [2.3700516 , 0.7574289 , 4.1124535 ]], dtype=float32) + + 2. when the right matrix :math:`B` is a sparse matrix with the shape of :math:`(M, K)`, + + >>> A = bm.arange(3) + >>> bm.sparse.sparse_matmul(A, sparse) + ArrayType([1, 6, 0, 4], dtype=int32) + >>> A = bm.random.rand(2, 3) + >>> bm.sparse.sparse_matmul(A, sparse) + ArrayType([[0.438388 , 1.4346815 , 0. , 2.361964 ], + [0.9171978 , 1.1214957 , 0. , 0.90534496]], dtype=float32) + + Parameters:: + + A: tensor, sequence + The dense or sparse matrix with the shape of :math:`(N, M)`. + B: tensor, sequence + The dense or sparse matrix with the shape of :math:`(M, K)`. + + Returns:: + + results: ArrayType + The tensor with the shape of :math:`(N, K)`. + """ + if isinstance(A, dict): + if not isinstance(B, (Array, jnp.ndarray)): + raise ValueError('A and B cannot be both sparse. \n' + f'A:\n{A}\n' + f'B:\n{B}') + return _matmul_with_left_sparse(A, B) + else: + if not isinstance(B, dict): + raise ValueError('A and B cannot be both dense. \n' + f'A:\n{A}\n' + f'B:\n{B}') + return _matmul_with_right_sparse(A, B) diff --git a/brainpy/_src/math/sparse/utils.py b/brainpy/_src/math/sparse/utils.py index 0c7f58348..643fe70f9 100644 --- a/brainpy/_src/math/sparse/utils.py +++ b/brainpy/_src/math/sparse/utils.py @@ -4,48 +4,51 @@ from functools import partial from typing import Tuple - import jax import numpy as np -from brainpy._src.math.interoperability import as_jax from jax import core, numpy as jnp from jax import lax from jax.interpreters import batching from jax.interpreters import mlir, ad from jax.tree_util import tree_flatten, tree_unflatten from jaxlib import gpu_sparse + +from brainpy._src.math.interoperability import as_jax + if jax.__version__ >= '0.5.0': - from jax.extend.core import Primitive + from jax.extend.core import Primitive else: - from jax.core import Primitive + from jax.core import Primitive __all__ = [ - 'coo_to_csr', - 'csr_to_coo', - 'csr_to_dense' + 'coo_to_csr', + 'csr_to_coo', + 'csr_to_dense' ] + def _general_batching_rule(prim, args, axes, **kwargs): - batch_axes, batch_args, non_batch_args = [], {}, {} - for ax_i, ax in enumerate(axes): - if ax is None: - non_batch_args[f'ax{ax_i}'] = args[ax_i] - else: - batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0) - batch_axes.append(ax_i) - - def f(_, x): - pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) - for i in range(len(axes))]) - return 0, prim.bind(*pars, **kwargs) - - _, outs = lax.scan(f, 0, batch_args) - out_vals, out_tree = tree_flatten(outs) - out_dim = tree_unflatten(out_tree, (0,) * len(out_vals)) - return outs, out_dim + batch_axes, batch_args, non_batch_args = [], {}, {} + for ax_i, ax in enumerate(axes): + if ax is None: + non_batch_args[f'ax{ax_i}'] = args[ax_i] + else: + batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0) + batch_axes.append(ax_i) + + def f(_, x): + pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) + for i in range(len(axes))]) + return 0, prim.bind(*pars, **kwargs) + + _, outs = lax.scan(f, 0, batch_args) + out_vals, out_tree = tree_flatten(outs) + out_dim = tree_unflatten(out_tree, (0,) * len(out_vals)) + return outs, out_dim + def _register_general_batching(prim): - batching.primitive_batchers[prim] = partial(_general_batching_rule, prim) + batching.primitive_batchers[prim] = partial(_general_batching_rule, prim) def coo_to_csr( @@ -54,35 +57,35 @@ def coo_to_csr( *, num_row: int ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """convert pre_ids, post_ids to (indices, indptr).""" - pre_ids = as_jax(pre_ids) - post_ids = as_jax(post_ids) + """convert pre_ids, post_ids to (indices, indptr).""" + pre_ids = as_jax(pre_ids) + post_ids = as_jax(post_ids) - # sorting - sort_ids = jnp.argsort(pre_ids, kind='stable') - post_ids = post_ids[sort_ids] + # sorting + sort_ids = jnp.argsort(pre_ids, kind='stable') + post_ids = post_ids[sort_ids] - indices = post_ids - unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) - final_pre_count = jnp.zeros(num_row) - final_pre_count[unique_pre_ids] = pre_count - indptr = final_pre_count.cumsum() - indptr = jnp.insert(indptr, 0, 0) - return indices, indptr + indices = post_ids + unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True) + final_pre_count = jnp.zeros(num_row) + final_pre_count[unique_pre_ids] = pre_count + indptr = final_pre_count.cumsum() + indptr = jnp.insert(indptr, 0, 0) + return indices, indptr def csr_to_coo( indices: jnp.ndarray, indptr: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Given CSR (indices, indptr) return COO (row, col)""" - indices = as_jax(indices) - indptr = as_jax(indptr) - return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices + """Given CSR (indices, indptr) return COO (row, col)""" + indices = as_jax(indices) + indptr = as_jax(indptr) + return jnp.cumsum(jnp.zeros_like(indices).at[indptr].add(1)) - 1, indices def csr_to_csc(): - pass + pass def coo_to_dense( @@ -92,7 +95,7 @@ def coo_to_dense( *, shape: Tuple[int, int] ) -> jnp.ndarray: - pass + pass def csr_to_dense( @@ -102,79 +105,79 @@ def csr_to_dense( *, shape: Tuple[int, int] ) -> jnp.ndarray: - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - return csr_to_dense_p.bind(data, indices, indptr, shape=shape) + data = as_jax(data) + indices = as_jax(indices) + indptr = as_jax(indptr) + return csr_to_dense_p.bind(data, indices, indptr, shape=shape) def _coo_extract(row, col, mat): - """Extract values of dense matrix mat at given COO indices.""" - return mat[row, col] + """Extract values of dense matrix mat at given COO indices.""" + return mat[row, col] def _csr_extract(indices, indptr, mat): - """Extract values of dense matrix mat at given CSR indices.""" - return _coo_extract(*csr_to_coo(indices, indptr), mat) + """Extract values of dense matrix mat at given CSR indices.""" + return _coo_extract(*csr_to_coo(indices, indptr), mat) def _coo_todense(data, row, col, *, shape): - """Convert CSR-format sparse matrix to a dense matrix. + """Convert CSR-format sparse matrix to a dense matrix. - Args: - data : array of shape ``(nse,)``. - row : array of shape ``(nse,)`` - col : array of shape ``(nse,)`` and dtype ``row.dtype`` - shape : COOInfo object containing matrix metadata + Args: + data : array of shape ``(nse,)``. + row : array of shape ``(nse,)`` + col : array of shape ``(nse,)`` and dtype ``row.dtype`` + shape : COOInfo object containing matrix metadata - Returns: - mat : array with specified shape and dtype matching ``data`` - """ - return jnp.zeros(shape, data.dtype).at[row, col].add(data) + Returns: + mat : array with specified shape and dtype matching ``data`` + """ + return jnp.zeros(shape, data.dtype).at[row, col].add(data) def _csr_to_dense_impl(data, indices, indptr, *, shape): - return _coo_todense(data, *csr_to_coo(indices, indptr), shape=shape) + return _coo_todense(data, *csr_to_coo(indices, indptr), shape=shape) def _csr_to_dense_abstract_eval(data, indices, indptr, *, shape): - assert data.ndim == indices.ndim == indptr.ndim == 1 - assert indices.dtype == indptr.dtype - assert data.shape == indices.shape - assert indptr.shape[0] == shape[0] + 1 - return core.ShapedArray(shape, data.dtype) + assert data.ndim == indices.ndim == indptr.ndim == 1 + assert indices.dtype == indptr.dtype + assert data.shape == indices.shape + assert indptr.shape[0] == shape[0] + 1 + return core.ShapedArray(shape, data.dtype) _csr_to_dense_lowering = mlir.lower_fun(_csr_to_dense_impl, multiple_results=False) def _csr_to_dense_gpu_lowering(ctx, data, indices, indptr, *, shape): - data_aval, indices_aval, _ = ctx.avals_in - dtype = data_aval.dtype - if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): - warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. " - "Falling back to default implementation.", - UserWarning) - return _csr_to_dense_lowering(ctx, data, indices, indptr, shape=shape) - return [gpu_sparse.cuda_csr_todense(data, indices, indptr, - shape=shape, data_dtype=dtype, - index_dtype=indices_aval.dtype)] + data_aval, indices_aval, _ = ctx.avals_in + dtype = data_aval.dtype + if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): + warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for dtype={dtype}. " + "Falling back to default implementation.", + UserWarning) + return _csr_to_dense_lowering(ctx, data, indices, indptr, shape=shape) + return [gpu_sparse.cuda_csr_todense(data, indices, indptr, + shape=shape, data_dtype=dtype, + index_dtype=indices_aval.dtype)] def _csr_to_dense_jvp(data_dot, data, indices, indptr, *, shape): - return csr_to_dense(data_dot, indices, indptr, shape=shape) + return csr_to_dense(data_dot, indices, indptr, shape=shape) def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape): - # Note: we assume that transpose has the same sparsity pattern. - # Can we check this? - assert ad.is_undefined_primal(data) - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices") - assert ct.shape == shape - assert indices.aval.dtype == indptr.aval.dtype - assert ct.dtype == data.aval.dtype - return _csr_extract(indices, indptr, ct), indices, indptr + # Note: we assume that transpose has the same sparsity pattern. + # Can we check this? + assert ad.is_undefined_primal(data) + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices") + assert ct.shape == shape + assert indices.aval.dtype == indptr.aval.dtype + assert ct.dtype == data.aval.dtype + return _csr_extract(indices, indptr, ct), indices, indptr csr_to_dense_p = Primitive('csr_to_dense') @@ -185,4 +188,4 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape): mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering) _register_general_batching(csr_to_dense_p) if gpu_sparse.cuda_is_supported: - mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda') + mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda') diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py index a6797a32d..41520b02b 100644 --- a/brainpy/_src/math/surrogate/_one_input.py +++ b/brainpy/_src/math/surrogate/_one_input.py @@ -10,89 +10,89 @@ from brainpy._src.math.ndarray import Array __all__ = [ - 'sigmoid', - 'piecewise_quadratic', - 'piecewise_exp', - 'soft_sign', - 'arctan', - 'nonzero_sign_log', - 'erf', - 'piecewise_leaky_relu', - 'squarewave_fourier_series', - 's2nn', - 'q_pseudo_spike', - 'leaky_relu', - 'log_tailed_relu', - 'relu_grad', - 'gaussian_grad', - 'inv_square_grad', - 'multi_gaussian_grad', - 'slayer_grad', + 'sigmoid', + 'piecewise_quadratic', + 'piecewise_exp', + 'soft_sign', + 'arctan', + 'nonzero_sign_log', + 'erf', + 'piecewise_leaky_relu', + 'squarewave_fourier_series', + 's2nn', + 'q_pseudo_spike', + 'leaky_relu', + 'log_tailed_relu', + 'relu_grad', + 'gaussian_grad', + 'inv_square_grad', + 'multi_gaussian_grad', + 'slayer_grad', ] class Surrogate(object): - """The base surrograte gradient function.""" + """The base surrograte gradient function.""" - def __call__(self, *args, **kwargs): - raise NotImplementedError + def __call__(self, *args, **kwargs): + raise NotImplementedError - def __repr__(self): - return f'{self.__class__.__name__}()' + def __repr__(self): + return f'{self.__class__.__name__}()' class _OneInpSurrogate(Surrogate): - def __init__(self, forward_use_surrogate=False): - self.forward_use_surrogate = forward_use_surrogate - self._true_call_ = jax.custom_gradient(self.call) + def __init__(self, forward_use_surrogate=False): + self.forward_use_surrogate = forward_use_surrogate + self._true_call_ = jax.custom_gradient(self.call) - def __call__(self, x: jax.Array): - return self._true_call_(as_jax(x)) + def __call__(self, x: jax.Array): + return self._true_call_(as_jax(x)) - def call(self, x): - """Call the function for surrogate gradient propagation.""" - y = self.surrogate_fun(x) if self.forward_use_surrogate else self.true_fun(x) - return y, functools.partial(self.surrogate_grad, x=x) + def call(self, x): + """Call the function for surrogate gradient propagation.""" + y = self.surrogate_fun(x) if self.forward_use_surrogate else self.true_fun(x) + return y, functools.partial(self.surrogate_grad, x=x) - def true_fun(self, x): - """The original true function.""" - return jnp.asarray(x >= 0, dtype=x.dtype) + def true_fun(self, x): + """The original true function.""" + return jnp.asarray(x >= 0, dtype=x.dtype) - def surrogate_fun(self, x): - """The surrogate function.""" - raise NotImplementedError + def surrogate_fun(self, x): + """The surrogate function.""" + raise NotImplementedError - def surrogate_grad(self, dz, x): - """The gradient for the surrogate function.""" - raise NotImplementedError + def surrogate_grad(self, dz, x): + """The gradient for the surrogate function.""" + raise NotImplementedError - def __repr__(self): - return f'{self.__class__.__name__}(forward_use_surrogate={self.forward_use_surrogate})' + def __repr__(self): + return f'{self.__class__.__name__}(forward_use_surrogate={self.forward_use_surrogate})' class Sigmoid(_OneInpSurrogate): - """Spike function with the sigmoid-shaped surrogate gradient. + """Spike function with the sigmoid-shaped surrogate gradient. - See Also:: + See Also:: - sigmoid + sigmoid - """ + """ - def __init__(self, alpha: float = 4., forward_use_surrogate=False): - super().__init__(forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha: float = 4., forward_use_surrogate=False): + super().__init__(forward_use_surrogate) + self.alpha = alpha - def surrogate_fun(self, x): - return sci.special.expit(x) + def surrogate_fun(self, x): + return sci.special.expit(x) - def surrogate_grad(self, dz, x): - sgax = sci.special.expit(as_jax(x) * self.alpha) - dx = as_jax(dz) * (1. - sgax) * sgax * self.alpha - return dx + def surrogate_grad(self, dz, x): + sgax = sci.special.expit(as_jax(x) * self.alpha) + dx = as_jax(dz) * (1. - sgax) * sgax * self.alpha + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def sigmoid( @@ -100,89 +100,89 @@ def sigmoid( alpha: float = 4., origin: bool = False, ): - r"""Spike function with the sigmoid-shaped surrogate gradient. + r"""Spike function with the sigmoid-shaped surrogate gradient. - If `origin=False`, return the forward function: + If `origin=False`, return the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} + g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x) + g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x) - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-2, 2, 1000) - >>> for alpha in [1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-2, 2, 1000) + >>> for alpha in [1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. - """ - return Sigmoid(alpha=alpha, forward_use_surrogate=origin)(x) + out: jax.Array + The spiking state. + """ + return Sigmoid(alpha=alpha, forward_use_surrogate=origin)(x) class PiecewiseQuadratic(_OneInpSurrogate): - """Judge spiking state with a piecewise quadratic function. + """Judge spiking state with a piecewise quadratic function. - See Also:: + See Also:: - piecewise_quadratic + piecewise_quadratic - """ + """ - def __init__(self, alpha: float = 1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha: float = 1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate) + self.alpha = alpha - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < -1 / self.alpha, - 0., - jnp.where(x > 1 / self.alpha, - 1., - (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < -1 / self.alpha, + 0., + jnp.where(x > 1 / self.alpha, + 1., + (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) + return z - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-(self.alpha * x) ** 2 + self.alpha)) - return dx + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., dz * (-(self.alpha * x) ** 2 + self.alpha)) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def piecewise_quadratic( @@ -190,100 +190,100 @@ def piecewise_quadratic( alpha: float = 1., origin: bool = False ): - r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. + r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = - \begin{cases} - 0, & x < -\frac{1}{\alpha} \\ - -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ - 1, & x > \frac{1}{\alpha} \\ - \end{cases} + g(x) = + \begin{cases} + 0, & x < -\frac{1}{\alpha} \\ + -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ + 1, & x > \frac{1}{\alpha} \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = - \begin{cases} - 0, & |x| > \frac{1}{\alpha} \\ - -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} - \end{cases} + g'(x) = + \begin{cases} + 0, & |x| > \frac{1}{\alpha} \\ + -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} + \end{cases} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805. - .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. - .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. - """ - return PiecewiseQuadratic(alpha=alpha, forward_use_surrogate=origin)(x) + .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805. + .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. + .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. + """ + return PiecewiseQuadratic(alpha=alpha, forward_use_surrogate=origin)(x) class PiecewiseExp(_OneInpSurrogate): - """Judge spiking state with a piecewise exponential function. + """Judge spiking state with a piecewise exponential function. - See Also:: + See Also:: - piecewise_exp - """ + piecewise_exp + """ - def __init__(self, alpha: float = 1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha: float = 1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate) + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) + return dx * as_jax(dz) - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def piecewise_exp( @@ -291,90 +291,90 @@ def piecewise_exp( alpha: float = 1., origin: bool = False ): - r"""Judge spiking state with a piecewise exponential function [1]_. + r"""Judge spiking state with a piecewise exponential function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \begin{cases} - \frac{1}{2}e^{\alpha x}, & x < 0 \\ - 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 - \end{cases} + g(x) = \begin{cases} + \frac{1}{2}e^{\alpha x}, & x < 0 \\ + 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{\alpha}{2}e^{-\alpha |x|} + g'(x) = \frac{\alpha}{2}e^{-\alpha |x|} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. - """ - return PiecewiseExp(alpha=alpha, forward_use_surrogate=origin)(x) + .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. + """ + return PiecewiseExp(alpha=alpha, forward_use_surrogate=origin)(x) class SoftSign(_OneInpSurrogate): - """Judge spiking state with a soft sign function. + """Judge spiking state with a soft sign function. - See Also:: + See Also:: - soft_sign - """ + soft_sign + """ - def __init__(self, alpha=1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 + return dx * as_jax(dz) - def surrogate_fun(self, x): - x = as_jax(x) - return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 + def surrogate_fun(self, x): + x = as_jax(x) + return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def soft_sign( @@ -382,85 +382,85 @@ def soft_sign( alpha: float = 1., origin: bool = False ): - r"""Judge spiking state with a soft sign function. + r"""Judge spiking state with a soft sign function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) - = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) + g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) + = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}} + g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ - return SoftSign(alpha=alpha, forward_use_surrogate=origin)(x) + """ + return SoftSign(alpha=alpha, forward_use_surrogate=origin)(x) class Arctan(_OneInpSurrogate): - """Judge spiking state with an arctan function. + """Judge spiking state with an arctan function. - See Also:: + See Also:: - arctan - """ + arctan + """ - def __init__(self, alpha=1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) + return dx * as_jax(dz) - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def arctan( @@ -468,84 +468,84 @@ def arctan( alpha: float = 1., origin: bool = False ): - r"""Judge spiking state with an arctan function. + r"""Judge spiking state with an arctan function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} + g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)} + g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ - return Arctan(alpha=alpha, forward_use_surrogate=origin)(x) + """ + return Arctan(alpha=alpha, forward_use_surrogate=origin)(x) class NonzeroSignLog(_OneInpSurrogate): - """Judge spiking state with a nonzero sign log function. + """Judge spiking state with a nonzero sign log function. - See Also:: + See Also:: - nonzero_sign_log - """ + nonzero_sign_log + """ - def __init__(self, alpha=1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = as_jax(dz) / (1 / self.alpha + jnp.abs(x)) - return dx + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = as_jax(dz) / (1 / self.alpha + jnp.abs(x)) + return dx - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def nonzero_sign_log( @@ -553,97 +553,97 @@ def nonzero_sign_log( alpha: float = 1., origin: bool = False ): - r"""Judge spiking state with a nonzero sign log function. + r"""Judge spiking state with a nonzero sign log function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) + g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) - where + where - .. math:: + .. math:: - \begin{split}\mathrm{NonzeroSign}(x) = - \begin{cases} - 1, & x \geq 0 \\ - -1, & x < 0 \\ - \end{cases}\end{split} + \begin{split}\mathrm{NonzeroSign}(x) = + \begin{cases} + 1, & x \geq 0 \\ + -1, & x < 0 \\ + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|} + g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|} - This surrogate function has the advantage of low computation cost during the backward. + This surrogate function has the advantage of low computation cost during the backward. - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ - return NonzeroSignLog(alpha=alpha, forward_use_surrogate=origin)(x) + """ + return NonzeroSignLog(alpha=alpha, forward_use_surrogate=origin)(x) class ERF(_OneInpSurrogate): - """Judge spiking state with an erf function. + """Judge spiking state with an erf function. - See Also:: + See Also:: - erf - """ + erf + """ - def __init__(self, alpha=1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) + return dx * as_jax(dz) - def surrogate_fun(self, x): - x = as_jax(x) - return sci.special.erf(-self.alpha * x) * 0.5 + def surrogate_fun(self, x): + x = as_jax(x) + return sci.special.erf(-self.alpha * x) * 0.5 - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def erf( @@ -651,100 +651,100 @@ def erf( alpha: float = 1., origin: bool = False ): - r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. + r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split} - g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ - &= \frac{1}{2} \text{erfc}(-\alpha x) \\ - &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt - \end{split} + \begin{split} + g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ + &= \frac{1}{2} \text{erfc}(-\alpha x) \\ + &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt + \end{split} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2} + g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. + .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. - """ - return ERF(alpha=alpha, forward_use_surrogate=origin)(x) + """ + return ERF(alpha=alpha, forward_use_surrogate=origin)(x) class PiecewiseLeakyRelu(_OneInpSurrogate): - """Judge spiking state with a piecewise leaky relu function. + """Judge spiking state with a piecewise leaky relu function. - See Also:: + See Also:: - piecewise_leaky_relu - """ + piecewise_leaky_relu + """ - def __init__(self, c=0.01, w=1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.c = c - self.w = w + def __init__(self, c=0.01, w=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.c = c + self.w = w - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < -self.w, - self.c * x + self.c * self.w, - jnp.where(x > self.w, - self.c * x - self.c * self.w + 1, - 0.5 * x / self.w + 0.5)) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < -self.w, + self.c * x + self.c * self.w, + jnp.where(x > self.w, + self.c * x - self.c * self.w + 1, + 0.5 * x / self.w + 0.5)) + return z - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) + return dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(c={self.c}, w={self.w})' + def __repr__(self): + return f'{self.__class__.__name__}(c={self.c}, w={self.w})' def piecewise_leaky_relu( @@ -753,118 +753,118 @@ def piecewise_leaky_relu( w: float = 1., origin: bool = False ): - r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. + r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - cx + cw, & x < -w \\ - \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ - cx - cw + 1, & x > w \\ - \end{cases}\end{split} - - Backward function: - - .. math:: - - \begin{split}g'(x) = - \begin{cases} - \frac{1}{w}, & |x| \leq w \\ - c, & |x| > w - \end{cases}\end{split} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for c in [0.01, 0.05, 0.1]: - >>> for w in [1., 2.]: - >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - c: float - When :math:`|x| > w` the gradient is `c`. - w: float - When :math:`|x| <= w` the gradient is `1 / w`. - origin: bool - Whether to compute the original function as the feedfoward output. - - Returns:: - - out: jax.Array - The spiking state. + \begin{split}g(x) = + \begin{cases} + cx + cw, & x < -w \\ + \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ + cx - cw + 1, & x > w \\ + \end{cases}\end{split} - References:: + Backward function: - .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450. - .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318. - .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372. - .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58. - .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525. - .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. + .. math:: - """ - return PiecewiseLeakyRelu(c=c, w=w)(x) + \begin{split}g'(x) = + \begin{cases} + \frac{1}{w}, & |x| \leq w \\ + c, & |x| > w + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for c in [0.01, 0.05, 0.1]: + >>> for w in [1., 2.]: + >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + c: float + When :math:`|x| > w` the gradient is `c`. + w: float + When :math:`|x| <= w` the gradient is `1 / w`. + origin: bool + Whether to compute the original function as the feedfoward output. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450. + .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318. + .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372. + .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58. + .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525. + .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. + + """ + return PiecewiseLeakyRelu(c=c, w=w)(x) class SquarewaveFourierSeries(_OneInpSurrogate): - """Judge spiking state with a squarewave fourier series. + """Judge spiking state with a squarewave fourier series. - See Also:: + See Also:: - squarewave_fourier_series - """ + squarewave_fourier_series + """ - def __init__(self, n=2, t_period=8., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.n = n - self.t_period = t_period + def __init__(self, n=2, t_period=8., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.n = n + self.t_period = t_period - def surrogate_grad(self, dz, x): - x = as_jax(x) - w = jnp.pi * 2. / self.t_period - dx = jnp.cos(w * x) - for i in range(2, self.n): - dx += jnp.cos((2 * i - 1.) * w * x) - dx *= 4. / self.t_period - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + w = jnp.pi * 2. / self.t_period + dx = jnp.cos(w * x) + for i in range(2, self.n): + dx += jnp.cos((2 * i - 1.) * w * x) + dx *= 4. / self.t_period + return dx * as_jax(dz) - def surrogate_fun(self, x): - x = as_jax(x) - w = jnp.pi * 2. / self.t_period - ret = jnp.sin(w * x) - for i in range(2, self.n): - c = (2 * i - 1.) - ret += jnp.sin(c * w * x) / c - z = 0.5 + 2. / jnp.pi * ret - return z + def surrogate_fun(self, x): + x = as_jax(x) + w = jnp.pi * 2. / self.t_period + ret = jnp.sin(w * x) + for i in range(2, self.n): + c = (2 * i - 1.) + ret += jnp.sin(c * w * x) / c + z = 0.5 + 2. / jnp.pi * ret + return z - def __repr__(self): - return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' + def __repr__(self): + return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' def squarewave_fourier_series( @@ -873,92 +873,92 @@ def squarewave_fourier_series( t_period: float = 8., origin: bool = False ): - r"""Judge spiking state with a squarewave fourier series. + r"""Judge spiking state with a squarewave fourier series. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } + g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T} + g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for n in [2, 4, 8]: - >>> f = bm.surrogate.SquarewaveFourierSeries(n=n) - >>> grads1 = bm.vector_grad(f)(xs) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}') - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for n in [2, 4, 8]: + >>> f = bm.surrogate.SquarewaveFourierSeries(n=n) + >>> grads1 = bm.vector_grad(f)(xs) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}') + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - n: int - t_period: float - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + n: int + t_period: float + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ + """ - return SquarewaveFourierSeries(n=n, t_period=t_period, forward_use_surrogate=origin)(x) + return SquarewaveFourierSeries(n=n, t_period=t_period, forward_use_surrogate=origin)(x) class S2NN(_OneInpSurrogate): - """Judge spiking state with the S2NN surrogate spiking function. + """Judge spiking state with the S2NN surrogate spiking function. - See Also:: + See Also:: - s2nn - """ + s2nn + """ - def __init__(self, alpha=4., beta=1., epsilon=1e-8, forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha - self.beta = beta - self.epsilon = epsilon + def __init__(self, alpha=4., beta=1., epsilon=1e-8, forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha + self.beta = beta + self.epsilon = epsilon - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < 0., - sci.special.expit(x * self.alpha), - self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < 0., + sci.special.expit(x * self.alpha), + self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) + return z - def surrogate_grad(self, dz, x): - x = as_jax(x) - sg = sci.special.expit(self.alpha * x) - dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + sg = sci.special.expit(self.alpha * x) + dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) + return dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' def s2nn( @@ -968,102 +968,102 @@ def s2nn( epsilon: float = 1e-8, origin: bool = False ): - r"""Judge spiking state with the S2NN surrogate spiking function [1]_. + r"""Judge spiking state with the S2NN surrogate spiking function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = \begin{cases} - \mathrm{sigmoid} (\alpha x), x < 0 \\ - \beta \ln(|x + 1|) + 0.5, x \ge 0 - \end{cases}\end{split} + \begin{split}g(x) = \begin{cases} + \mathrm{sigmoid} (\alpha x), x < 0 \\ + \beta \ln(|x + 1|) + 0.5, x \ge 0 + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - \begin{split}g'(x) = \begin{cases} - \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ - \frac{\beta}{(x + 1)}, x \ge 0 - \end{cases}\end{split} + \begin{split}g'(x) = \begin{cases} + \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ + \frac{\beta}{(x + 1)}, x \ge 0 + \end{cases}\end{split} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$') - >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$') - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$') + >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$') + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - The param that controls the gradient when ``x < 0``. - beta: float - The param that controls the gradient when ``x >= 0`` - epsilon: float - Avoid nan - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + The param that controls the gradient when ``x < 0``. + beta: float + The param that controls the gradient when ``x >= 0`` + epsilon: float + Avoid nan + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. + .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. - """ - return S2NN(alpha=alpha, beta=beta, epsilon=epsilon, forward_use_surrogate=origin)(x) + """ + return S2NN(alpha=alpha, beta=beta, epsilon=epsilon, forward_use_surrogate=origin)(x) class QPseudoSpike(_OneInpSurrogate): - """Judge spiking state with the q-PseudoSpike surrogate function. + """Judge spiking state with the q-PseudoSpike surrogate function. - See Also:: + See Also:: - q_pseudo_spike - """ + q_pseudo_spike + """ - def __init__(self, alpha=2., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha=2., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) + return dx * as_jax(dz) - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < 0., - 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), - 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < 0., + 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), + 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) + return z - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def q_pseudo_spike( @@ -1071,92 +1071,92 @@ def q_pseudo_spike( alpha: float = 2., origin: bool = False ): - r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. + r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ - 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. - \end{cases}\end{split} + \begin{split}g(x) = + \begin{cases} + \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ + 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha} + g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - The parameter to control tail fatness of gradient. - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + The parameter to control tail fatness of gradient. + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. - """ - return QPseudoSpike(alpha=alpha, forward_use_surrogate=origin)(x) + .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. + """ + return QPseudoSpike(alpha=alpha, forward_use_surrogate=origin)(x) class LeakyRelu(_OneInpSurrogate): - """Judge spiking state with the Leaky ReLU function. + """Judge spiking state with the Leaky ReLU function. - See Also:: + See Also:: - leaky_relu - """ + leaky_relu + """ - def __init__(self, alpha=0.1, beta=1., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha - self.beta = beta + def __init__(self, alpha=0.1, beta=1., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha + self.beta = beta - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0., self.alpha * x, self.beta * x) + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0., self.alpha * x, self.beta * x) - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.where(x < 0., self.alpha, self.beta) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.where(x < 0., self.alpha, self.beta) + return dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' def leaky_relu( @@ -1165,101 +1165,101 @@ def leaky_relu( beta: float = 1., origin: bool = False ): - r"""Judge spiking state with the Leaky ReLU function. - - If `origin=False`, computes the forward function: + r"""Judge spiking state with the Leaky ReLU function. - .. math:: + If `origin=False`, computes the forward function: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: + .. math:: - .. math:: - - \begin{split}g(x) = - \begin{cases} - \beta \cdot x, & x \geq 0 \\ - \alpha \cdot x, & x < 0 \\ - \end{cases}\end{split} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - Backward function: - - .. math:: - - \begin{split}g'(x) = - \begin{cases} - \beta, & x \geq 0 \\ - \alpha, & x < 0 \\ - \end{cases}\end{split} + If `origin=True`, computes the original function: - .. plot:: - :include-source: True + .. math:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') - >>> plt.legend() - >>> plt.show() - - Parameters:: + \begin{split}g(x) = + \begin{cases} + \beta \cdot x, & x \geq 0 \\ + \alpha \cdot x, & x < 0 \\ + \end{cases}\end{split} - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient when :math:`x < 0`. - beta: float - The parameter to control the gradient when :math:`x >= 0`. - origin: bool - Whether to compute the original function as the feedfoward output. + Backward function: - Returns:: + .. math:: - out: jax.Array - The spiking state. - """ - return LeakyRelu(alpha=alpha, beta=beta, forward_use_surrogate=origin)(x) + \begin{split}g'(x) = + \begin{cases} + \beta, & x \geq 0 \\ + \alpha, & x < 0 \\ + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient when :math:`x < 0`. + beta: float + The parameter to control the gradient when :math:`x >= 0`. + origin: bool + Whether to compute the original function as the feedfoward output. + + Returns:: + + out: jax.Array + The spiking state. + """ + return LeakyRelu(alpha=alpha, beta=beta, forward_use_surrogate=origin)(x) class LogTailedRelu(_OneInpSurrogate): - """Judge spiking state with the Log-tailed ReLU function. + """Judge spiking state with the Log-tailed ReLU function. - See Also:: + See Also:: - log_tailed_relu - """ + log_tailed_relu + """ - def __init__(self, alpha=0., forward_use_surrogate=False): - super().__init__(forward_use_surrogate=forward_use_surrogate) - self.alpha = alpha + def __init__(self, alpha=0., forward_use_surrogate=False): + super().__init__(forward_use_surrogate=forward_use_surrogate) + self.alpha = alpha - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x > 1, - jnp.log(x), - jnp.where(x > 0, - x, - self.alpha * x)) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x > 1, + jnp.log(x), + jnp.where(x > 0, + x, + self.alpha * x)) + return z - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.where(x > 1, - 1 / x, - jnp.where(x > 0, - 1., - self.alpha)) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.where(x > 1, + 1 / x, + jnp.where(x > 0, + 1., + self.alpha)) + return dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def log_tailed_relu( @@ -1267,93 +1267,93 @@ def log_tailed_relu( alpha: float = 0., origin: bool = False ): - r"""Judge spiking state with the Log-tailed ReLU function [1]_. + r"""Judge spiking state with the Log-tailed ReLU function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - \alpha x, & x \leq 0 \\ - x, & 0 < x \leq 0 \\ - log(x), x > 1 \\ - \end{cases}\end{split} + \begin{split}g(x) = + \begin{cases} + \alpha x, & x \leq 0 \\ + x, & 0 < x \leq 0 \\ + log(x), x > 1 \\ + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - \begin{split}g'(x) = - \begin{cases} - \alpha, & x \leq 0 \\ - 1, & 0 < x \leq 0 \\ - \frac{1}{x}, x > 1 \\ - \end{cases}\end{split} + \begin{split}g'(x) = + \begin{cases} + \alpha, & x \leq 0 \\ + 1, & 0 < x \leq 0 \\ + \frac{1}{x}, x > 1 \\ + \end{cases}\end{split} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient. - origin: bool - Whether to compute the original function as the feedfoward output. + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient. + origin: bool + Whether to compute the original function as the feedfoward output. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. - """ - return LogTailedRelu(alpha=alpha, forward_use_surrogate=origin)(x) + .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. + """ + return LogTailedRelu(alpha=alpha, forward_use_surrogate=origin)(x) class ReluGrad(_OneInpSurrogate): - """Judge spiking state with the ReLU gradient function. + """Judge spiking state with the ReLU gradient function. - See Also:: + See Also:: - relu_grad - """ + relu_grad + """ - def __init__(self, alpha=0.3, width=1.): - super().__init__() - self.alpha = alpha - self.width = width + def __init__(self, alpha=0.3, width=1.): + super().__init__() + self.alpha = alpha + self.width = width - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) - return dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) + return dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' def relu_grad( @@ -1361,79 +1361,79 @@ def relu_grad( alpha: float = 0.3, width: float = 1., ): - r"""Spike function with the ReLU gradient function [1]_. + r"""Spike function with the ReLU gradient function [1]_. - The forward function: + The forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|)) + g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|)) - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for s in [0.5, 1.]: - >>> for w in [1, 2.]: - >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}') - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for s in [0.5, 1.]: + >>> for w in [1, 2.]: + >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}') + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient. - width: float - The parameter to control the width of the gradient. + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient. + width: float + The parameter to control the width of the gradient. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). - """ - return ReluGrad(alpha=alpha, width=width)(x) + .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). + """ + return ReluGrad(alpha=alpha, width=width)(x) class GaussianGrad(_OneInpSurrogate): - """Judge spiking state with the Gaussian gradient function. + """Judge spiking state with the Gaussian gradient function. - See Also:: + See Also:: - gaussian_grad - """ + gaussian_grad + """ - def __init__(self, sigma=0.5, alpha=0.5): - super().__init__() - self.sigma = sigma - self.alpha = alpha + def __init__(self, sigma=0.5, alpha=0.5): + super().__init__() + self.sigma = sigma + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) - return self.alpha * dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + return self.alpha * dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' def gaussian_grad( @@ -1441,85 +1441,85 @@ def gaussian_grad( sigma: float = 0.5, alpha: float = 0.5, ): - r"""Spike function with the Gaussian gradient function [1]_. + r"""Spike function with the Gaussian gradient function [1]_. - The forward function: + The forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \alpha * \text{gaussian}(x, 0., \sigma) + g'(x) = \alpha * \text{gaussian}(x, 0., \sigma) - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for s in [0.5, 1., 2.]: - >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for s in [0.5, 1., 2.]: + >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - sigma: float - The parameter to control the variance of gaussian distribution. - alpha: float - The parameter to control the scale of the gradient. + x: jax.Array, Array + The input data. + sigma: float + The parameter to control the variance of gaussian distribution. + alpha: float + The parameter to control the scale of the gradient. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). - """ - return GaussianGrad(sigma=sigma, alpha=alpha)(x) + .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). + """ + return GaussianGrad(sigma=sigma, alpha=alpha)(x) class MultiGaussianGrad(_OneInpSurrogate): - """Judge spiking state with the multi-Gaussian gradient function. + """Judge spiking state with the multi-Gaussian gradient function. - See Also:: + See Also:: - multi_gaussian_grad - """ + multi_gaussian_grad + """ - def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): - super().__init__() - self.h = h - self.s = s - self.sigma = sigma - self.scale = scale + def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): + super().__init__() + self.h = h + self.s = s + self.sigma = sigma + self.scale = scale - def surrogate_grad(self, dz, x): - x = as_jax(x) - g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) - g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) - ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) - g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) - ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) - dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h - return self.scale * dx * as_jax(dz) + def surrogate_grad(self, dz, x): + x = as_jax(x) + g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h + return self.scale * dx * as_jax(dz) - def __repr__(self): - return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' + def __repr__(self): + return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' def multi_gaussian_grad( @@ -1529,207 +1529,207 @@ def multi_gaussian_grad( sigma: float = 0.5, scale: float = 0.5, ): - r"""Spike function with the multi-Gaussian gradient function [1]_. + r"""Spike function with the multi-Gaussian gradient function [1]_. - The forward function: + The forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - \begin{array}{l} - g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2}) - -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})- - h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2}) - \end{array} + \begin{array}{l} + g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2}) + -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})- + h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2}) + \end{array} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads)) - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads)) + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. + x: jax.Array, Array + The input data. + h: float + The hyper-parameters of approximate function + s: float + The hyper-parameters of approximate function + sigma: float + The gaussian sigma. + scale: float + The gradient scale. - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). - """ - return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) + .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). + """ + return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) class InvSquareGrad(_OneInpSurrogate): - """Judge spiking state with the inverse-square surrogate gradient function. + """Judge spiking state with the inverse-square surrogate gradient function. - See Also:: + See Also:: - inv_square_grad - """ + inv_square_grad + """ - def __init__(self, alpha=100.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha=100.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = as_jax(dz) / (self.alpha * jnp.abs(x) + 1.0) ** 2 - return dx + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = as_jax(dz) / (self.alpha * jnp.abs(x) + 1.0) ** 2 + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def inv_square_grad( x: Union[jax.Array, Array], alpha: float = 100. ): - r"""Spike function with the inverse-square surrogate gradient. + r"""Spike function with the inverse-square surrogate gradient. - Forward function: + Forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} + g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-1, 1, 1000) - >>> for alpha in [1., 10., 100.]: - >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-1, 1, 1000) + >>> for alpha in [1., 10., 100.]: + >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Returns:: + Returns:: - out: jax.Array - The spiking state. - """ - return InvSquareGrad(alpha=alpha)(x) + out: jax.Array + The spiking state. + """ + return InvSquareGrad(alpha=alpha)(x) class SlayerGrad(_OneInpSurrogate): - """Judge spiking state with the slayer surrogate gradient function. + """Judge spiking state with the slayer surrogate gradient function. - See Also:: + See Also:: - slayer_grad - """ + slayer_grad + """ - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, dz, x): - x = as_jax(x) - dx = as_jax(dz) * jnp.exp(-self.alpha * jnp.abs(x)) - return dx + def surrogate_grad(self, dz, x): + x = as_jax(x) + dx = as_jax(dz) * jnp.exp(-self.alpha * jnp.abs(x)) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def slayer_grad( x: Union[jax.Array, Array], alpha: float = 1. ): - r"""Spike function with the slayer surrogate gradient function. + r"""Spike function with the slayer surrogate gradient function. - Forward function: + Forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \exp(-\alpha |x|) + g'(x) = \exp(-\alpha |x|) - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: + Parameters:: - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Returns:: + Returns:: - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: + References:: - .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). - """ - return SlayerGrad(alpha=alpha)(x) + .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). + """ + return SlayerGrad(alpha=alpha)(x) diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py index a51339ace..c213a6ab6 100644 --- a/brainpy/_src/math/surrogate/_one_input_new.py +++ b/brainpy/_src/math/surrogate/_one_input_new.py @@ -7,74 +7,74 @@ import jax.scipy as sci if jax.__version__ >= '0.5.0': - from jax.extend.core import Primitive + from jax.extend.core import Primitive else: - from jax.core import Primitive + from jax.core import Primitive from jax.interpreters import batching, ad, mlir from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import BaseArray as Array __all__ = [ - 'Surrogate', - 'Sigmoid', - 'sigmoid', - 'PiecewiseQuadratic', - 'piecewise_quadratic', - 'PiecewiseExp', - 'piecewise_exp', - 'SoftSign', - 'soft_sign', - 'Arctan', - 'arctan', - 'NonzeroSignLog', - 'nonzero_sign_log', - 'ERF', - 'erf', - 'PiecewiseLeakyRelu', - 'piecewise_leaky_relu', - 'SquarewaveFourierSeries', - 'squarewave_fourier_series', - 'S2NN', - 's2nn', - 'QPseudoSpike', - 'q_pseudo_spike', - 'LeakyRelu', - 'leaky_relu', - 'LogTailedRelu', - 'log_tailed_relu', - 'ReluGrad', - 'relu_grad', - 'GaussianGrad', - 'gaussian_grad', - 'InvSquareGrad', - 'inv_square_grad', - 'MultiGaussianGrad', - 'multi_gaussian_grad', - 'SlayerGrad', - 'slayer_grad', + 'Surrogate', + 'Sigmoid', + 'sigmoid', + 'PiecewiseQuadratic', + 'piecewise_quadratic', + 'PiecewiseExp', + 'piecewise_exp', + 'SoftSign', + 'soft_sign', + 'Arctan', + 'arctan', + 'NonzeroSignLog', + 'nonzero_sign_log', + 'ERF', + 'erf', + 'PiecewiseLeakyRelu', + 'piecewise_leaky_relu', + 'SquarewaveFourierSeries', + 'squarewave_fourier_series', + 'S2NN', + 's2nn', + 'QPseudoSpike', + 'q_pseudo_spike', + 'LeakyRelu', + 'leaky_relu', + 'LogTailedRelu', + 'log_tailed_relu', + 'ReluGrad', + 'relu_grad', + 'GaussianGrad', + 'gaussian_grad', + 'InvSquareGrad', + 'inv_square_grad', + 'MultiGaussianGrad', + 'multi_gaussian_grad', + 'SlayerGrad', + 'slayer_grad', ] def _heaviside_abstract(x, dx): - return [x] + return [x] def _heaviside_imp(x, dx): - z = jnp.asarray(x >= 0, dtype=x.dtype) - return [z] + z = jnp.asarray(x >= 0, dtype=x.dtype) + return [z] def _heaviside_batching(args, axes): - return heaviside_p.bind(*args), [axes[0]] + return heaviside_p.bind(*args), [axes[0]] def _heaviside_jvp(primals, tangents): - x, dx = primals - tx, tdx = tangents - primal_outs = heaviside_p.bind(x, dx) - tangent_outs = [dx * tx, ] - return primal_outs, tangent_outs + x, dx = primals + tx, tdx = tangents + primal_outs = heaviside_p.bind(x, dx) + tangent_outs = [dx * tx, ] + return primal_outs, tangent_outs heaviside_p = Primitive('heaviside_p') @@ -87,265 +87,265 @@ def _heaviside_jvp(primals, tangents): def _is_bp_array(x): - return isinstance(x, Array) + return isinstance(x, Array) def _as_jax(x): - return x.value if _is_bp_array(x) else x + return x.value if _is_bp_array(x) else x class Surrogate(object): - """The base surrograte gradient function. + """The base surrograte gradient function. - To customize a surrogate gradient function, you can inherit this class and - implement the `surrogate_fun` and `surrogate_grad` methods. + To customize a surrogate gradient function, you can inherit this class and + implement the `surrogate_fun` and `surrogate_grad` methods. - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import jax.numpy as jnp + Examples:: - >>> class MySurrogate(bm.Surrogate): - ... def __init__(self, alpha=1.): - ... super().__init__() - ... self.alpha = alpha - ... - ... def surrogate_fun(self, x): - ... return jnp.sin(x) * self.alpha - ... - ... def surrogate_grad(self, x): - ... return jnp.cos(x) * self.alpha + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import jax.numpy as jnp - """ + >>> class MySurrogate(bm.Surrogate): + ... def __init__(self, alpha=1.): + ... super().__init__() + ... self.alpha = alpha + ... + ... def surrogate_fun(self, x): + ... return jnp.sin(x) * self.alpha + ... + ... def surrogate_grad(self, x): + ... return jnp.cos(x) * self.alpha - def __call__(self, x): - x = _as_jax(x) - dx = self.surrogate_grad(x) - return heaviside_p.bind(x, dx)[0] + """ - def __repr__(self): - return f'{self.__class__.__name__}()' + def __call__(self, x): + x = _as_jax(x) + dx = self.surrogate_grad(x) + return heaviside_p.bind(x, dx)[0] - def surrogate_fun(self, x) -> jax.Array: - """The surrogate function.""" - raise NotImplementedError + def __repr__(self): + return f'{self.__class__.__name__}()' - def surrogate_grad(self, x) -> jax.Array: - """The gradient function of the surrogate function.""" - raise NotImplementedError + def surrogate_fun(self, x) -> jax.Array: + """The surrogate function.""" + raise NotImplementedError + + def surrogate_grad(self, x) -> jax.Array: + """The gradient function of the surrogate function.""" + raise NotImplementedError class Sigmoid(Surrogate): - """Spike function with the sigmoid-shaped surrogate gradient. + """Spike function with the sigmoid-shaped surrogate gradient. + + See Also:: - See Also:: - - sigmoid + sigmoid - """ + """ - def __init__(self, alpha: float = 4.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha: float = 4.): + super().__init__() + self.alpha = alpha - def surrogate_fun(self, x): - return sci.special.expit(self.alpha * x) + def surrogate_fun(self, x): + return sci.special.expit(self.alpha * x) - def surrogate_grad(self, x): - sgax = sci.special.expit(x * self.alpha) - dx = (1. - sgax) * sgax * self.alpha - return dx + def surrogate_grad(self, x): + sgax = sci.special.expit(x * self.alpha) + dx = (1. - sgax) * sgax * self.alpha + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def sigmoid( x: Union[jax.Array, Array], alpha: float = 4., ): - r"""Spike function with the sigmoid-shaped surrogate gradient. + r"""Spike function with the sigmoid-shaped surrogate gradient. - If `origin=False`, return the forward function: + If `origin=False`, return the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} + g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x) + g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x) - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-2, 2, 1000) - >>> for alpha in [1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-2, 2, 1000) + >>> for alpha in [1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + Parameters:: + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Returns:: - - out: jax.Array - The spiking state. - """ - return Sigmoid(alpha=alpha)(x) + + Returns:: + + out: jax.Array + The spiking state. + """ + return Sigmoid(alpha=alpha)(x) class PiecewiseQuadratic(Surrogate): - """Judge spiking state with a piecewise quadratic function. + """Judge spiking state with a piecewise quadratic function. + + See Also:: - See Also:: - - piecewise_quadratic + piecewise_quadratic - """ + """ - def __init__(self, alpha: float = 1.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha: float = 1.): + super().__init__() + self.alpha = alpha - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < -1 / self.alpha, - 0., - jnp.where(x > 1 / self.alpha, - 1., - (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < -1 / self.alpha, + 0., + jnp.where(x > 1 / self.alpha, + 1., + (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) + return z - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha)) - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha)) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def piecewise_quadratic( x: Union[jax.Array, Array], alpha: float = 1., ): - r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. + r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = - \begin{cases} - 0, & x < -\frac{1}{\alpha} \\ - -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ - 1, & x > \frac{1}{\alpha} \\ - \end{cases} + g(x) = + \begin{cases} + 0, & x < -\frac{1}{\alpha} \\ + -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ + 1, & x > \frac{1}{\alpha} \\ + \end{cases} - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = - \begin{cases} - 0, & |x| > \frac{1}{\alpha} \\ - -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} - \end{cases} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805. - .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. - .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. - """ - return PiecewiseQuadratic(alpha=alpha)(x) + g'(x) = + \begin{cases} + 0, & |x| > \frac{1}{\alpha} \\ + -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} + \end{cases} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805. + .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. + .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. + """ + return PiecewiseQuadratic(alpha=alpha)(x) class PiecewiseExp(Surrogate): - """Judge spiking state with a piecewise exponential function. + """Judge spiking state with a piecewise exponential function. + + See Also:: - See Also:: - - piecewise_exp - """ + piecewise_exp + """ - def __init__(self, alpha: float = 1.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha: float = 1.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, x): - x = as_jax(x) - dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) + return dx - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def piecewise_exp( @@ -353,89 +353,89 @@ def piecewise_exp( alpha: float = 1., ): - r"""Judge spiking state with a piecewise exponential function [1]_. + r"""Judge spiking state with a piecewise exponential function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: - If `origin=True`, computes the original function: + g(x) = \begin{cases} + \frac{1}{2}e^{\alpha x}, & x < 0 \\ + 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 + \end{cases} - .. math:: + Backward function: - g(x) = \begin{cases} - \frac{1}{2}e^{\alpha x}, & x < 0 \\ - 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 - \end{cases} + .. math:: - Backward function: + g'(x) = \frac{\alpha}{2}e^{-\alpha |x|} - .. math:: + .. plot:: + :include-source: True - g'(x) = \frac{\alpha}{2}e^{-\alpha |x|} + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: - - .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. - """ - return PiecewiseExp(alpha=alpha)(x) + References:: + + .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. + """ + return PiecewiseExp(alpha=alpha)(x) class SoftSign(Surrogate): - """Judge spiking state with a soft sign function. + """Judge spiking state with a soft sign function. + + See Also:: - See Also:: - - soft_sign - """ + soft_sign + """ - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, x): - x = as_jax(x) - dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 + return dx - def surrogate_fun(self, x): - x = as_jax(x) - return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 + def surrogate_fun(self, x): + x = as_jax(x) + return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def soft_sign( @@ -443,84 +443,84 @@ def soft_sign( alpha: float = 1., ): - r"""Judge spiking state with a soft sign function. + r"""Judge spiking state with a soft sign function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) - = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) + g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) + = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) - Backward function: + Backward function: - .. math:: + .. math:: - g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}} + g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}} - .. plot:: - :include-source: True + .. plot:: + :include-source: True - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + Parameters:: + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Returns:: - - out: jax.Array - The spiking state. - """ - return SoftSign(alpha=alpha)(x) + Returns:: + + out: jax.Array + The spiking state. + + """ + return SoftSign(alpha=alpha)(x) class Arctan(Surrogate): - """Judge spiking state with an arctan function. + """Judge spiking state with an arctan function. + + See Also:: - See Also:: - - arctan - """ + arctan + """ - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, x): - x = as_jax(x) - dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) + return dx - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def arctan( @@ -528,83 +528,83 @@ def arctan( alpha: float = 1., ): - r"""Judge spiking state with an arctan function. + r"""Judge spiking state with an arctan function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: - If `origin=True`, computes the original function: + g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} - .. math:: + Backward function: - g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} + .. math:: - Backward function: + g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)} - .. math:: + .. plot:: + :include-source: True - g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)} + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ - return Arctan(alpha=alpha)(x) + """ + return Arctan(alpha=alpha)(x) class NonzeroSignLog(Surrogate): - """Judge spiking state with a nonzero sign log function. + """Judge spiking state with a nonzero sign log function. - See Also:: - - nonzero_sign_log - """ + See Also:: - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha + nonzero_sign_log + """ - def surrogate_grad(self, x): - x = as_jax(x) - dx = 1. / (1 / self.alpha + jnp.abs(x)) - return dx + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) + def surrogate_grad(self, x): + x = as_jax(x) + dx = 1. / (1 / self.alpha + jnp.abs(x)) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def nonzero_sign_log( @@ -612,96 +612,96 @@ def nonzero_sign_log( alpha: float = 1., ): - r"""Judge spiking state with a nonzero sign log function. + r"""Judge spiking state with a nonzero sign log function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) + g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) - where + where - .. math:: + .. math:: - \begin{split}\mathrm{NonzeroSign}(x) = - \begin{cases} - 1, & x \geq 0 \\ - -1, & x < 0 \\ - \end{cases}\end{split} + \begin{split}\mathrm{NonzeroSign}(x) = + \begin{cases} + 1, & x \geq 0 \\ + -1, & x < 0 \\ + \end{cases}\end{split} + + Backward function: + + .. math:: - Backward function: + g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|} - .. math:: + This surrogate function has the advantage of low computation cost during the backward. - g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|} - This surrogate function has the advantage of low computation cost during the backward. + .. plot:: + :include-source: True + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ - return NonzeroSignLog(alpha=alpha)(x) + """ + return NonzeroSignLog(alpha=alpha)(x) class ERF(Surrogate): - """Judge spiking state with an erf function. + """Judge spiking state with an erf function. - See Also:: - - erf - """ + See Also:: - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha + erf + """ - def surrogate_grad(self, x): - x = as_jax(x) - dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) - return dx + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha - def surrogate_fun(self, x): - x = as_jax(x) - return sci.special.erf(-self.alpha * x) * 0.5 + def surrogate_grad(self, x): + x = as_jax(x) + dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def surrogate_fun(self, x): + x = as_jax(x) + return sci.special.erf(-self.alpha * x) * 0.5 + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def erf( @@ -709,99 +709,99 @@ def erf( alpha: float = 1., ): - r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. + r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split} + g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ + &= \frac{1}{2} \text{erfc}(-\alpha x) \\ + &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt + \end{split} - If `origin=True`, computes the original function: + Backward function: - .. math:: + .. math:: - \begin{split} - g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ - &= \frac{1}{2} \text{erfc}(-\alpha x) \\ - &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt - \end{split} + g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2} - Backward function: + .. plot:: + :include-source: True - .. math:: + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2} + Parameters:: - .. plot:: - :include-source: True + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + Returns:: + out: jax.Array + The spiking state. - Returns:: - - out: jax.Array - The spiking state. + References:: - References:: - - .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. + .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. - """ - return ERF(alpha=alpha)(x) + """ + return ERF(alpha=alpha)(x) class PiecewiseLeakyRelu(Surrogate): - """Judge spiking state with a piecewise leaky relu function. + """Judge spiking state with a piecewise leaky relu function. - See Also:: - - piecewise_leaky_relu - """ + See Also:: - def __init__(self, c=0.01, w=1.): - super().__init__() - self.c = c - self.w = w + piecewise_leaky_relu + """ - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < -self.w, - self.c * x + self.c * self.w, - jnp.where(x > self.w, - self.c * x - self.c * self.w + 1, - 0.5 * x / self.w + 0.5)) - return z + def __init__(self, c=0.01, w=1.): + super().__init__() + self.c = c + self.w = w - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) - return dx + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < -self.w, + self.c * x + self.c * self.w, + jnp.where(x > self.w, + self.c * x - self.c * self.w + 1, + 0.5 * x / self.w + 0.5)) + return z - def __repr__(self): - return f'{self.__class__.__name__}(c={self.c}, w={self.w})' + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(c={self.c}, w={self.w})' def piecewise_leaky_relu( @@ -810,117 +810,117 @@ def piecewise_leaky_relu( w: float = 1., ): - r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. + r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - cx + cw, & x < -w \\ - \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ - cx - cw + 1, & x > w \\ - \end{cases}\end{split} + \begin{split}g(x) = + \begin{cases} + cx + cw, & x < -w \\ + \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ + cx - cw + 1, & x > w \\ + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: + + \begin{split}g'(x) = + \begin{cases} + \frac{1}{w}, & |x| \leq w \\ + c, & |x| > w + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for c in [0.01, 0.05, 0.1]: + >>> for w in [1., 2.]: + >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + c: float + When :math:`|x| > w` the gradient is `c`. + w: float + When :math:`|x| <= w` the gradient is `1 / w`. - \begin{split}g'(x) = - \begin{cases} - \frac{1}{w}, & |x| \leq w \\ - c, & |x| > w - \end{cases}\end{split} - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for c in [0.01, 0.05, 0.1]: - >>> for w in [1., 2.]: - >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - c: float - When :math:`|x| > w` the gradient is `c`. - w: float - When :math:`|x| <= w` the gradient is `1 / w`. - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450. - .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318. - .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372. - .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58. - .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525. - .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. - - """ - return PiecewiseLeakyRelu(c=c, w=w)(x) + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450. + .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318. + .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372. + .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58. + .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525. + .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. + + """ + return PiecewiseLeakyRelu(c=c, w=w)(x) class SquarewaveFourierSeries(Surrogate): - """Judge spiking state with a squarewave fourier series. - - See Also:: - - squarewave_fourier_series - """ - - def __init__(self, n=2, t_period=8.): - super().__init__() - self.n = n - self.t_period = t_period - - def surrogate_grad(self, x): - x = as_jax(x) - w = jnp.pi * 2. / self.t_period - dx = jnp.cos(w * x) - for i in range(2, self.n): - dx += jnp.cos((2 * i - 1.) * w * x) - dx *= 4. / self.t_period - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - w = jnp.pi * 2. / self.t_period - ret = jnp.sin(w * x) - for i in range(2, self.n): - c = (2 * i - 1.) - ret += jnp.sin(c * w * x) / c - z = 0.5 + 2. / jnp.pi * ret - return z - - def __repr__(self): - return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' + """Judge spiking state with a squarewave fourier series. + + See Also:: + + squarewave_fourier_series + """ + + def __init__(self, n=2, t_period=8.): + super().__init__() + self.n = n + self.t_period = t_period + + def surrogate_grad(self, x): + x = as_jax(x) + w = jnp.pi * 2. / self.t_period + dx = jnp.cos(w * x) + for i in range(2, self.n): + dx += jnp.cos((2 * i - 1.) * w * x) + dx *= 4. / self.t_period + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + w = jnp.pi * 2. / self.t_period + ret = jnp.sin(w * x) + for i in range(2, self.n): + c = (2 * i - 1.) + ret += jnp.sin(c * w * x) / c + z = 0.5 + 2. / jnp.pi * ret + return z + + def __repr__(self): + return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' def squarewave_fourier_series( @@ -929,91 +929,91 @@ def squarewave_fourier_series( t_period: float = 8., ): - r"""Judge spiking state with a squarewave fourier series. + r"""Judge spiking state with a squarewave fourier series. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: - If `origin=True`, computes the original function: + g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } - .. math:: + Backward function: - g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } + .. math:: - Backward function: + g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T} - .. math:: + .. plot:: + :include-source: True - g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T} + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for n in [2, 4, 8]: + >>> f = bm.surrogate.SquarewaveFourierSeries(n=n) + >>> grads1 = bm.vector_grad(f)(xs) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}') + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for n in [2, 4, 8]: - >>> f = bm.surrogate.SquarewaveFourierSeries(n=n) - >>> grads1 = bm.vector_grad(f)(xs) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}') - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + n: int + t_period: float - Parameters:: - - x: jax.Array, Array - The input data. - n: int - t_period: float + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - """ + """ - return SquarewaveFourierSeries(n=n, t_period=t_period)(x) + return SquarewaveFourierSeries(n=n, t_period=t_period)(x) class S2NN(Surrogate): - """Judge spiking state with the S2NN surrogate spiking function. + """Judge spiking state with the S2NN surrogate spiking function. - See Also:: - - s2nn - """ + See Also:: - def __init__(self, alpha=4., beta=1., epsilon=1e-8): - super().__init__() - self.alpha = alpha - self.beta = beta - self.epsilon = epsilon + s2nn + """ - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < 0., - sci.special.expit(x * self.alpha), - self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) - return z + def __init__(self, alpha=4., beta=1., epsilon=1e-8): + super().__init__() + self.alpha = alpha + self.beta = beta + self.epsilon = epsilon - def surrogate_grad(self, x): - x = as_jax(x) - sg = sci.special.expit(self.alpha * x) - dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) - return dx + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < 0., + sci.special.expit(x * self.alpha), + self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) + return z - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' + def surrogate_grad(self, x): + x = as_jax(x) + sg = sci.special.expit(self.alpha * x) + dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' def s2nn( @@ -1023,101 +1023,101 @@ def s2nn( epsilon: float = 1e-8, ): - r"""Judge spiking state with the S2NN surrogate spiking function [1]_. + r"""Judge spiking state with the S2NN surrogate spiking function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = \begin{cases} - \mathrm{sigmoid} (\alpha x), x < 0 \\ - \beta \ln(|x + 1|) + 0.5, x \ge 0 - \end{cases}\end{split} + \begin{split}g(x) = \begin{cases} + \mathrm{sigmoid} (\alpha x), x < 0 \\ + \beta \ln(|x + 1|) + 0.5, x \ge 0 + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - \begin{split}g'(x) = \begin{cases} - \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ - \frac{\beta}{(x + 1)}, x \ge 0 - \end{cases}\end{split} + \begin{split}g'(x) = \begin{cases} + \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ + \frac{\beta}{(x + 1)}, x \ge 0 + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$') + >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$') + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$') - >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$') - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + alpha: float + The param that controls the gradient when ``x < 0``. + beta: float + The param that controls the gradient when ``x >= 0`` + epsilon: float + Avoid nan - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The param that controls the gradient when ``x < 0``. - beta: float - The param that controls the gradient when ``x >= 0`` - epsilon: float - Avoid nan + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: - - .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. + References:: - """ - return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x) + .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. + + """ + return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x) class QPseudoSpike(Surrogate): - """Judge spiking state with the q-PseudoSpike surrogate function. + """Judge spiking state with the q-PseudoSpike surrogate function. + + See Also:: - See Also:: - - q_pseudo_spike - """ + q_pseudo_spike + """ - def __init__(self, alpha=2.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha=2.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) + return dx - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < 0., - 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), - 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) - return z + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < 0., + 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), + 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) + return z - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def q_pseudo_spike( @@ -1125,91 +1125,91 @@ def q_pseudo_spike( alpha: float = 2., ): - r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. + r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ - 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. - \end{cases}\end{split} + \begin{split}g(x) = + \begin{cases} + \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ + 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. + \end{cases}\end{split} + + Backward function: + + .. math:: - Backward function: + g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha} - .. math:: + .. plot:: + :include-source: True - g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha} + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha)) + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha)) - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + alpha: float + The parameter to control tail fatness of gradient. - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control tail fatness of gradient. + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: - - .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. - """ - return QPseudoSpike(alpha=alpha)(x) + References:: + + .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. + """ + return QPseudoSpike(alpha=alpha)(x) class LeakyRelu(Surrogate): - """Judge spiking state with the Leaky ReLU function. + """Judge spiking state with the Leaky ReLU function. + + See Also:: - See Also:: - - leaky_relu - """ + leaky_relu + """ - def __init__(self, alpha=0.1, beta=1.): - super().__init__() - self.alpha = alpha - self.beta = beta + def __init__(self, alpha=0.1, beta=1.): + super().__init__() + self.alpha = alpha + self.beta = beta - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0., self.alpha * x, self.beta * x) + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0., self.alpha * x, self.beta * x) - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(x < 0., self.alpha, self.beta) - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(x < 0., self.alpha, self.beta) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' def leaky_relu( @@ -1218,100 +1218,100 @@ def leaky_relu( beta: float = 1., ): - r"""Judge spiking state with the Leaky ReLU function. + r"""Judge spiking state with the Leaky ReLU function. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - \beta \cdot x, & x \geq 0 \\ - \alpha \cdot x, & x < 0 \\ - \end{cases}\end{split} + \begin{split}g(x) = + \begin{cases} + \beta \cdot x, & x \geq 0 \\ + \alpha \cdot x, & x < 0 \\ + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - \begin{split}g'(x) = - \begin{cases} - \beta, & x \geq 0 \\ - \alpha, & x < 0 \\ - \end{cases}\end{split} + \begin{split}g'(x) = + \begin{cases} + \beta, & x \geq 0 \\ + \alpha, & x < 0 \\ + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient when :math:`x < 0`. + beta: float + The parameter to control the gradient when :math:`x >= 0`. - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient when :math:`x < 0`. - beta: float - The parameter to control the gradient when :math:`x >= 0`. - - - Returns:: - - out: jax.Array - The spiking state. - """ - return LeakyRelu(alpha=alpha, beta=beta)(x) + + Returns:: + + out: jax.Array + The spiking state. + """ + return LeakyRelu(alpha=alpha, beta=beta)(x) class LogTailedRelu(Surrogate): - """Judge spiking state with the Log-tailed ReLU function. - - See Also:: - - log_tailed_relu - """ - - def __init__(self, alpha=0.): - super().__init__() - self.alpha = alpha - - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x > 1, - jnp.log(x), - jnp.where(x > 0, - x, - self.alpha * x)) - return z - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(x > 1, - 1 / x, - jnp.where(x > 0, - 1., - self.alpha)) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + """Judge spiking state with the Log-tailed ReLU function. + + See Also:: + + log_tailed_relu + """ + + def __init__(self, alpha=0.): + super().__init__() + self.alpha = alpha + + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x > 1, + jnp.log(x), + jnp.where(x > 0, + x, + self.alpha * x)) + return z + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(x > 1, + 1 / x, + jnp.where(x > 0, + 1., + self.alpha)) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def log_tailed_relu( @@ -1319,92 +1319,92 @@ def log_tailed_relu( alpha: float = 0., ): - r"""Judge spiking state with the Log-tailed ReLU function [1]_. + r"""Judge spiking state with the Log-tailed ReLU function [1]_. - If `origin=False`, computes the forward function: + If `origin=False`, computes the forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} - If `origin=True`, computes the original function: + If `origin=True`, computes the original function: - .. math:: + .. math:: - \begin{split}g(x) = - \begin{cases} - \alpha x, & x \leq 0 \\ - x, & 0 < x \leq 0 \\ - log(x), x > 1 \\ - \end{cases}\end{split} + \begin{split}g(x) = + \begin{cases} + \alpha x, & x \leq 0 \\ + x, & 0 < x \leq 0 \\ + log(x), x > 1 \\ + \end{cases}\end{split} - Backward function: + Backward function: - .. math:: + .. math:: - \begin{split}g'(x) = - \begin{cases} - \alpha, & x \leq 0 \\ - 1, & 0 < x \leq 0 \\ - \frac{1}{x}, x > 1 \\ - \end{cases}\end{split} + \begin{split}g'(x) = + \begin{cases} + \alpha, & x \leq 0 \\ + 1, & 0 < x \leq 0 \\ + \frac{1}{x}, x > 1 \\ + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') + >>> plt.legend() + >>> plt.show() - .. plot:: - :include-source: True + Parameters:: - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') - >>> plt.legend() - >>> plt.show() + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient. - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient. + Returns:: - Returns:: - - out: jax.Array - The spiking state. + out: jax.Array + The spiking state. - References:: - - .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. - """ - return LogTailedRelu(alpha=alpha)(x) + References:: + + .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. + """ + return LogTailedRelu(alpha=alpha)(x) class ReluGrad(Surrogate): - """Judge spiking state with the ReLU gradient function. + """Judge spiking state with the ReLU gradient function. + + See Also:: - See Also:: - - relu_grad - """ + relu_grad + """ - def __init__(self, alpha=0.3, width=1.): - super().__init__() - self.alpha = alpha - self.width = width + def __init__(self, alpha=0.3, width=1.): + super().__init__() + self.alpha = alpha + self.width = width - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) - return dx + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' def relu_grad( @@ -1412,79 +1412,79 @@ def relu_grad( alpha: float = 0.3, width: float = 1., ): - r"""Spike function with the ReLU gradient function [1]_. + r"""Spike function with the ReLU gradient function [1]_. - The forward function: + The forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|)) - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for s in [0.5, 1.]: - >>> for w in [1, 2.]: - >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient. - width: float - The parameter to control the width of the gradient. - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). - """ - return ReluGrad(alpha=alpha, width=width)(x) + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|)) + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for s in [0.5, 1.]: + >>> for w in [1, 2.]: + >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient. + width: float + The parameter to control the width of the gradient. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). + """ + return ReluGrad(alpha=alpha, width=width)(x) class GaussianGrad(Surrogate): - """Judge spiking state with the Gaussian gradient function. + """Judge spiking state with the Gaussian gradient function. - See Also:: - - gaussian_grad - """ + See Also:: - def __init__(self, sigma=0.5, alpha=0.5): - super().__init__() - self.sigma = sigma - self.alpha = alpha + gaussian_grad + """ - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) - return self.alpha * dx + def __init__(self, sigma=0.5, alpha=0.5): + super().__init__() + self.sigma = sigma + self.alpha = alpha - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + return self.alpha * dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' def gaussian_grad( @@ -1492,85 +1492,85 @@ def gaussian_grad( sigma: float = 0.5, alpha: float = 0.5, ): - r"""Spike function with the Gaussian gradient function [1]_. + r"""Spike function with the Gaussian gradient function [1]_. - The forward function: + The forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \alpha * \text{gaussian}(x, 0., \sigma) - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for s in [0.5, 1., 2.]: - >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - sigma: float - The parameter to control the variance of gaussian distribution. - alpha: float - The parameter to control the scale of the gradient. - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). - """ - return GaussianGrad(sigma=sigma, alpha=alpha)(x) + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \alpha * \text{gaussian}(x, 0., \sigma) + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for s in [0.5, 1., 2.]: + >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + sigma: float + The parameter to control the variance of gaussian distribution. + alpha: float + The parameter to control the scale of the gradient. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). + """ + return GaussianGrad(sigma=sigma, alpha=alpha)(x) class MultiGaussianGrad(Surrogate): - """Judge spiking state with the multi-Gaussian gradient function. - - See Also:: - - multi_gaussian_grad - """ - - def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): - super().__init__() - self.h = h - self.s = s - self.sigma = sigma - self.scale = scale - - def surrogate_grad(self, x): - x = as_jax(x) - g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) - g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) - ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) - g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) - ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) - dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h - return self.scale * dx - - def __repr__(self): - return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' + """Judge spiking state with the multi-Gaussian gradient function. + + See Also:: + + multi_gaussian_grad + """ + + def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): + super().__init__() + self.h = h + self.s = s + self.sigma = sigma + self.scale = scale + + def surrogate_grad(self, x): + x = as_jax(x) + g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h + return self.scale * dx + + def __repr__(self): + return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' def multi_gaussian_grad( @@ -1580,205 +1580,205 @@ def multi_gaussian_grad( sigma: float = 0.5, scale: float = 0.5, ): - r"""Spike function with the multi-Gaussian gradient function [1]_. + r"""Spike function with the multi-Gaussian gradient function [1]_. - The forward function: + The forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - \begin{array}{l} - g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2}) - -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})- - h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2}) - \end{array} - - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads)) - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). - """ - return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + \begin{array}{l} + g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2}) + -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})- + h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2}) + \end{array} + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads)) + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + h: float + The hyper-parameters of approximate function + s: float + The hyper-parameters of approximate function + sigma: float + The gaussian sigma. + scale: float + The gradient scale. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). + """ + return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) class InvSquareGrad(Surrogate): - """Judge spiking state with the inverse-square surrogate gradient function. + """Judge spiking state with the inverse-square surrogate gradient function. - See Also:: - - inv_square_grad - """ + See Also:: - def __init__(self, alpha=100.): - super().__init__() - self.alpha = alpha + inv_square_grad + """ - def surrogate_grad(self, x): - dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2 - return dx + def __init__(self, alpha=100.): + super().__init__() + self.alpha = alpha - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def surrogate_grad(self, x): + dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2 + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def inv_square_grad( x: Union[jax.Array, Array], alpha: float = 100. ): - r"""Spike function with the inverse-square surrogate gradient. + r"""Spike function with the inverse-square surrogate gradient. - Forward function: + Forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: - Backward function: + .. math:: - .. math:: + g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} - g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} + .. plot:: + :include-source: True - .. plot:: - :include-source: True + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-1, 1, 1000) + >>> for alpha in [1., 10., 100.]: + >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-1, 1, 1000) - >>> for alpha in [1., 10., 100.]: - >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() + Parameters:: - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient - Returns:: - - out: jax.Array - The spiking state. - """ - return InvSquareGrad(alpha=alpha)(x) + Returns:: + + out: jax.Array + The spiking state. + """ + return InvSquareGrad(alpha=alpha)(x) class SlayerGrad(Surrogate): - """Judge spiking state with the slayer surrogate gradient function. + """Judge spiking state with the slayer surrogate gradient function. + + See Also:: - See Also:: - - slayer_grad - """ + slayer_grad + """ - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha - def surrogate_grad(self, x): - dx = jnp.exp(-self.alpha * jnp.abs(x)) - return dx + def surrogate_grad(self, x): + dx = jnp.exp(-self.alpha * jnp.abs(x)) + return dx - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' def slayer_grad( x: Union[jax.Array, Array], alpha: float = 1. ): - r"""Spike function with the slayer surrogate gradient function. + r"""Spike function with the slayer surrogate gradient function. - Forward function: + Forward function: - .. math:: + .. math:: - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \exp(-\alpha |x|) - - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). - """ - return SlayerGrad(alpha=alpha)(x) + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \exp(-\alpha |x|) + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). + """ + return SlayerGrad(alpha=alpha)(x) diff --git a/brainpy/_src/math/surrogate/_two_inputs.py b/brainpy/_src/math/surrogate/_two_inputs.py index 04b8a12d4..8bb68aa64 100644 --- a/brainpy/_src/math/surrogate/_two_inputs.py +++ b/brainpy/_src/math/surrogate/_two_inputs.py @@ -11,8 +11,8 @@ from ._utils import vjp_custom __all__ = [ - 'inv_square_grad2', - 'relu_grad2', + 'inv_square_grad2', + 'relu_grad2', ] @@ -22,17 +22,17 @@ def inv_square_grad2( x_old: Union[jax.Array, Array], alpha: float ): - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=x_new.dtype) + x_new_comp = x_new >= 0 + x_old_comp = x_old < 0 + z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=x_new.dtype) - def grad(dz): - dz = as_jax(dz) - dx_new = (dz / (alpha * jnp.abs(x_new) + 1.0) ** 2) * jnp.asarray(x_old_comp, dtype=x_old.dtype) - dx_old = -(dz / (alpha * jnp.abs(x_old) + 1.0) ** 2) * jnp.asarray(x_new_comp, dtype=x_new.dtype) - return dx_new, dx_old, None + def grad(dz): + dz = as_jax(dz) + dx_new = (dz / (alpha * jnp.abs(x_new) + 1.0) ** 2) * jnp.asarray(x_old_comp, dtype=x_old.dtype) + dx_old = -(dz / (alpha * jnp.abs(x_old) + 1.0) ** 2) * jnp.asarray(x_new_comp, dtype=x_new.dtype) + return dx_new, dx_old, None - return z, grad + return z, grad @vjp_custom(['x_new', 'x_old'], dict(alpha=.3, width=1.)) @@ -42,17 +42,14 @@ def relu_grad2( alpha: float, width: float, ): - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=x_new.dtype) - - def grad(dz): - dz = as_jax(dz) - dx_new = (dz * jnp.maximum(width - jnp.abs(x_new), 0) * alpha) * jnp.asarray(x_old_comp, dtype=x_old.dtype) - dx_old = -(dz * jnp.maximum(width - jnp.abs(x_old), 0) * alpha) * jnp.asarray(x_new_comp, dtype=x_new.dtype) - return dx_new, dx_old, None, None - - return z, grad - + x_new_comp = x_new >= 0 + x_old_comp = x_old < 0 + z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=x_new.dtype) + def grad(dz): + dz = as_jax(dz) + dx_new = (dz * jnp.maximum(width - jnp.abs(x_new), 0) * alpha) * jnp.asarray(x_old_comp, dtype=x_old.dtype) + dx_old = -(dz * jnp.maximum(width - jnp.abs(x_old), 0) * alpha) * jnp.asarray(x_new_comp, dtype=x_new.dtype) + return dx_new, dx_old, None, None + return z, grad diff --git a/brainpy/_src/math/surrogate/_utils.py b/brainpy/_src/math/surrogate/_utils.py index 70bb3ea82..6b9321295 100644 --- a/brainpy/_src/math/surrogate/_utils.py +++ b/brainpy/_src/math/surrogate/_utils.py @@ -8,128 +8,126 @@ import jax from brainpy import check -from brainpy._src.math.ndarray import BaseArray as Array +from brainpy._src.math.ndarray import BaseArray as Array from brainpy.errors import UnsupportedError __all__ = [ - 'get_default', - 'make_return', - 'vjp_custom', + 'get_default', + 'make_return', + 'vjp_custom', ] def get_default(x, default): - if x is None: - return default, False - else: - return x, True + if x is None: + return default, False + else: + return x, True def make_return(r, *args): - if isinstance(r, (tuple, list)): - r = tuple(r) - else: - r = [r] - for a in args: - if a: - r += [None] - return tuple(r) + if isinstance(r, (tuple, list)): + r = tuple(r) + else: + r = [r] + for a in args: + if a: + r += [None] + return tuple(r) def _get_args(f): - reduced_args = [] - for name, par in inspect.signature(f).parameters.items(): - if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: - reduced_args.append(par.name) - - elif par.kind is inspect.Parameter.VAR_POSITIONAL: - reduced_args.append(f'*{par.name}') - - elif par.kind is inspect.Parameter.KEYWORD_ONLY: - raise UnsupportedError() - elif par.kind is inspect.Parameter.POSITIONAL_ONLY: - raise UnsupportedError() - elif par.kind is inspect.Parameter.VAR_KEYWORD: # TODO - raise UnsupportedError() - else: - raise UnsupportedError() - return reduced_args + reduced_args = [] + for name, par in inspect.signature(f).parameters.items(): + if par.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + reduced_args.append(par.name) + elif par.kind is inspect.Parameter.VAR_POSITIONAL: + reduced_args.append(f'*{par.name}') -class VJPCustom(object): - def __init__(self, - func: Callable, - args: Sequence[str], - defaults: Dict = None, - statics: Dict = None, ): - if statics is None: statics = dict() - if defaults is None: defaults = dict() - assert isinstance(statics, dict) - assert isinstance(defaults, dict) - assert callable(func) - check.is_sequence(args, elem_type=str) - - self.n_args = len(args) - self.func = func - self.args = args - self.defaults = tuple(defaults.items()) - self.statics = statics - self.all_args = _get_args(func) - - for k in statics: - if k not in defaults: - raise KeyError(f'{k} defined as "static_args" should provide its default value in "defaults"') - self._cached = {} - if len(statics): - static_vals = list(statics.values()) - products = list(itertools.product(*static_vals, repeat=1)) - for args in products: - string = self._str_static_arg(dict(zip(self.statics.keys(), args))) - self._cached[string] = jax.custom_gradient(partial(self.func, - **dict(zip(self.statics.keys(), args)))) - else: - self._cached[''] = jax.custom_gradient(self.func) - - def _str_static_arg(self, args: Dict): - r = [] - for k in self.statics: - r.append(f'{k}={args[k]}') - return '-'.join(r) - - def __call__(self, *args, **kwargs): - args = list(args) - kwargs = dict(kwargs) - for k in self.args[len(args):]: - if k not in kwargs: - raise ValueError(f'Must provide {k} for function {self.func}') - args.append(kwargs.pop(k)) - for k, v in self.defaults[len(args) - self.n_args:]: - if k not in kwargs: - args.append(v) - else: - args.append(kwargs.pop(k)) - if len(kwargs): - raise KeyError(f'Unknown arguments {kwargs} for function {self.func}') - dynamics = [] - statics = dict() - for k, v in zip(self.all_args, args): - if isinstance(v, Array): v = v.value - if k in self.statics: - statics[k] = v - else: - dynamics.append(v) - return self._cached[self._str_static_arg(statics)](*dynamics) - - -def vjp_custom(args: Sequence[str], defaults: Dict, statics: Dict=None): - """Generalize a customized gradient function as a general Python function. - """ - - def wrapper(fun): - obj = VJPCustom(fun, args, defaults, statics) - obj.__doc__ = fun.__doc__ - return obj - - return wrapper + elif par.kind is inspect.Parameter.KEYWORD_ONLY: + raise UnsupportedError() + elif par.kind is inspect.Parameter.POSITIONAL_ONLY: + raise UnsupportedError() + elif par.kind is inspect.Parameter.VAR_KEYWORD: # TODO + raise UnsupportedError() + else: + raise UnsupportedError() + return reduced_args +class VJPCustom(object): + def __init__(self, + func: Callable, + args: Sequence[str], + defaults: Dict = None, + statics: Dict = None, ): + if statics is None: statics = dict() + if defaults is None: defaults = dict() + assert isinstance(statics, dict) + assert isinstance(defaults, dict) + assert callable(func) + check.is_sequence(args, elem_type=str) + + self.n_args = len(args) + self.func = func + self.args = args + self.defaults = tuple(defaults.items()) + self.statics = statics + self.all_args = _get_args(func) + + for k in statics: + if k not in defaults: + raise KeyError(f'{k} defined as "static_args" should provide its default value in "defaults"') + self._cached = {} + if len(statics): + static_vals = list(statics.values()) + products = list(itertools.product(*static_vals, repeat=1)) + for args in products: + string = self._str_static_arg(dict(zip(self.statics.keys(), args))) + self._cached[string] = jax.custom_gradient(partial(self.func, + **dict(zip(self.statics.keys(), args)))) + else: + self._cached[''] = jax.custom_gradient(self.func) + + def _str_static_arg(self, args: Dict): + r = [] + for k in self.statics: + r.append(f'{k}={args[k]}') + return '-'.join(r) + + def __call__(self, *args, **kwargs): + args = list(args) + kwargs = dict(kwargs) + for k in self.args[len(args):]: + if k not in kwargs: + raise ValueError(f'Must provide {k} for function {self.func}') + args.append(kwargs.pop(k)) + for k, v in self.defaults[len(args) - self.n_args:]: + if k not in kwargs: + args.append(v) + else: + args.append(kwargs.pop(k)) + if len(kwargs): + raise KeyError(f'Unknown arguments {kwargs} for function {self.func}') + dynamics = [] + statics = dict() + for k, v in zip(self.all_args, args): + if isinstance(v, Array): v = v.value + if k in self.statics: + statics[k] = v + else: + dynamics.append(v) + return self._cached[self._str_static_arg(statics)](*dynamics) + + +def vjp_custom(args: Sequence[str], defaults: Dict, statics: Dict = None): + """Generalize a customized gradient function as a general Python function. + """ + + def wrapper(fun): + obj = VJPCustom(fun, args, defaults, statics) + obj.__doc__ = fun.__doc__ + return obj + + return wrapper diff --git a/brainpy/_src/math/surrogate/tests/test_one_input.py b/brainpy/_src/math/surrogate/tests/test_one_input.py index cced63cf8..4c2e297e3 100644 --- a/brainpy/_src/math/surrogate/tests/test_one_input.py +++ b/brainpy/_src/math/surrogate/tests/test_one_input.py @@ -8,44 +8,44 @@ class TestOneInputGrad(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(TestOneInputGrad, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=f'{name}_x64={x64}', - func=getattr(one_input, name), - x64=x64) - for name in one_input.__all__ - for x64 in [True, False] - ) - def test_bm_grad(self, func, x64): - if x64: - bm.enable_x64() - - xs = bm.arange(-3, 3, 0.005) - grads = bm.vector_grad(func)(xs) - self.assertTrue(grads.size == xs.size) - - if x64: - bm.disable_x64() - - @parameterized.named_parameters( - dict(testcase_name=f'{name}_x64={x64}', - func=getattr(one_input, name), - x64=x64, ) - for name in one_input.__all__ - for x64 in [True, False] - ) - def test_jax_vjp(self, func, x64): - if x64: - bm.enable_x64() - - xs = bm.arange(-3, 3, 0.005) - primals, f_vjp = jax.vjp(func, xs) - grad2 = f_vjp(jax.numpy.ones_like(xs)) - self.assertTrue(grad2[0].size == xs.size) - - if x64: - bm.disable_x64() + def __init__(self, *args, platform='cpu', **kwargs): + super(TestOneInputGrad, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.named_parameters( + dict(testcase_name=f'{name}_x64={x64}', + func=getattr(one_input, name), + x64=x64) + for name in one_input.__all__ + for x64 in [True, False] + ) + def test_bm_grad(self, func, x64): + if x64: + bm.enable_x64() + + xs = bm.arange(-3, 3, 0.005) + grads = bm.vector_grad(func)(xs) + self.assertTrue(grads.size == xs.size) + + if x64: + bm.disable_x64() + + @parameterized.named_parameters( + dict(testcase_name=f'{name}_x64={x64}', + func=getattr(one_input, name), + x64=x64, ) + for name in one_input.__all__ + for x64 in [True, False] + ) + def test_jax_vjp(self, func, x64): + if x64: + bm.enable_x64() + + xs = bm.arange(-3, 3, 0.005) + primals, f_vjp = jax.vjp(func, xs) + grad2 = f_vjp(jax.numpy.ones_like(xs)) + self.assertTrue(grad2[0].size == xs.size) + + if x64: + bm.disable_x64() diff --git a/brainpy/_src/math/surrogate/tests/test_two_inputs.py b/brainpy/_src/math/surrogate/tests/test_two_inputs.py index d944ccd18..d08e8da18 100644 --- a/brainpy/_src/math/surrogate/tests/test_two_inputs.py +++ b/brainpy/_src/math/surrogate/tests/test_two_inputs.py @@ -8,44 +8,44 @@ class TestTwoInputsGrad(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(TestTwoInputsGrad, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=f'{name}_x64={x64}', - func=getattr(two_inputs, name), - x64=x64) - for name in two_inputs.__all__ - for x64 in [True, False] - ) - def test_bm_grad(self, func, x64): - if x64: - bm.enable_x64() - - xs = bm.arange(-3, 3, 0.005) - grads = bm.vector_grad(func)(xs[:-1], xs[1:]) - self.assertTrue(grads.size == xs.size-1) - - if x64: - bm.disable_x64() - - @parameterized.named_parameters( - dict(testcase_name=f'{name}_x64={x64}', - func=getattr(two_inputs, name), - x64=x64, ) - for name in two_inputs.__all__ - for x64 in [True, False] - ) - def test_jax_vjp(self, func, x64): - if x64: - bm.enable_x64() - - xs = bm.arange(-3, 3, 0.005) - primals, f_vjp = jax.vjp(func, xs[:-1], xs[1:]) - grad2 = f_vjp(jax.numpy.ones(xs.size-1)) - self.assertTrue(grad2[0].size == xs.size - 1) - - if x64: - bm.disable_x64() + def __init__(self, *args, platform='cpu', **kwargs): + super(TestTwoInputsGrad, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.named_parameters( + dict(testcase_name=f'{name}_x64={x64}', + func=getattr(two_inputs, name), + x64=x64) + for name in two_inputs.__all__ + for x64 in [True, False] + ) + def test_bm_grad(self, func, x64): + if x64: + bm.enable_x64() + + xs = bm.arange(-3, 3, 0.005) + grads = bm.vector_grad(func)(xs[:-1], xs[1:]) + self.assertTrue(grads.size == xs.size - 1) + + if x64: + bm.disable_x64() + + @parameterized.named_parameters( + dict(testcase_name=f'{name}_x64={x64}', + func=getattr(two_inputs, name), + x64=x64, ) + for name in two_inputs.__all__ + for x64 in [True, False] + ) + def test_jax_vjp(self, func, x64): + if x64: + bm.enable_x64() + + xs = bm.arange(-3, 3, 0.005) + primals, f_vjp = jax.vjp(func, xs[:-1], xs[1:]) + grad2 = f_vjp(jax.numpy.ones(xs.size - 1)) + self.assertTrue(grad2[0].size == xs.size - 1) + + if x64: + bm.disable_x64() diff --git a/brainpy/_src/math/tests/test_array_format.py b/brainpy/_src/math/tests/test_array_format.py index 7d87e528d..e8f621ec5 100644 --- a/brainpy/_src/math/tests/test_array_format.py +++ b/brainpy/_src/math/tests/test_array_format.py @@ -1,14 +1,11 @@ - - import brainpy.math as bm def test_format(): - print(bm.ones((5))) - print(bm.Variable(bm.ones((5)))) - print(bm.VariableView(bm.Variable(bm.ones((5))), bm.asarray([1, 2, 3]))) - - print(bm.ones((3, 4))) - print(bm.Variable(bm.ones((3, 4)))) - print(bm.VariableView(bm.Variable(bm.ones((3, 4))), bm.asarray([1, 2]))) + print(bm.ones((5))) + print(bm.Variable(bm.ones((5)))) + print(bm.VariableView(bm.Variable(bm.ones((5))), bm.asarray([1, 2, 3]))) + print(bm.ones((3, 4))) + print(bm.Variable(bm.ones((3, 4)))) + print(bm.VariableView(bm.Variable(bm.ones((3, 4))), bm.asarray([1, 2]))) diff --git a/brainpy/_src/math/tests/test_compat_pytorch.py b/brainpy/_src/math/tests/test_compat_pytorch.py index ce3dd161b..ec4fc58d1 100644 --- a/brainpy/_src/math/tests/test_compat_pytorch.py +++ b/brainpy/_src/math/tests/test_compat_pytorch.py @@ -1,78 +1,76 @@ # -*- coding: utf-8 -*- - - -import jax.numpy as jnp import unittest + +import brainpy._src.math.compat_pytorch as torch import brainpy.math as bm from brainpy._src.math import compat_pytorch -import brainpy._src.math.compat_pytorch as torch - -from absl .testing import parameterized class TestFlatten(unittest.TestCase): - def test1(self): - rng = bm.random.default_rng(113) - arr = rng.rand(3, 4, 5) - a2 = compat_pytorch.flatten(arr, 1, 2) - self.assertTrue(a2.shape == (3, 20)) - a2 = compat_pytorch.flatten(arr, 0, 1) - self.assertTrue(a2.shape == (12, 5)) + def test1(self): + rng = bm.random.default_rng(113) + arr = rng.rand(3, 4, 5) + a2 = compat_pytorch.flatten(arr, 1, 2) + self.assertTrue(a2.shape == (3, 20)) + a2 = compat_pytorch.flatten(arr, 0, 1) + self.assertTrue(a2.shape == (12, 5)) - def test2(self): - rng = bm.random.default_rng(234) - arr = rng.rand() - self.assertTrue(arr.ndim == 0) - arr = compat_pytorch.flatten(arr) - self.assertTrue(arr.ndim == 1) + def test2(self): + rng = bm.random.default_rng(234) + arr = rng.rand() + self.assertTrue(arr.ndim == 0) + arr = compat_pytorch.flatten(arr) + self.assertTrue(arr.ndim == 1) class TestUnsqueeze(unittest.TestCase): - def test1(self): - rng = bm.random.default_rng(999) - arr = rng.rand(3, 4, 5) - a = compat_pytorch.unsqueeze(arr, 0) - self.assertTrue(a.shape == (1, 3, 4, 5)) - a = compat_pytorch.unsqueeze(arr, -3) - self.assertTrue(a.shape == (3, 1, 4, 5)) + def test1(self): + rng = bm.random.default_rng(999) + arr = rng.rand(3, 4, 5) + a = compat_pytorch.unsqueeze(arr, 0) + self.assertTrue(a.shape == (1, 3, 4, 5)) + a = compat_pytorch.unsqueeze(arr, -3) + self.assertTrue(a.shape == (3, 1, 4, 5)) + class TestExpand(unittest.TestCase): - def test1(self): - rng = bm.random.default_rng(121) - arr = rng.rand(1, 4, 5) - a = compat_pytorch.Tensor(arr) - a = a.expand(1, 6, 4, -1) - self.assertTrue(a.shape == (1, 6, 4, 5)) + def test1(self): + rng = bm.random.default_rng(121) + arr = rng.rand(1, 4, 5) + a = compat_pytorch.Tensor(arr) + a = a.expand(1, 6, 4, -1) + self.assertTrue(a.shape == (1, 6, 4, 5)) + class TestMathOperators(unittest.TestCase): - def test_abs(self): - arr = compat_pytorch.Tensor([-1, -2, 3]) - a = compat_pytorch.abs(arr) - res = compat_pytorch.Tensor([1, 2, 3]) - b = compat_pytorch.absolute(arr) - self.assertTrue(bm.array_equal(a, res)) - self.assertTrue(bm.array_equal(b, res)) + def test_abs(self): + arr = compat_pytorch.Tensor([-1, -2, 3]) + a = compat_pytorch.abs(arr) + res = compat_pytorch.Tensor([1, 2, 3]) + b = compat_pytorch.absolute(arr) + self.assertTrue(bm.array_equal(a, res)) + self.assertTrue(bm.array_equal(b, res)) - def test_add(self): - a = compat_pytorch.Tensor([0.0202, 1.0985, 1.3506, -0.6056]) - a = compat_pytorch.add(a, 20) - res = compat_pytorch.Tensor([20.0202, 21.0985, 21.3506, 19.3944]) - self.assertTrue(bm.array_equal(a, res)) - b = compat_pytorch.Tensor([-0.9732, -0.3497, 0.6245, 0.4022]) - c = compat_pytorch.Tensor([[0.3743], [-1.7724], [-0.5811], [-0.8017]]) - b = compat_pytorch.add(b, c, alpha=10) - self.assertTrue(b.shape == (4, 4)) - print("b:", b) + def test_add(self): + a = compat_pytorch.Tensor([0.0202, 1.0985, 1.3506, -0.6056]) + a = compat_pytorch.add(a, 20) + res = compat_pytorch.Tensor([20.0202, 21.0985, 21.3506, 19.3944]) + self.assertTrue(bm.array_equal(a, res)) + b = compat_pytorch.Tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + c = compat_pytorch.Tensor([[0.3743], [-1.7724], [-0.5811], [-0.8017]]) + b = compat_pytorch.add(b, c, alpha=10) + self.assertTrue(b.shape == (4, 4)) + print("b:", b) - def test_addcdiv(self): - rng = bm.random.default_rng(999) - t = rng.rand(1, 3) - t1 = rng.randn(3, 1) - rng = bm.random.default_rng(199) - t2 = rng.randn(1, 3) - res = torch.addcdiv(t, t1, t2, value=0.1) - print("t + t1/t2 * value:", res) - res = torch.addcmul(t, t1, t2, value=0.1) - print("t + t1*t2 * value:", res) + def test_addcdiv(self): + rng = bm.random.default_rng(999) + t = rng.rand(1, 3) + t1 = rng.randn(3, 1) + rng = bm.random.default_rng(199) + t2 = rng.randn(1, 3) + res = torch.addcdiv(t, t1, t2, value=0.1) + print("t + t1/t2 * value:", res) + res = torch.addcmul(t, t1, t2, value=0.1) + print("t + t1*t2 * value:", res) diff --git a/brainpy/_src/math/tests/test_defaults.py b/brainpy/_src/math/tests/test_defaults.py index 9076829b7..009a7b370 100644 --- a/brainpy/_src/math/tests/test_defaults.py +++ b/brainpy/_src/math/tests/test_defaults.py @@ -4,33 +4,33 @@ class TestDefaults(unittest.TestCase): - def test_dt(self): - with bm.environment(dt=1.0): - self.assertEqual(bm.dt, 1.0) - self.assertEqual(bm.get_dt(), 1.0) - - def test_bool(self): - with bm.environment(bool_=bm.int32): - self.assertTrue(bm.bool_ == bm.int32) - self.assertTrue(bm.get_bool() == bm.int32) - - def test_int(self): - with bm.environment(int_=bm.int32): - self.assertTrue(bm.int == bm.int32) - self.assertTrue(bm.get_int() == bm.int32) - - def test_float(self): - with bm.environment(float_=bm.float32): - self.assertTrue(bm.float_ == bm.float32) - self.assertTrue(bm.get_float() == bm.float32) - - def test_complex(self): - with bm.environment(complex_=bm.complex64): - self.assertTrue(bm.complex_ == bm.complex64) - self.assertTrue(bm.get_complex() == bm.complex64) - - def test_mode(self): - mode = bm.TrainingMode() - with bm.environment(mode=mode): - self.assertTrue(bm.mode == mode) - self.assertTrue(bm.get_mode() == mode) + def test_dt(self): + with bm.environment(dt=1.0): + self.assertEqual(bm.dt, 1.0) + self.assertEqual(bm.get_dt(), 1.0) + + def test_bool(self): + with bm.environment(bool_=bm.int32): + self.assertTrue(bm.bool_ == bm.int32) + self.assertTrue(bm.get_bool() == bm.int32) + + def test_int(self): + with bm.environment(int_=bm.int32): + self.assertTrue(bm.int == bm.int32) + self.assertTrue(bm.get_int() == bm.int32) + + def test_float(self): + with bm.environment(float_=bm.float32): + self.assertTrue(bm.float_ == bm.float32) + self.assertTrue(bm.get_float() == bm.float32) + + def test_complex(self): + with bm.environment(complex_=bm.complex64): + self.assertTrue(bm.complex_ == bm.complex64) + self.assertTrue(bm.get_complex() == bm.complex64) + + def test_mode(self): + mode = bm.TrainingMode() + with bm.environment(mode=mode): + self.assertTrue(bm.mode == mode) + self.assertTrue(bm.get_mode() == mode) diff --git a/brainpy/_src/math/tests/test_delay_vars.py b/brainpy/_src/math/tests/test_delay_vars.py index 014830ffd..425b9f64a 100644 --- a/brainpy/_src/math/tests/test_delay_vars.py +++ b/brainpy/_src/math/tests/test_delay_vars.py @@ -9,134 +9,128 @@ class TestTimeDelay(unittest.TestCase): - def test_dim1(self): - bm.enable_x64() - - # linear interp - t0 = 0. - before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) - delay = bm.TimeDelay(jnp.zeros(10), - delay_len=1., - t0=t0, - dt=0.1, - before_t0=before_t0, - interp_method='linear_interp') - print(delay(t0 - 0.1)) - print(delay(t0 - 0.15)) - self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones(10) * 9.)) - self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones(10) * 8.5)) - print() - print(delay(t0 - 0.23)) - print(delay(t0 - 0.23) - 7.7) - self.assertTrue(bm.allclose(delay(t0 - 0.23), bm.ones(10) * 7.7)) - - # round interp - delay = bm.TimeDelay(jnp.zeros(10), - delay_len=1., - t0=t0, - dt=0.1, - before_t0=before_t0, - interp_method='round') - self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones(10) * 9)) - print(delay(t0 - 0.15)) - self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones(10) * 8)) - self.assertTrue(jnp.allclose(delay(t0 - 0.2), jnp.ones(10) * 8)) - - bm.disable_x64() - - - def test_dim2(self): - t0 = 0. - before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) - before_t0 = jnp.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) - delay = bm.TimeDelay(jnp.zeros((10, 5)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) - self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones((10, 5)) * 9)) - self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones((10, 5)) * 8.5)) - self.assertTrue(bm.allclose(delay(t0 - 0.23), bm.ones((10, 5)) * 7.7)) - - - def test_dim3(self): - t0 = 0. - before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) - before_t0 = jnp.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) - before_t0 = jnp.repeat(before_t0.reshape((10, 10, 5, 1)), 3, axis=3) - delay = bm.TimeDelay(jnp.zeros((10, 5, 3)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) - self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones((10, 5, 3)) * 9)) - self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones((10, 5, 3)) * 8.5)) - self.assertTrue(bm.allclose(delay(t0 - 0.23), bm.ones((10, 5, 3)) * 7.7)) - - - def test1(self): - print() - delay = bm.TimeDelay(jnp.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) - self.assertTrue(bm.allclose(delay(-0.2), bm.ones(3) * -0.2)) - delay = bm.TimeDelay(jnp.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) - self.assertTrue(bm.allclose(delay(-0.6), bm.ones((3, 2)) * -0.6)) - delay = bm.TimeDelay(jnp.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) - self.assertTrue(bm.allclose(delay(-0.8), jnp.ones((3, 2, 1)) * -0.8)) - - def test_current_time2(self): - print() - delay = bm.TimeDelay(jnp.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) - print(delay(0.)) - before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) - before_t0 = jnp.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) - delay = bm.TimeDelay(jnp.zeros((10, 5)), delay_len=1., dt=0.1, before_t0=before_t0) - print(delay(0.)) - - - # def test_prev_time_beyond_boundary(self): - # with self.assertRaises(ValueError): - # delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) - # delay(-1.2) + def test_dim1(self): + bm.enable_x64() + + # linear interp + t0 = 0. + before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) + delay = bm.TimeDelay(jnp.zeros(10), + delay_len=1., + t0=t0, + dt=0.1, + before_t0=before_t0, + interp_method='linear_interp') + print(delay(t0 - 0.1)) + print(delay(t0 - 0.15)) + self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones(10) * 9.)) + self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones(10) * 8.5)) + print() + print(delay(t0 - 0.23)) + print(delay(t0 - 0.23) - 7.7) + self.assertTrue(bm.allclose(delay(t0 - 0.23), bm.ones(10) * 7.7)) + + # round interp + delay = bm.TimeDelay(jnp.zeros(10), + delay_len=1., + t0=t0, + dt=0.1, + before_t0=before_t0, + interp_method='round') + self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones(10) * 9)) + print(delay(t0 - 0.15)) + self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones(10) * 8)) + self.assertTrue(jnp.allclose(delay(t0 - 0.2), jnp.ones(10) * 8)) + + bm.disable_x64() + + def test_dim2(self): + t0 = 0. + before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) + before_t0 = jnp.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) + delay = bm.TimeDelay(jnp.zeros((10, 5)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) + self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones((10, 5)) * 9)) + self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones((10, 5)) * 8.5)) + self.assertTrue(bm.allclose(delay(t0 - 0.23), bm.ones((10, 5)) * 7.7)) + + def test_dim3(self): + t0 = 0. + before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) + before_t0 = jnp.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) + before_t0 = jnp.repeat(before_t0.reshape((10, 10, 5, 1)), 3, axis=3) + delay = bm.TimeDelay(jnp.zeros((10, 5, 3)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) + self.assertTrue(jnp.allclose(delay(t0 - 0.1), jnp.ones((10, 5, 3)) * 9)) + self.assertTrue(jnp.allclose(delay(t0 - 0.15), jnp.ones((10, 5, 3)) * 8.5)) + self.assertTrue(bm.allclose(delay(t0 - 0.23), bm.ones((10, 5, 3)) * 7.7)) + + def test1(self): + print() + delay = bm.TimeDelay(jnp.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) + self.assertTrue(bm.allclose(delay(-0.2), bm.ones(3) * -0.2)) + delay = bm.TimeDelay(jnp.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) + self.assertTrue(bm.allclose(delay(-0.6), bm.ones((3, 2)) * -0.6)) + delay = bm.TimeDelay(jnp.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) + self.assertTrue(bm.allclose(delay(-0.8), jnp.ones((3, 2, 1)) * -0.8)) + + def test_current_time2(self): + print() + delay = bm.TimeDelay(jnp.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) + print(delay(0.)) + before_t0 = jnp.repeat(jnp.arange(10).reshape((-1, 1)), 10, axis=1) + before_t0 = jnp.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) + delay = bm.TimeDelay(jnp.zeros((10, 5)), delay_len=1., dt=0.1, before_t0=before_t0) + print(delay(0.)) + + # def test_prev_time_beyond_boundary(self): + # with self.assertRaises(ValueError): + # delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) + # delay(-1.2) class TestLengthDelay(unittest.TestCase): - def test1(self): - dim = 3 - for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: - delay = bm.LengthDelay(jnp.zeros(dim), 10, update_method=update_method) - print(delay(1)) - self.assertTrue(jnp.allclose(delay(1), jnp.zeros(dim))) - - delay = bm.jit(delay) - print(delay(1)) - self.assertTrue(jnp.allclose(delay(1), jnp.zeros(dim))) - - def test2(self): - dim = 3 - for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: - delay = bm.LengthDelay(jnp.zeros(dim), 10, - # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), - initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), - update_method=update_method) - print(delay(0)) - self.assertTrue(jnp.allclose(delay(0), jnp.zeros(dim))) - print(delay(1)) - self.assertTrue(jnp.allclose(delay(1), jnp.ones(dim) * 10)) - - delay = bm.jit(delay) - print(delay(0)) - self.assertTrue(jnp.allclose(delay(0), jnp.zeros(dim))) - print(delay(1)) - self.assertTrue(jnp.allclose(delay(1), jnp.ones(dim) * 10)) - - def test3(self): - dim = 3 - for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: - delay = bm.LengthDelay(jnp.zeros(dim), 10, - # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), - initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), - update_method=update_method) - print(delay(jnp.asarray([1, 2, 3]), - jnp.arange(3))) - self.assertTrue(bm.allclose(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)), - bm.asarray([10., 9., 8.]))) - - delay = bm.jit(delay) - print(delay(jnp.asarray([1, 2, 3]), - jnp.arange(3))) - self.assertTrue(bm.allclose(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)), - bm.asarray([10., 9., 8.]))) - - + def test1(self): + dim = 3 + for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: + delay = bm.LengthDelay(jnp.zeros(dim), 10, update_method=update_method) + print(delay(1)) + self.assertTrue(jnp.allclose(delay(1), jnp.zeros(dim))) + + delay = bm.jit(delay) + print(delay(1)) + self.assertTrue(jnp.allclose(delay(1), jnp.zeros(dim))) + + def test2(self): + dim = 3 + for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: + delay = bm.LengthDelay(jnp.zeros(dim), 10, + # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), + initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), + update_method=update_method) + print(delay(0)) + self.assertTrue(jnp.allclose(delay(0), jnp.zeros(dim))) + print(delay(1)) + self.assertTrue(jnp.allclose(delay(1), jnp.ones(dim) * 10)) + + delay = bm.jit(delay) + print(delay(0)) + self.assertTrue(jnp.allclose(delay(0), jnp.zeros(dim))) + print(delay(1)) + self.assertTrue(jnp.allclose(delay(1), jnp.ones(dim) * 10)) + + def test3(self): + dim = 3 + for update_method in [ROTATE_UPDATE, CONCAT_UPDATE]: + delay = bm.LengthDelay(jnp.zeros(dim), 10, + # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), + initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), + update_method=update_method) + print(delay(jnp.asarray([1, 2, 3]), + jnp.arange(3))) + self.assertTrue(bm.allclose(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)), + bm.asarray([10., 9., 8.]))) + + delay = bm.jit(delay) + print(delay(jnp.asarray([1, 2, 3]), + jnp.arange(3))) + self.assertTrue(bm.allclose(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)), + bm.asarray([10., 9., 8.]))) diff --git a/brainpy/_src/math/tests/test_einops.py b/brainpy/_src/math/tests/test_einops.py index 2f018d973..c6be1cd3d 100644 --- a/brainpy/_src/math/tests/test_einops.py +++ b/brainpy/_src/math/tests/test_einops.py @@ -8,324 +8,325 @@ REDUCTIONS = ("min", "max", "sum", "mean", "prod") identity_patterns = [ - "...->...", - "a b c d e-> a b c d e", - "a b c d e ...-> ... a b c d e", - "a b c d e ...-> a ... b c d e", - "... a b c d e -> ... a b c d e", - "a ... e-> a ... e", - "a ... -> a ... ", - "a ... c d e -> a (...) c d e", + "...->...", + "a b c d e-> a b c d e", + "a b c d e ...-> ... a b c d e", + "a b c d e ...-> a ... b c d e", + "... a b c d e -> ... a b c d e", + "a ... e-> a ... e", + "a ... -> a ... ", + "a ... c d e -> a (...) c d e", ] equivalent_rearrange_patterns = [ - ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "), - ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"), - ("a b c d e -> a b c d e", "... -> ... "), - ("a b c d e -> (a b c d e)", "... -> (...)"), - ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"), - ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"), + ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "), + ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"), + ("a b c d e -> a b c d e", "... -> ... "), + ("a b c d e -> (a b c d e)", "... -> (...)"), + ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"), + ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"), ] equivalent_reduction_patterns = [ - ("a b c d e -> ", " ... -> "), - ("a b c d e -> (e a)", "a ... e -> (e a)"), - ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "), - ("a b c d e -> (a b)", " ... c d e -> (...) "), + ("a b c d e -> ", " ... -> "), + ("a b c d e -> (e a)", "a ... e -> (e a)"), + ("a b c d e -> d (a e)", " a b c d e ... -> d (a e) "), + ("a b c d e -> (a b)", " ... c d e -> (...) "), ] def test_collapsed_ellipsis_errors_out(): - x = numpy.zeros([1, 1, 1, 1, 1]) - ein_rearrange(x, "a b c d ... -> a b c ... d") - with pytest.raises(EinopsError): - ein_rearrange(x, "a b c d (...) -> a b c ... d") + x = numpy.zeros([1, 1, 1, 1, 1]) + ein_rearrange(x, "a b c d ... -> a b c ... d") + with pytest.raises(EinopsError): + ein_rearrange(x, "a b c d (...) -> a b c ... d") - ein_rearrange(x, "... -> (...)") - with pytest.raises(EinopsError): - ein_rearrange(x, "(...) -> (...)") + ein_rearrange(x, "... -> (...)") + with pytest.raises(EinopsError): + ein_rearrange(x, "(...) -> (...)") def test_ellipsis_ops_numpy(): - x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) - for pattern in identity_patterns: - assert numpy.array_equal(x, ein_rearrange(x, pattern)), pattern + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + for pattern in identity_patterns: + assert numpy.array_equal(x, ein_rearrange(x, pattern)), pattern - for pattern1, pattern2 in equivalent_rearrange_patterns: - assert numpy.array_equal(ein_rearrange(x, pattern1), ein_rearrange(x, pattern2)) + for pattern1, pattern2 in equivalent_rearrange_patterns: + assert numpy.array_equal(ein_rearrange(x, pattern1), ein_rearrange(x, pattern2)) - for reduction in ["min", "max", "sum"]: - for pattern1, pattern2 in equivalent_reduction_patterns: - assert numpy.array_equal(ein_reduce(x, pattern1, reduction=reduction), - ein_reduce(x, pattern2, reduction=reduction)) + for reduction in ["min", "max", "sum"]: + for pattern1, pattern2 in equivalent_reduction_patterns: + assert numpy.array_equal(ein_reduce(x, pattern1, reduction=reduction), + ein_reduce(x, pattern2, reduction=reduction)) - # now just check coincidence with numpy - all_rearrange_patterns = [*identity_patterns] - for pattern_pairs in equivalent_rearrange_patterns: - all_rearrange_patterns.extend(pattern_pairs) + # now just check coincidence with numpy + all_rearrange_patterns = [*identity_patterns] + for pattern_pairs in equivalent_rearrange_patterns: + all_rearrange_patterns.extend(pattern_pairs) def test_rearrange_consistency_numpy(): - shape = [1, 2, 3, 5, 7, 11] - x = numpy.arange(numpy.prod(shape)).reshape(shape) - for pattern in [ - "a b c d e f -> a b c d e f", - "b a c d e f -> a b d e f c", - "a b c d e f -> f e d c b a", - "a b c d e f -> (f e) d (c b a)", - "a b c d e f -> (f e d c b a)", - ]: - result = ein_rearrange(x, pattern) - assert len(numpy.setdiff1d(x, result)) == 0 - - result = ein_rearrange(x, "a b c d e f -> a (b) (c d e) f") - assert numpy.array_equal(x.flatten(), result.flatten()) - - result = ein_rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11") - assert numpy.array_equal(x, result) - - result1 = ein_rearrange(x, "a b c d e f -> f e d c b a") - result2 = ein_rearrange(x, "f e d c b a -> a b c d e f") - assert numpy.array_equal(result1, result2) - - result = ein_rearrange(ein_rearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, d=5) - assert numpy.array_equal(x, result) - - sizes = dict(zip("abcdef", shape)) - temp = ein_rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes) - result = ein_rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes) - assert numpy.array_equal(x, result) - - x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4]) - result = ein_rearrange(x2, "a b c -> b c a") - assert x2[1, 2, 3] == result[2, 3, 1] - assert x2[0, 1, 2] == result[1, 2, 0] + shape = [1, 2, 3, 5, 7, 11] + x = numpy.arange(numpy.prod(shape)).reshape(shape) + for pattern in [ + "a b c d e f -> a b c d e f", + "b a c d e f -> a b d e f c", + "a b c d e f -> f e d c b a", + "a b c d e f -> (f e) d (c b a)", + "a b c d e f -> (f e d c b a)", + ]: + result = ein_rearrange(x, pattern) + assert len(numpy.setdiff1d(x, result)) == 0 + + result = ein_rearrange(x, "a b c d e f -> a (b) (c d e) f") + assert numpy.array_equal(x.flatten(), result.flatten()) + + result = ein_rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11") + assert numpy.array_equal(x, result) + + result1 = ein_rearrange(x, "a b c d e f -> f e d c b a") + result2 = ein_rearrange(x, "f e d c b a -> a b c d e f") + assert numpy.array_equal(result1, result2) + + result = ein_rearrange(ein_rearrange(x, "a b c d e f -> (f d) c (e b) a"), "(f d) c (e b) a -> a b c d e f", b=2, + d=5) + assert numpy.array_equal(x, result) + + sizes = dict(zip("abcdef", shape)) + temp = ein_rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes) + result = ein_rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes) + assert numpy.array_equal(x, result) + + x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4]) + result = ein_rearrange(x2, "a b c -> b c a") + assert x2[1, 2, 3] == result[2, 3, 1] + assert x2[0, 1, 2] == result[1, 2, 0] def test_rearrange_permutations_numpy(): - # tests random permutation of axes against two independent numpy ways - for n_axes in range(1, 10): - input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) - permutation = numpy.random.permutation(n_axes) - left_expression = " ".join("i" + str(axis) for axis in range(n_axes)) - right_expression = " ".join("i" + str(axis) for axis in permutation) - expression = left_expression + " -> " + right_expression - result = ein_rearrange(input, expression) - - for pick in numpy.random.randint(0, 2, [10, n_axes]): - assert input[tuple(pick)] == result[tuple(pick[permutation])] - - for n_axes in range(1, 10): - input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) - permutation = numpy.random.permutation(n_axes) - left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1]) - right_expression = " ".join("i" + str(axis) for axis in permutation[::-1]) - expression = left_expression + " -> " + right_expression - result = ein_rearrange(input, expression) - assert result.shape == input.shape - expected_result = numpy.zeros_like(input) - for original_axis, result_axis in enumerate(permutation): - expected_result |= ((input >> original_axis) & 1) << result_axis - - assert numpy.array_equal(result, expected_result) + # tests random permutation of axes against two independent numpy ways + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = " ".join("i" + str(axis) for axis in range(n_axes)) + right_expression = " ".join("i" + str(axis) for axis in permutation) + expression = left_expression + " -> " + right_expression + result = ein_rearrange(input, expression) + + for pick in numpy.random.randint(0, 2, [10, n_axes]): + assert input[tuple(pick)] == result[tuple(pick[permutation])] + + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1]) + right_expression = " ".join("i" + str(axis) for axis in permutation[::-1]) + expression = left_expression + " -> " + right_expression + result = ein_rearrange(input, expression) + assert result.shape == input.shape + expected_result = numpy.zeros_like(input) + for original_axis, result_axis in enumerate(permutation): + expected_result |= ((input >> original_axis) & 1) << result_axis + + assert numpy.array_equal(result, expected_result) def test_reduction_imperatives(): - for reduction in REDUCTIONS: - # slight redundancy for simpler order - numpy version is evaluated multiple times - input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6]) - if reduction in ["mean", "prod"]: - input = input / input.astype("float64").mean() - test_cases = [ - ["a b c d e -> ", {}, getattr(input, reduction)()], - ["a ... -> ", {}, getattr(input, reduction)()], - ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()], - [ - "a b c d e -> (e c) a", - {}, - getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), - ], - [ - "a ... c d e -> (e c) a", - {}, - getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), - ], - [ - "a b c d e ... -> (e c) a", - {}, - getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), - ], - ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])], - ["(a a2) ... -> (a2 a) ...", dict(a2=1), input], - ] - for pattern, axes_lengths, expected_result in test_cases: - result = ein_reduce(bm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths) - result = bm.as_numpy(result) - print(reduction, pattern, expected_result, result) - assert numpy.allclose(result, expected_result), f"Failed at {pattern}" + for reduction in REDUCTIONS: + # slight redundancy for simpler order - numpy version is evaluated multiple times + input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype="int64").reshape([2, 3, 4, 5, 6]) + if reduction in ["mean", "prod"]: + input = input / input.astype("float64").mean() + test_cases = [ + ["a b c d e -> ", {}, getattr(input, reduction)()], + ["a ... -> ", {}, getattr(input, reduction)()], + ["(a1 a2) ... (e1 e2) -> ", dict(a1=1, e2=2), getattr(input, reduction)()], + [ + "a b c d e -> (e c) a", + {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), + ], + [ + "a ... c d e -> (e c) a", + {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), + ], + [ + "a b c d e ... -> (e c) a", + {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2]), + ], + ["a b c d e -> (e c a)", {}, getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])], + ["(a a2) ... -> (a2 a) ...", dict(a2=1), input], + ] + for pattern, axes_lengths, expected_result in test_cases: + result = ein_reduce(bm.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths) + result = bm.as_numpy(result) + print(reduction, pattern, expected_result, result) + assert numpy.allclose(result, expected_result), f"Failed at {pattern}" def test_enumerating_directions(): - for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]: - x = numpy.arange(numpy.prod(shape)).reshape(shape) - axes1 = _enumerate_directions(x) - axes2 = _enumerate_directions(bm.from_numpy(x)) - assert len(axes1) == len(axes2) == len(shape) - for ax1, ax2 in zip(axes1, axes2): - ax2 = bm.as_numpy(ax2) - assert ax1.shape == ax2.shape - assert numpy.allclose(ax1, ax2) + for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]: + x = numpy.arange(numpy.prod(shape)).reshape(shape) + axes1 = _enumerate_directions(x) + axes2 = _enumerate_directions(bm.from_numpy(x)) + assert len(axes1) == len(axes2) == len(shape) + for ax1, ax2 in zip(axes1, axes2): + ax2 = bm.as_numpy(ax2) + assert ax1.shape == ax2.shape + assert numpy.allclose(ax1, ax2) def test_concatenations_and_stacking(): - for n_arrays in [1, 2, 5]: - shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] - for shape in shapes: - arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)] - arrays2 = [bm.from_numpy(array) for array in arrays1] - result0 = numpy.asarray(arrays1) - result1 = ein_rearrange(arrays1, "...->...") - result2 = ein_rearrange(arrays2, "...->...") - assert numpy.array_equal(result0, result1) - assert numpy.array_equal(result1, bm.as_numpy(result2)) - - result1 = ein_rearrange(arrays1, "b ... -> ... b") - result2 = ein_rearrange(arrays2, "b ... -> ... b") - assert numpy.array_equal(result1, bm.as_numpy(result2)) + for n_arrays in [1, 2, 5]: + shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] + for shape in shapes: + arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)] + arrays2 = [bm.from_numpy(array) for array in arrays1] + result0 = numpy.asarray(arrays1) + result1 = ein_rearrange(arrays1, "...->...") + result2 = ein_rearrange(arrays2, "...->...") + assert numpy.array_equal(result0, result1) + assert numpy.array_equal(result1, bm.as_numpy(result2)) + + result1 = ein_rearrange(arrays1, "b ... -> ... b") + result2 = ein_rearrange(arrays2, "b ... -> ... b") + assert numpy.array_equal(result1, bm.as_numpy(result2)) def test_gradients_imperatives(): - # lazy - just checking reductions - for reduction in REDUCTIONS: - if reduction in ("any", "all"): - continue # non-differentiable ops - x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32") - y0 = bm.from_numpy(x) - if not hasattr(y0, "grad"): - continue + # lazy - just checking reductions + for reduction in REDUCTIONS: + if reduction in ("any", "all"): + continue # non-differentiable ops + x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype("float32") + y0 = bm.from_numpy(x) + if not hasattr(y0, "grad"): + continue - y1 = ein_reduce(y0, "a b c -> c a", reduction=reduction) - y2 = ein_reduce(y1, "c a -> a c", reduction=reduction) - y3 = ein_reduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2) - y4 = ein_reduce(y3, "... -> ", reduction=reduction) + y1 = ein_reduce(y0, "a b c -> c a", reduction=reduction) + y2 = ein_reduce(y1, "c a -> a c", reduction=reduction) + y3 = ein_reduce(y2, "a (c1 c2) -> a", reduction=reduction, c1=2) + y4 = ein_reduce(y3, "... -> ", reduction=reduction) - y4.backward() - grad = bm.as_numpy(y0.grad) + y4.backward() + grad = bm.as_numpy(y0.grad) def test_tiling_imperatives(): - input = numpy.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5]) - test_cases = [ - (1, 1, 1, 1, 1), - (1, 2, 1, 3, 1), - (3, 1, 1, 4, 1), - ] - for repeats in test_cases: - expected = numpy.tile(input, repeats) - converted = bm.from_numpy(input) - repeated = bm.tile(converted, repeats) - result = bm.as_numpy(repeated) - assert numpy.array_equal(result, expected) + input = numpy.arange(2 * 3 * 5, dtype="int64").reshape([2, 1, 3, 1, 5]) + test_cases = [ + (1, 1, 1, 1, 1), + (1, 2, 1, 3, 1), + (3, 1, 1, 4, 1), + ] + for repeats in test_cases: + expected = numpy.tile(input, repeats) + converted = bm.from_numpy(input) + repeated = bm.tile(converted, repeats) + result = bm.as_numpy(repeated) + assert numpy.array_equal(result, expected) repeat_test_cases = [ - # all assume that input has shape [2, 3, 5] - ("a b c -> c a b", dict()), - ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)), - ("a b c -> (a copy) b c ", dict(copy=1)), - ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)), - ("a ... -> a ... copy", dict(copy=4)), - ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)), - ("... -> ... ", dict()), - (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)), - ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)), + # all assume that input has shape [2, 3, 5] + ("a b c -> c a b", dict()), + ("a b c -> (c copy a b)", dict(copy=2, a=2, b=3, c=5)), + ("a b c -> (a copy) b c ", dict(copy=1)), + ("a b c -> (c a) (copy1 b copy2)", dict(a=2, copy1=1, copy2=2)), + ("a ... -> a ... copy", dict(copy=4)), + ("... c -> ... (copy1 c copy2)", dict(copy1=1, copy2=2)), + ("... -> ... ", dict()), + (" ... -> copy1 ... copy2 ", dict(copy1=2, copy2=3)), + ("a b c -> copy1 a copy2 b c () ", dict(copy1=2, copy2=1)), ] def check_reversion(x, repeat_pattern, **sizes): - """Checks repeat pattern by running reduction""" - left, right = repeat_pattern.split("->") - reduce_pattern = right + "->" + left - repeated = ein_repeat(x, repeat_pattern, **sizes) - reduced_min = ein_reduce(repeated, reduce_pattern, reduction="min", **sizes) - reduced_max = ein_reduce(repeated, reduce_pattern, reduction="max", **sizes) - assert numpy.array_equal(x, reduced_min) - assert numpy.array_equal(x, reduced_max) + """Checks repeat pattern by running reduction""" + left, right = repeat_pattern.split("->") + reduce_pattern = right + "->" + left + repeated = ein_repeat(x, repeat_pattern, **sizes) + reduced_min = ein_reduce(repeated, reduce_pattern, reduction="min", **sizes) + reduced_max = ein_reduce(repeated, reduce_pattern, reduction="max", **sizes) + assert numpy.array_equal(x, reduced_min) + assert numpy.array_equal(x, reduced_max) def test_repeat_numpy(): - # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well - x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) - x1 = ein_repeat(x, "a b c -> copy a b c ", copy=1) - assert numpy.array_equal(x[None], x1) - for pattern, axis_dimensions in repeat_test_cases: - check_reversion(x, pattern, **axis_dimensions) + # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well + x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) + x1 = ein_repeat(x, "a b c -> copy a b c ", copy=1) + assert numpy.array_equal(x[None], x1) + for pattern, axis_dimensions in repeat_test_cases: + check_reversion(x, pattern, **axis_dimensions) test_cases_repeat_anonymous = [ - # all assume that input has shape [1, 2, 4, 6] - ("a b c d -> c a d b", dict()), - ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)), - ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)), - ("1 ... -> 3 ... ", dict()), - ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)), - ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()), + # all assume that input has shape [1, 2, 4, 6] + ("a b c d -> c a d b", dict()), + ("a b c d -> (c 2 d a b)", dict(a=1, c=4, d=6)), + ("1 b c d -> (d copy 1) 3 b c ", dict(copy=3)), + ("1 ... -> 3 ... ", dict()), + ("() ... d -> 1 (copy1 d copy2) ... ", dict(copy1=2, copy2=3)), + ("1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)", dict()), ] def test_anonymous_axes(): - x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6]) - for pattern, axis_dimensions in test_cases_repeat_anonymous: - check_reversion(x, pattern, **axis_dimensions) + x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6]) + for pattern, axis_dimensions in test_cases_repeat_anonymous: + check_reversion(x, pattern, **axis_dimensions) def test_list_inputs(): - x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) - - assert numpy.array_equal( - ein_rearrange(list(x), "... -> (...)"), - ein_rearrange(x, "... -> (...)"), - ) - assert numpy.array_equal( - ein_reduce(list(x), "a ... e -> (...)", "min"), - ein_reduce(x, "a ... e -> (...)", "min"), - ) - assert numpy.array_equal( - ein_repeat(list(x), "... -> b (...)", b=3), - ein_repeat(x, "... -> b (...)", b=3), - ) + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + + assert numpy.array_equal( + ein_rearrange(list(x), "... -> (...)"), + ein_rearrange(x, "... -> (...)"), + ) + assert numpy.array_equal( + ein_reduce(list(x), "a ... e -> (...)", "min"), + ein_reduce(x, "a ... e -> (...)", "min"), + ) + assert numpy.array_equal( + ein_repeat(list(x), "... -> b (...)", b=3), + ein_repeat(x, "... -> b (...)", b=3), + ) def bit_count(x): - return sum((x >> i) & 1 for i in range(20)) + return sum((x >> i) & 1 for i in range(20)) def test_reduction_imperatives_booleans(): - """Checks that any/all reduction works in all frameworks""" - x_np = numpy.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6) - - for axis in range(6): - expected_result_any = numpy.any(x_np, axis=axis, keepdims=True) - expected_result_all = numpy.all(x_np, axis=axis, keepdims=True) - assert not numpy.array_equal(expected_result_any, expected_result_all) - - axes = list("abcdef") - axes_in = list(axes) - axes_out = list(axes) - axes_out[axis] = "1" - pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out)) - + """Checks that any/all reduction works in all frameworks""" + x_np = numpy.asarray([(bit_count(x) % 2) == 0 for x in range(2 ** 6)]).reshape([2] * 6) + + for axis in range(6): + expected_result_any = numpy.any(x_np, axis=axis, keepdims=True) + expected_result_all = numpy.all(x_np, axis=axis, keepdims=True) + assert not numpy.array_equal(expected_result_any, expected_result_all) + + axes = list("abcdef") + axes_in = list(axes) + axes_out = list(axes) + axes_out[axis] = "1" + pattern = (" ".join(axes_in)) + " -> " + (" ".join(axes_out)) + + res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") + res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") + + assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) + assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) + + # expected result: any/all + expected_result_any = numpy.any(x_np, axis=(0, 1), keepdims=True) + expected_result_all = numpy.all(x_np, axis=(0, 1), keepdims=True) + pattern = "a b ... -> 1 1 ..." res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") - assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) - - # expected result: any/all - expected_result_any = numpy.any(x_np, axis=(0, 1), keepdims=True) - expected_result_all = numpy.all(x_np, axis=(0, 1), keepdims=True) - pattern = "a b ... -> 1 1 ..." - res_any = ein_reduce(bm.from_numpy(x_np), pattern, reduction="any") - res_all = ein_reduce(bm.from_numpy(x_np), pattern, reduction="all") - assert numpy.array_equal(expected_result_any, bm.as_numpy(res_any)) - assert numpy.array_equal(expected_result_all, bm.as_numpy(res_all)) diff --git a/brainpy/_src/math/tests/test_einops_parsing.py b/brainpy/_src/math/tests/test_einops_parsing.py index 069c7bbac..00bbdeebc 100644 --- a/brainpy/_src/math/tests/test_einops_parsing.py +++ b/brainpy/_src/math/tests/test_einops_parsing.py @@ -4,108 +4,108 @@ class AnonymousAxisPlaceholder: - def __init__(self, value: int): - self.value = value - assert isinstance(self.value, int) + def __init__(self, value: int): + self.value = value + assert isinstance(self.value, int) - def __eq__(self, other): - return isinstance(other, AnonymousAxis) and self.value == other.value + def __eq__(self, other): + return isinstance(other, AnonymousAxis) and self.value == other.value def test_anonymous_axes(): - a, b = AnonymousAxis('2'), AnonymousAxis('2') - assert a != b - c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3) - assert a == c and b == c - assert a != d and b != d - assert [a, 2, b] == [c, 2, c] + a, b = AnonymousAxis('2'), AnonymousAxis('2') + assert a != b + c, d = AnonymousAxisPlaceholder(2), AnonymousAxisPlaceholder(3) + assert a == c and b == c + assert a != d and b != d + assert [a, 2, b] == [c, 2, c] def test_elementary_axis_name(): - for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname', - 'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']: - assert ParsedExpression.check_axis_name(name) + for name in ['a', 'b', 'h', 'dx', 'h1', 'zz', 'i9123', 'somelongname', + 'Alex', 'camelCase', 'u_n_d_e_r_score', 'unreasonablyLongAxisName']: + assert ParsedExpression.check_axis_name(name) - for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]: - assert not ParsedExpression.check_axis_name(name) + for name in ['', '2b', '12', '_startWithUnderscore', 'endWithUnderscore_', '_', '...', _ellipsis]: + assert not ParsedExpression.check_axis_name(name) def test_invalid_expressions(): - # double ellipsis should raise an error - ParsedExpression('... a b c d') - with pytest.raises(EinopsError): - ParsedExpression('... a b c d ...') - with pytest.raises(EinopsError): - ParsedExpression('... a b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(... a) b c (d ...)') - - # double/missing/enclosed parenthesis - ParsedExpression('(a) b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a)) b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a) (()) b c (d ...)') - with pytest.raises(EinopsError): - ParsedExpression('(a) ((b c) (d ...))') - - # invalid identifiers - ParsedExpression('camelCase under_scored cApiTaLs ß ...') - with pytest.raises(EinopsError): - ParsedExpression('1a') - with pytest.raises(EinopsError): - ParsedExpression('_pre') - with pytest.raises(EinopsError): - ParsedExpression('...pre') - with pytest.raises(EinopsError): - ParsedExpression('pre...') + # double ellipsis should raise an error + ParsedExpression('... a b c d') + with pytest.raises(EinopsError): + ParsedExpression('... a b c d ...') + with pytest.raises(EinopsError): + ParsedExpression('... a b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(... a) b c (d ...)') + + # double/missing/enclosed parenthesis + ParsedExpression('(a) b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a)) b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a) (()) b c (d ...)') + with pytest.raises(EinopsError): + ParsedExpression('(a) ((b c) (d ...))') + + # invalid identifiers + ParsedExpression('camelCase under_scored cApiTaLs ß ...') + with pytest.raises(EinopsError): + ParsedExpression('1a') + with pytest.raises(EinopsError): + ParsedExpression('_pre') + with pytest.raises(EinopsError): + ParsedExpression('...pre') + with pytest.raises(EinopsError): + ParsedExpression('pre...') def test_parse_expression(): - parsed = ParsedExpression('a1 b1 c1 d1') - assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'} - assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']] - assert not parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - parsed = ParsedExpression('() () () ()') - assert parsed.identifiers == set() - assert parsed.composition == [[], [], [], []] - assert not parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - parsed = ParsedExpression('1 1 1 ()') - assert parsed.identifiers == set() - assert parsed.composition == [[], [], [], []] - assert not parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - aap = AnonymousAxisPlaceholder - - parsed = ParsedExpression('5 (3 4)') - assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5} - assert parsed.composition == [[aap(5)], [aap(3), aap(4)]] - assert parsed.has_non_unitary_anonymous_axes - assert not parsed.has_ellipsis - - parsed = ParsedExpression('5 1 (1 4) 1') - assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5} - assert parsed.composition == [[aap(5)], [], [aap(4)], []] - - parsed = ParsedExpression('name1 ... a1 12 (name2 14)') - assert len(parsed.identifiers) == 6 - assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 - assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]] - assert parsed.has_non_unitary_anonymous_axes - assert parsed.has_ellipsis - assert not parsed.has_ellipsis_parenthesized - - parsed = ParsedExpression('(name1 ... a1 12) name2 14') - assert len(parsed.identifiers) == 6 - assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 - assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]] - assert parsed.has_non_unitary_anonymous_axes - assert parsed.has_ellipsis - assert parsed.has_ellipsis_parenthesized + parsed = ParsedExpression('a1 b1 c1 d1') + assert parsed.identifiers == {'a1', 'b1', 'c1', 'd1'} + assert parsed.composition == [['a1'], ['b1'], ['c1'], ['d1']] + assert not parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + parsed = ParsedExpression('() () () ()') + assert parsed.identifiers == set() + assert parsed.composition == [[], [], [], []] + assert not parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + parsed = ParsedExpression('1 1 1 ()') + assert parsed.identifiers == set() + assert parsed.composition == [[], [], [], []] + assert not parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + aap = AnonymousAxisPlaceholder + + parsed = ParsedExpression('5 (3 4)') + assert len(parsed.identifiers) == 3 and {i.value for i in parsed.identifiers} == {3, 4, 5} + assert parsed.composition == [[aap(5)], [aap(3), aap(4)]] + assert parsed.has_non_unitary_anonymous_axes + assert not parsed.has_ellipsis + + parsed = ParsedExpression('5 1 (1 4) 1') + assert len(parsed.identifiers) == 2 and {i.value for i in parsed.identifiers} == {4, 5} + assert parsed.composition == [[aap(5)], [], [aap(4)], []] + + parsed = ParsedExpression('name1 ... a1 12 (name2 14)') + assert len(parsed.identifiers) == 6 + assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 + assert parsed.composition == [['name1'], _ellipsis, ['a1'], [aap(12)], ['name2', aap(14)]] + assert parsed.has_non_unitary_anonymous_axes + assert parsed.has_ellipsis + assert not parsed.has_ellipsis_parenthesized + + parsed = ParsedExpression('(name1 ... a1 12) name2 14') + assert len(parsed.identifiers) == 6 + assert parsed.identifiers.difference({'name1', _ellipsis, 'a1', 'name2'}).__len__() == 2 + assert parsed.composition == [['name1', _ellipsis, 'a1', aap(12)], ['name2'], [aap(14)]] + assert parsed.has_non_unitary_anonymous_axes + assert parsed.has_ellipsis + assert parsed.has_ellipsis_parenthesized diff --git a/brainpy/_src/math/tests/test_environment.py b/brainpy/_src/math/tests/test_environment.py index 96fe2db46..1c8a0c0f5 100644 --- a/brainpy/_src/math/tests/test_environment.py +++ b/brainpy/_src/math/tests/test_environment.py @@ -6,13 +6,13 @@ class TestEnvironment(unittest.TestCase): - def test_numpy_func_return(self): - # Reset random state to ensure clean state between tests - bm.random.seed() - - with bm.environment(numpy_func_return='jax_array'): - a = bm.random.randn(3, 3) - self.assertTrue(isinstance(a, jax.Array)) - with bm.environment(numpy_func_return='bp_array'): - a = bm.random.randn(3, 3) - self.assertTrue(isinstance(a, bm.Array)) + def test_numpy_func_return(self): + # Reset random state to ensure clean state between tests + bm.random.seed() + + with bm.environment(numpy_func_return='jax_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, jax.Array)) + with bm.environment(numpy_func_return='bp_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, bm.Array)) diff --git a/brainpy/_src/math/tests/test_math_deprecations.py b/brainpy/_src/math/tests/test_math_deprecations.py index 7a72653a7..a105c1989 100644 --- a/brainpy/_src/math/tests/test_math_deprecations.py +++ b/brainpy/_src/math/tests/test_math_deprecations.py @@ -1,16 +1,14 @@ -import unittest from absl.testing import parameterized import brainpy.math as bm - deprecated_names = list(bm.__deprecations.keys()) class Test(parameterized.TestCase): - @parameterized.product( - name=deprecated_names - ) - def test(self, name): - with self.assertWarns(DeprecationWarning): - getattr(bm, name) + @parameterized.product( + name=deprecated_names + ) + def test(self, name): + with self.assertWarns(DeprecationWarning): + getattr(bm, name) diff --git a/brainpy/_src/math/tests/test_ndarray.py b/brainpy/_src/math/tests/test_ndarray.py index c42237d1c..008d4e005 100644 --- a/brainpy/_src/math/tests/test_ndarray.py +++ b/brainpy/_src/math/tests/test_ndarray.py @@ -2,9 +2,8 @@ import unittest -import jax -import brainpy as bp +import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_flatten, tree_unflatten @@ -14,114 +13,113 @@ class TestJaxArray(unittest.TestCase): - def test_tree(self): - structured = {'a': Variable(jnp.zeros(1)), - 'b': (Variable(jnp.ones(2)), - Variable(jnp.ones(2) * 2))} - flat, tree = tree_flatten(structured) - unflattened = tree_unflatten(tree, flat) - print("\nstructured={}\n\n flat={}\n\n tree={}\n\n unflattened={}".format( - structured, flat, tree, unflattened)) - - def test_none(self): - # https://github.com/PKU-NIP-Lab/BrainPy/issues/144 - a = None - b = bm.zeros(10) - with self.assertRaises(TypeError): - bb = a + b - - c = bm.Variable(bm.zeros(10)) - with self.assertRaises(TypeError): - cc = a + c - - d = bm.Parameter(bm.zeros(10)) - with self.assertRaises(TypeError): - dd = a + d - - e = bm.TrainVar(bm.zeros(10)) - with self.assertRaises(TypeError): - ee = a + e - - def test_operation_with_numpy_array(self): - rng = bm.random.RandomState(123) - add = lambda: bm.asarray(rng.rand(10)) + np.zeros(1) - self.assertTrue(isinstance(add(), bm.Array)) - self.assertTrue(isinstance(bm.jit(add)(), bm.Array)) + def test_tree(self): + structured = {'a': Variable(jnp.zeros(1)), + 'b': (Variable(jnp.ones(2)), + Variable(jnp.ones(2) * 2))} + flat, tree = tree_flatten(structured) + unflattened = tree_unflatten(tree, flat) + print("\nstructured={}\n\n flat={}\n\n tree={}\n\n unflattened={}".format( + structured, flat, tree, unflattened)) + + def test_none(self): + # https://github.com/PKU-NIP-Lab/BrainPy/issues/144 + a = None + b = bm.zeros(10) + with self.assertRaises(TypeError): + bb = a + b + + c = bm.Variable(bm.zeros(10)) + with self.assertRaises(TypeError): + cc = a + c + + d = bm.Parameter(bm.zeros(10)) + with self.assertRaises(TypeError): + dd = a + d + + e = bm.TrainVar(bm.zeros(10)) + with self.assertRaises(TypeError): + ee = a + e + + def test_operation_with_numpy_array(self): + rng = bm.random.RandomState(123) + add = lambda: bm.asarray(rng.rand(10)) + np.zeros(1) + self.assertTrue(isinstance(add(), bm.Array)) + self.assertTrue(isinstance(bm.jit(add)(), bm.Array)) class TestTracerError(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - self.a = bm.zeros((10, 2)) - self.f = jax.jit(self._f) + self.a = bm.zeros((10, 2)) + self.f = jax.jit(self._f) - def _f(self, b): - self.a[:] = bm.zeros_like(self.a) - return b + 1. + def _f(self, b): + self.a[:] = bm.zeros_like(self.a) + return b + 1. - def test_tracing(self): - print(self.f(1.)) - with self.assertRaises(jax.errors.UnexpectedTracerError): - print(self.f(bm.ones(10))) + def test_tracing(self): + print(self.f(1.)) + with self.assertRaises(jax.errors.UnexpectedTracerError): + print(self.f(bm.ones(10))) class TestVariable(unittest.TestCase): - def test_variable_init(self): - self.assertTrue( - bm.array_equal(bm.Variable(bm.zeros(10)), - bm.Variable(10)) - ) - bm.random.seed(123) - self.assertTrue( - not bm.array_equal(bm.Variable(bm.random.rand(10)), - bm.Variable(10)) - ) + def test_variable_init(self): + self.assertTrue( + bm.array_equal(bm.Variable(bm.zeros(10)), + bm.Variable(10)) + ) + bm.random.seed(123) + self.assertTrue( + not bm.array_equal(bm.Variable(bm.random.rand(10)), + bm.Variable(10)) + ) class TestVariableView(unittest.TestCase): - def test_update(self): - bm.random.seed() + def test_update(self): + bm.random.seed() - origin = bm.Variable(bm.zeros(10)) - view = bm.VariableView(origin, slice(0, 5, None)) + origin = bm.Variable(bm.zeros(10)) + view = bm.VariableView(origin, slice(0, 5, None)) - view.update(bm.ones(5)) - self.assertTrue( - bm.array_equal(origin, bm.concatenate([bm.ones(5), bm.zeros(5)])) - ) + view.update(bm.ones(5)) + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.ones(5), bm.zeros(5)])) + ) - view.value = bm.arange(5.) - self.assertTrue( - bm.array_equal(origin, bm.concatenate([bm.arange(5), bm.zeros(5)])) - ) + view.value = bm.arange(5.) + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.arange(5), bm.zeros(5)])) + ) - view += 10 - self.assertTrue( - bm.array_equal(origin, bm.concatenate([bm.arange(10, 15), bm.zeros(5)])) - ) + view += 10 + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.arange(10, 15), bm.zeros(5)])) + ) - bm.random.shuffle(view) - print(view) - print(origin) + bm.random.shuffle(view) + print(view) + print(origin) - view.sort() - self.assertTrue( - bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)])) - ) + view.sort() + self.assertTrue( + bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)])) + ) - self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10)) + self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10)) class TestArrayPriority(unittest.TestCase): - def test1(self): - a = bm.Array(bm.zeros(10)) - assert isinstance(a + bm.ones(1).value, bm.Array) - assert isinstance(a + np.ones(1), bm.Array) - assert isinstance(a * np.ones(1), bm.Array) - assert isinstance(np.ones(1) + a, bm.Array) - assert isinstance(np.ones(1) * a, bm.Array) - b = bm.Variable(bm.zeros(10)) - assert isinstance(b + bm.ones(1).value, bm.Array) - assert isinstance(b + np.ones(1), bm.Array) - + def test1(self): + a = bm.Array(bm.zeros(10)) + assert isinstance(a + bm.ones(1).value, bm.Array) + assert isinstance(a + np.ones(1), bm.Array) + assert isinstance(a * np.ones(1), bm.Array) + assert isinstance(np.ones(1) + a, bm.Array) + assert isinstance(np.ones(1) * a, bm.Array) + b = bm.Variable(bm.zeros(10)) + assert isinstance(b + bm.ones(1).value, bm.Array) + assert isinstance(b + np.ones(1), bm.Array) diff --git a/brainpy/_src/math/tests/test_numpy_einsum.py b/brainpy/_src/math/tests/test_numpy_einsum.py index b9673ce6f..ccbcd35ac 100644 --- a/brainpy/_src/math/tests/test_numpy_einsum.py +++ b/brainpy/_src/math/tests/test_numpy_einsum.py @@ -16,8 +16,8 @@ import pytest -pytest.skip("No need to test.", allow_module_level=True) +pytest.skip("No need to test.", allow_module_level=True) import itertools from collections import defaultdict @@ -37,333 +37,333 @@ def is_sequence(x): - try: - iter(x) - except TypeError: - return False - else: - return True + try: + iter(x) + except TypeError: + return False + else: + return True class EinsumTest(jtu.JaxTestCase): - def _check(self, s, *ops): - a = np.einsum(s, *ops) - b = bm.einsum(s, *ops, precision=lax.Precision.HIGHEST) - print(f'a.shape = {a.shape}') - print(f'b.shape = {b.shape}') - - self.assertTrue(bm) - self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) - - def test_three_operands_1(self): - r = bm.random.RandomState() - x = r.randn(3) - y = r.randn(4) - z = r.randn(5) - s = 'i,j,k->ijk' - self._check(s, x, y, z) - - def test_three_operands_2(self): - r = bm.random.RandomState() - x = r.randn(3) - y = r.randn(4) - z = r.randn(5) - s = 'i,j,k->ijk' - self._check(s, x, y, z) - - def test_two_operands_1(self): - r = bm.random.RandomState() - x = r.randn(3, 4) - y = r.randn(4) - s = 'ij,j->i' - self._check(s, x, y) - - def test_two_operands_2(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - y = r.randn(4) - s = 'ijk,j->i' - self._check(s, x, y) - - def test_two_operands_3(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 3) - y = r.randn(3) - s = 'iji,i->j' - self._check(s, x, y) - - def test_two_operands_4(self): - r = bm.random.RandomState() - x = r.randn(3, 4) - y = r.randn(3, 4) - s = 'ij,ij->' - self._check(s, x, y) - - def test_two_operands_5(self): - r = bm.random.RandomState() - x = r.randn(10, 2, 3) - y = r.randn(3, 4) - s = 'nij,jk->nik' - self._check(s, x, y) - - def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 - r = bm.random.RandomState() - x = r.randn(2, 1) - y = r.randn(2, 3, 4) - s = 'sa,shb->shab' - self._check(s, x, y) - - def test_one_operand_1(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - s = 'ijk->j' - self._check(s, x) - - def test_one_operand_2(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - s = 'ijk->kij' - self._check(s, x) - - def test_one_operand_3(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - s = 'ijk->ki' - self._check(s, x) - - def test_one_operand_4(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - s = 'ijk->ki' - self._check(s, x) - - def test_one_operand_5(self): - r = bm.random.RandomState() - x = r.randn(2, 3, 4, 5) - s = '...ijk->...ki' - self._check(s, x) - - def test_one_operand_6(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - s = '...ijk->ki' - self._check(s, x) - - def test_one_operand_7(self): - r = bm.random.RandomState() - x = r.randn(3, 3) - s = 'ii->' - self._check(s, x) - - def test_one_operand_8(self): - r = bm.random.RandomState() - x = r.randn(3, 3) - s = 'ij->' - self._check(s, x) - - def test_one_operand_9(self): - r = bm.random.RandomState() - x = r.randn(3, 3, 3) - s = 'iii->' - self._check(s, x) - - def test_one_operand_10(self): - r = bm.random.RandomState() - x = r.randn(3, 3) - s = 'ii->i' - self._check(s, x) - - def test_one_operand_11(self): - r = bm.random.RandomState() - x = r.randn(3, 3, 4) - s = 'iij->i' - self._check(s, x) - - def test_one_operand_12(self): - r = bm.random.RandomState() - x = r.randn(3, 3, 3) - s = 'iii->i' - self._check(s, x) - - def test_one_operand_13(self): - r = bm.random.RandomState() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkk->i' - self._check(s, x) - - def test_one_operand_14(self): - r = bm.random.RandomState() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkk->ik' - self._check(s, x) - - def test_one_operand_15(self): - r = bm.random.RandomState() - x = r.randn(3, 3, 5, 4, 4) - s = 'iijkl->il' - self._check(s, x) - - def test_one_operand_16(self): - r = bm.random.RandomState() - x = r.randn(3, 3) - s = 'ij->ij' - self._check(s, x) - - def test_tf_unsupported_1(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = bm.random.RandomState() - x = r.randn(2, 3, 5, 1) - y = r.randn(3, 4, 5, 1) - s = 'ij...,jk...->ik...' - self._check(s, x, y) - - def test_tf_unsupported_2(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = bm.random.RandomState() - x = r.randn(2, 3, 3) - y = r.randn(4) - s = 'ijj,k->ik' - self._check(s, x, y) - - def test_tf_unsupported_3(self): - # from https://www.tensorflow.org/api_docs/python/tf/einsum - r = bm.random.RandomState() - x = r.randn(2, 3) - y = r.randn(2, 3) - z = r.randn(3, 4) - s = 'ij,ij,jk->ik' - self._check(s, x, y, z) - - # these tests are based on https://github.com/dask/dask/pull/3412/files - @parameterized.named_parameters( - {"testcase_name": "_{}_dtype={}".format(einstr, dtype.__name__), "einstr": einstr, - "dtype": dtype} - for einstr in [ - 'abc,bad->abcd', - 'abcdef,bcdfg->abcdeg', - 'ea,fb,abcd,gc,hd->efgh', - 'ab,b', - 'aa', - 'a,a->', - 'a,a->a', - 'a,a', - 'a,b', - 'a,b,c', - 'a', - 'ba,b', - 'ba,b->', - 'defab,fedbc->defac', - 'ab...,bc...->ac...', - 'a...a', - 'abc...->cba...', - '...ab->...a', - 'a...a->a...', - # Following 2 from # https://stackoverflow.com/a/19203475/1611416 - '...abc,...abcd->...d', - 'ab...,b->ab...', - # https://github.com/dask/dask/pull/3412#discussion_r182413444 - 'aa->a', - 'ab,ab,c->c', - 'aab,bc->ac', - 'aab,bcc->ac', - 'fdf,cdd,ccd,afe->ae', - 'fff,fae,bef,def->abd', - ] - for dtype in [jnp.float32, jnp.int32, jnp.complex64, jnp.bool_]) - def test_from_dask(self, einstr, dtype): - r = jtu.rand_default(bm.random.RandomState()) - if '->' in einstr: - input_str, result_names = einstr.split('->') - else: - input_str = einstr - input_names = input_str.split(',') - - dims = itertools.cycle([2, 3, 4]) - shapes = defaultdict(lambda: next(dims)) - input_shapes = [tuple(shapes[c] for c in names.replace('...', '01')) - for names in input_names] - operands = [r(shape, dtype) for shape in input_shapes] - - self._check(einstr, *operands) - - def test_ordered_front_batch_dim_case(self): - x = np.ones((1, 8, 20, 4)) - y = np.ones((1, 8, 20, 4)) - s = 'ijkl,ijml->ijkm' - self._check(s, x, y) - - def test_einsum_path(self): - # just check examples from np.einsum_path docstring - a = bm.random.RandomState().rand(2, 2) - b = bm.random.RandomState().rand(2, 5) - c = bm.random.RandomState().rand(5, 2) - - path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') - self.assertEqual(str(path_info[0]), "['einsum_path', (1, 2), (0, 1)]") - self.assertEqual(path_info[1].split('\n')[0], - ' Complete contraction: ij,jk,kl->il') - - # check this doesn't crash - I = bm.random.RandomState().rand(10, 10, 10, 10) - C = bm.random.RandomState().rand(10, 10) - np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, optimize='greedy') - - def test_einsum_kpmurphy_example(self): - # code from an email with @murphyk - N, C, D, K, T = 2, 3, 4, 5, 6 - r = np.random.RandomState(123) - S = r.randn(N, T, K) - W = r.randn(K, D) - V = r.randn(D, C) - L = np.zeros((N, C)) - for n in range(N): - for c in range(C): - s = 0 - for d in range(D): - for k in range(K): - for t in range(T): - s += S[n, t, k] * W[k, d] * V[d, c] - L[n, c] = s - - path = bm.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] - rtol = 1e-2 if jtu.device_under_test() == "tpu" else None - self.assertAllClose(L, - bm.einsum('ntk,kd,dc->nc', S, W, V, optimize=path), - check_dtypes=False, rtol=rtol) - - def test_contraction_broadcasting(self): - r = bm.random.RandomState() - x = r.randn(3, 4, 5) - y = r.randn(3, 1, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - def test_batch_broadcasting(self): - r = bm.random.RandomState() - x = r.randn(1, 4, 5) - y = r.randn(3, 5, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - def test_batch_and_contraction_broadcasting(self): - r = bm.random.RandomState() - x = r.randn(1, 4, 5) - y = r.randn(3, 1, 6) - s = 'cij,cjk->cik' - self._check(s, x, y) - - def test_broadcasting_issue_2189(self): - r = bm.random.RandomState() - x = r.randn(2, 1, 3, 3) - y = r.randn(2, 4, 3) - s = '...ij,...j' - self._check(s, x, y) - - def test_no_unnecessary_transpose(self): - r = bm.random.RandomState() - x = r.randn(2, 2, 2) - y = r.randn(2, 2) - jaxpr = jax.make_jaxpr(partial(bm.einsum, "ijk,kl->ijl"))(x, y) - self.assertNotIn('transpose', str(jaxpr)) + def _check(self, s, *ops): + a = np.einsum(s, *ops) + b = bm.einsum(s, *ops, precision=lax.Precision.HIGHEST) + print(f'a.shape = {a.shape}') + print(f'b.shape = {b.shape}') + + self.assertTrue(bm) + self.assertAllClose(a, b, atol=1e-4, rtol=1e-4) + + def test_three_operands_1(self): + r = bm.random.RandomState() + x = r.randn(3) + y = r.randn(4) + z = r.randn(5) + s = 'i,j,k->ijk' + self._check(s, x, y, z) + + def test_three_operands_2(self): + r = bm.random.RandomState() + x = r.randn(3) + y = r.randn(4) + z = r.randn(5) + s = 'i,j,k->ijk' + self._check(s, x, y, z) + + def test_two_operands_1(self): + r = bm.random.RandomState() + x = r.randn(3, 4) + y = r.randn(4) + s = 'ij,j->i' + self._check(s, x, y) + + def test_two_operands_2(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + y = r.randn(4) + s = 'ijk,j->i' + self._check(s, x, y) + + def test_two_operands_3(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 3) + y = r.randn(3) + s = 'iji,i->j' + self._check(s, x, y) + + def test_two_operands_4(self): + r = bm.random.RandomState() + x = r.randn(3, 4) + y = r.randn(3, 4) + s = 'ij,ij->' + self._check(s, x, y) + + def test_two_operands_5(self): + r = bm.random.RandomState() + x = r.randn(10, 2, 3) + y = r.randn(3, 4) + s = 'nij,jk->nik' + self._check(s, x, y) + + def test_two_operands_6(self): + # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + r = bm.random.RandomState() + x = r.randn(2, 1) + y = r.randn(2, 3, 4) + s = 'sa,shb->shab' + self._check(s, x, y) + + def test_one_operand_1(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + s = 'ijk->j' + self._check(s, x) + + def test_one_operand_2(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + s = 'ijk->kij' + self._check(s, x) + + def test_one_operand_3(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + s = 'ijk->ki' + self._check(s, x) + + def test_one_operand_4(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + s = 'ijk->ki' + self._check(s, x) + + def test_one_operand_5(self): + r = bm.random.RandomState() + x = r.randn(2, 3, 4, 5) + s = '...ijk->...ki' + self._check(s, x) + + def test_one_operand_6(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + s = '...ijk->ki' + self._check(s, x) + + def test_one_operand_7(self): + r = bm.random.RandomState() + x = r.randn(3, 3) + s = 'ii->' + self._check(s, x) + + def test_one_operand_8(self): + r = bm.random.RandomState() + x = r.randn(3, 3) + s = 'ij->' + self._check(s, x) + + def test_one_operand_9(self): + r = bm.random.RandomState() + x = r.randn(3, 3, 3) + s = 'iii->' + self._check(s, x) + + def test_one_operand_10(self): + r = bm.random.RandomState() + x = r.randn(3, 3) + s = 'ii->i' + self._check(s, x) + + def test_one_operand_11(self): + r = bm.random.RandomState() + x = r.randn(3, 3, 4) + s = 'iij->i' + self._check(s, x) + + def test_one_operand_12(self): + r = bm.random.RandomState() + x = r.randn(3, 3, 3) + s = 'iii->i' + self._check(s, x) + + def test_one_operand_13(self): + r = bm.random.RandomState() + x = r.randn(3, 3, 5, 4, 4) + s = 'iijkk->i' + self._check(s, x) + + def test_one_operand_14(self): + r = bm.random.RandomState() + x = r.randn(3, 3, 5, 4, 4) + s = 'iijkk->ik' + self._check(s, x) + + def test_one_operand_15(self): + r = bm.random.RandomState() + x = r.randn(3, 3, 5, 4, 4) + s = 'iijkl->il' + self._check(s, x) + + def test_one_operand_16(self): + r = bm.random.RandomState() + x = r.randn(3, 3) + s = 'ij->ij' + self._check(s, x) + + def test_tf_unsupported_1(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = bm.random.RandomState() + x = r.randn(2, 3, 5, 1) + y = r.randn(3, 4, 5, 1) + s = 'ij...,jk...->ik...' + self._check(s, x, y) + + def test_tf_unsupported_2(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = bm.random.RandomState() + x = r.randn(2, 3, 3) + y = r.randn(4) + s = 'ijj,k->ik' + self._check(s, x, y) + + def test_tf_unsupported_3(self): + # from https://www.tensorflow.org/api_docs/python/tf/einsum + r = bm.random.RandomState() + x = r.randn(2, 3) + y = r.randn(2, 3) + z = r.randn(3, 4) + s = 'ij,ij,jk->ik' + self._check(s, x, y, z) + + # these tests are based on https://github.com/dask/dask/pull/3412/files + @parameterized.named_parameters( + {"testcase_name": "_{}_dtype={}".format(einstr, dtype.__name__), "einstr": einstr, + "dtype": dtype} + for einstr in [ + 'abc,bad->abcd', + 'abcdef,bcdfg->abcdeg', + 'ea,fb,abcd,gc,hd->efgh', + 'ab,b', + 'aa', + 'a,a->', + 'a,a->a', + 'a,a', + 'a,b', + 'a,b,c', + 'a', + 'ba,b', + 'ba,b->', + 'defab,fedbc->defac', + 'ab...,bc...->ac...', + 'a...a', + 'abc...->cba...', + '...ab->...a', + 'a...a->a...', + # Following 2 from # https://stackoverflow.com/a/19203475/1611416 + '...abc,...abcd->...d', + 'ab...,b->ab...', + # https://github.com/dask/dask/pull/3412#discussion_r182413444 + 'aa->a', + 'ab,ab,c->c', + 'aab,bc->ac', + 'aab,bcc->ac', + 'fdf,cdd,ccd,afe->ae', + 'fff,fae,bef,def->abd', + ] + for dtype in [jnp.float32, jnp.int32, jnp.complex64, jnp.bool_]) + def test_from_dask(self, einstr, dtype): + r = jtu.rand_default(bm.random.RandomState()) + if '->' in einstr: + input_str, result_names = einstr.split('->') + else: + input_str = einstr + input_names = input_str.split(',') + + dims = itertools.cycle([2, 3, 4]) + shapes = defaultdict(lambda: next(dims)) + input_shapes = [tuple(shapes[c] for c in names.replace('...', '01')) + for names in input_names] + operands = [r(shape, dtype) for shape in input_shapes] + + self._check(einstr, *operands) + + def test_ordered_front_batch_dim_case(self): + x = np.ones((1, 8, 20, 4)) + y = np.ones((1, 8, 20, 4)) + s = 'ijkl,ijml->ijkm' + self._check(s, x, y) + + def test_einsum_path(self): + # just check examples from np.einsum_path docstring + a = bm.random.RandomState().rand(2, 2) + b = bm.random.RandomState().rand(2, 5) + c = bm.random.RandomState().rand(5, 2) + + path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy') + self.assertEqual(str(path_info[0]), "['einsum_path', (1, 2), (0, 1)]") + self.assertEqual(path_info[1].split('\n')[0], + ' Complete contraction: ij,jk,kl->il') + + # check this doesn't crash + I = bm.random.RandomState().rand(10, 10, 10, 10) + C = bm.random.RandomState().rand(10, 10) + np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, optimize='greedy') + + def test_einsum_kpmurphy_example(self): + # code from an email with @murphyk + N, C, D, K, T = 2, 3, 4, 5, 6 + r = np.random.RandomState(123) + S = r.randn(N, T, K) + W = r.randn(K, D) + V = r.randn(D, C) + L = np.zeros((N, C)) + for n in range(N): + for c in range(C): + s = 0 + for d in range(D): + for k in range(K): + for t in range(T): + s += S[n, t, k] * W[k, d] * V[d, c] + L[n, c] = s + + path = bm.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0] + rtol = 1e-2 if jtu.device_under_test() == "tpu" else None + self.assertAllClose(L, + bm.einsum('ntk,kd,dc->nc', S, W, V, optimize=path), + check_dtypes=False, rtol=rtol) + + def test_contraction_broadcasting(self): + r = bm.random.RandomState() + x = r.randn(3, 4, 5) + y = r.randn(3, 1, 6) + s = 'cij,cjk->cik' + self._check(s, x, y) + + def test_batch_broadcasting(self): + r = bm.random.RandomState() + x = r.randn(1, 4, 5) + y = r.randn(3, 5, 6) + s = 'cij,cjk->cik' + self._check(s, x, y) + + def test_batch_and_contraction_broadcasting(self): + r = bm.random.RandomState() + x = r.randn(1, 4, 5) + y = r.randn(3, 1, 6) + s = 'cij,cjk->cik' + self._check(s, x, y) + + def test_broadcasting_issue_2189(self): + r = bm.random.RandomState() + x = r.randn(2, 1, 3, 3) + y = r.randn(2, 4, 3) + s = '...ij,...j' + self._check(s, x, y) + + def test_no_unnecessary_transpose(self): + r = bm.random.RandomState() + x = r.randn(2, 2, 2) + y = r.randn(2, 2) + jaxpr = jax.make_jaxpr(partial(bm.einsum, "ijk,kl->ijl"))(x, y) + self.assertNotIn('transpose', str(jaxpr)) diff --git a/brainpy/_src/math/tests/test_numpy_indexing.py b/brainpy/_src/math/tests/test_numpy_indexing.py index f6832dd58..6861bcfec 100644 --- a/brainpy/_src/math/tests/test_numpy_indexing.py +++ b/brainpy/_src/math/tests/test_numpy_indexing.py @@ -16,6 +16,7 @@ import pytest + pytest.skip("No need to test.", allow_module_level=True) import enum @@ -29,7 +30,6 @@ import jax import numpy as np from absl.testing import parameterized -from jax import lax from jax import numpy as jnp from jax._src import dtypes from jax._src import test_util as jtu @@ -54,1233 +54,1235 @@ class IndexSpec(typing.NamedTuple): - shape: Tuple[int, ...] - indexer: Any - out_shape: Optional[Tuple[int, ...]] = None + shape: Tuple[int, ...] + indexer: Any + out_shape: Optional[Tuple[int, ...]] = None def check_grads(f, args, order, atol=None, rtol=None, eps=None): - # TODO(mattjj,dougalm): add higher-order check - default_tol = 1e-6 if config.x64_enabled else 1e-2 - atol = atol or default_tol - rtol = rtol or default_tol - eps = eps or default_tol - jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps) - jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps) + # TODO(mattjj,dougalm): add higher-order check + default_tol = 1e-6 if config.x64_enabled else 1e-2 + atol = atol or default_tol + rtol = rtol or default_tol + eps = eps or default_tol + jtu.check_jvp(f, partial(jax.jvp, f), args, atol, rtol, eps) + jtu.check_vjp(f, partial(jax.vjp, f), args, atol, rtol, eps) STATIC_INDEXING_TESTS = [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=1, out_shape=()), - IndexSpec(shape=(3, 3), indexer=0, out_shape=(3,)), - IndexSpec(shape=(3, 4, 5), indexer=2, out_shape=(4, 5)), - IndexSpec(shape=(3,), indexer=-1, out_shape=()), - IndexSpec(shape=(3,), indexer=-2, out_shape=()), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, 1), out_shape=()), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2), out_shape=(5,)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2), out_shape=(5,)), - ]), - ("ThreeIntIndices", [ - IndexSpec(shape=(3, 4, 5), indexer=(1, 2, 3), out_shape=()), - ]), - ("OneSliceIndex", [ - IndexSpec(shape=(10,), indexer=slice(1, 3), out_shape=(2,)), - IndexSpec(shape=(10,), indexer=slice(1, -1), out_shape=(8,)), - IndexSpec(shape=(10,), indexer=slice(None, -1), out_shape=(9,)), - IndexSpec(shape=(10,), indexer=slice(None, None, None), out_shape=(10,)), - IndexSpec(shape=(10, 8), indexer=slice(1, 3), out_shape=(2, 8)), - IndexSpec(shape=(10, 8), indexer=slice(1, None), out_shape=(9, 8)), - IndexSpec(shape=(10, 8), indexer=slice(None, 3), out_shape=(3, 8)), - IndexSpec(shape=(10, 8), indexer=slice(-3, None), out_shape=(3, 8)), - ]), - ("OneSliceIndexNegativeStride", [ - IndexSpec(shape=(10,), indexer=slice(3, 1, -1), out_shape=(2,)), - IndexSpec(shape=(10,), indexer=slice(1, 8, -1), out_shape=(0,)), - IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)), - IndexSpec(shape=(10,), indexer=slice(None, None, -1), out_shape=(10,)), - IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1), out_shape=(2, 8)), - IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)), - ]), - ("OneSliceIndexNonUnitStride", [ - IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)), - IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)), - IndexSpec(shape=(10,), indexer=slice(1, 3, 2), out_shape=(1,)), - IndexSpec(shape=(10,), indexer=slice(1, None, 2), out_shape=(5,)), - IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)), - IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3), out_shape=(3, 8)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, 2), out_shape=(5, 8)), - IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2), out_shape=(4, 8)), - IndexSpec(shape=(10, 8), indexer=slice(None, None, -2), out_shape=(5, 8)), - ]), - ("TwoSliceIndices", [ - IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2)), - out_shape=(2, 2)), - IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2)), - out_shape=(9, 2)), - IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2)), - out_shape=(10, 2)), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2)), - out_shape=(2, 2, 3)), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None)), - out_shape=(2, 8, 3)), - IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2)), - out_shape=(9, 2, 3)), - ]), - ("OneColonIndex", [ - IndexSpec(shape=(3,), indexer=slice(None), out_shape=(3,)), - IndexSpec(shape=(3, 4), indexer=slice(None), out_shape=(3, 4)), - ]), - ("MultipleColonIndices", [ - IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None)), - out_shape=(3, 4)), - IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None)), - out_shape=(3, 4, 5)), - ]), - ("MixedSliceIndices", [ - IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2)), - out_shape=(10, 2)), - IndexSpec(shape=(10, 4), indexer=(1, slice(None)), - out_shape=(4,)), - ]), - ("EllipsisIndex", [ - IndexSpec(shape=(3,), indexer=Ellipsis, out_shape=(3,)), - IndexSpec(shape=(3, 4), indexer=Ellipsis, out_shape=(3, 4)), - IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis), out_shape=(4, 5)), - IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3), out_shape=(3,)), - ]), - ("NoneIndex", [ - IndexSpec(shape=(), indexer=None, out_shape=(1,)), - IndexSpec(shape=(), indexer=(None, None), out_shape=(1, 1)), - IndexSpec(shape=(), indexer=(Ellipsis, None), out_shape=(1,)), - IndexSpec(shape=(3,), indexer=None, out_shape=(1, 3)), - IndexSpec(shape=(3, 4), indexer=None, out_shape=(1, 3, 4)), - IndexSpec(shape=(3, 4), indexer=(Ellipsis, None), out_shape=(3, 4, 1)), - IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis), out_shape=(1, 4)), - IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis), out_shape=(1, 4, 5)), - ]), - ("EmptyIndex", [ - IndexSpec(shape=(), indexer=(), out_shape=()), - IndexSpec(shape=(3,), indexer=(), out_shape=(3,)), - IndexSpec(shape=(3, 4), indexer=(), out_shape=(3, 4)), - ]), - ("TupleOfIntAndSliceAndIntArray", [ - IndexSpec(shape=(3, 2, 3), indexer=(0, slice(None), np.arange(3)), - out_shape=(3, 2)), - IndexSpec(shape=(3, 2, 3), indexer=(np.int32(1), slice(None), np.arange(3)), - out_shape=(3, 2)), - IndexSpec(shape=(3, 2, 3), indexer=(np.array(2), slice(None), np.arange(3)), - out_shape=(3, 2)), - ]), + ("OneIntIndex", [ + IndexSpec(shape=(3,), indexer=1, out_shape=()), + IndexSpec(shape=(3, 3), indexer=0, out_shape=(3,)), + IndexSpec(shape=(3, 4, 5), indexer=2, out_shape=(4, 5)), + IndexSpec(shape=(3,), indexer=-1, out_shape=()), + IndexSpec(shape=(3,), indexer=-2, out_shape=()), + ]), + ("TwoIntIndices", [ + IndexSpec(shape=(3, 3), indexer=(2, 1), out_shape=()), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2), out_shape=(5,)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2), out_shape=(5,)), + ]), + ("ThreeIntIndices", [ + IndexSpec(shape=(3, 4, 5), indexer=(1, 2, 3), out_shape=()), + ]), + ("OneSliceIndex", [ + IndexSpec(shape=(10,), indexer=slice(1, 3), out_shape=(2,)), + IndexSpec(shape=(10,), indexer=slice(1, -1), out_shape=(8,)), + IndexSpec(shape=(10,), indexer=slice(None, -1), out_shape=(9,)), + IndexSpec(shape=(10,), indexer=slice(None, None, None), out_shape=(10,)), + IndexSpec(shape=(10, 8), indexer=slice(1, 3), out_shape=(2, 8)), + IndexSpec(shape=(10, 8), indexer=slice(1, None), out_shape=(9, 8)), + IndexSpec(shape=(10, 8), indexer=slice(None, 3), out_shape=(3, 8)), + IndexSpec(shape=(10, 8), indexer=slice(-3, None), out_shape=(3, 8)), + ]), + ("OneSliceIndexNegativeStride", [ + IndexSpec(shape=(10,), indexer=slice(3, 1, -1), out_shape=(2,)), + IndexSpec(shape=(10,), indexer=slice(1, 8, -1), out_shape=(0,)), + IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)), + IndexSpec(shape=(10,), indexer=slice(None, None, -1), out_shape=(10,)), + IndexSpec(shape=(10, 8), indexer=slice(3, 1, -1), out_shape=(2, 8)), + IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)), + ]), + ("OneSliceIndexNonUnitStride", [ + IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)), + IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)), + IndexSpec(shape=(10,), indexer=slice(1, 3, 2), out_shape=(1,)), + IndexSpec(shape=(10,), indexer=slice(1, None, 2), out_shape=(5,)), + IndexSpec(shape=(10,), indexer=slice(None, 1, -2), out_shape=(4,)), + IndexSpec(shape=(10, 8), indexer=slice(1, 8, 3), out_shape=(3, 8)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, 2), out_shape=(5, 8)), + IndexSpec(shape=(10, 8), indexer=slice(None, 1, -2), out_shape=(4, 8)), + IndexSpec(shape=(10, 8), indexer=slice(None, None, -2), out_shape=(5, 8)), + ]), + ("TwoSliceIndices", [ + IndexSpec(shape=(10, 8), indexer=(slice(1, 3), slice(0, 2)), + out_shape=(2, 2)), + IndexSpec(shape=(10, 8), indexer=(slice(1, None), slice(None, 2)), + out_shape=(9, 2)), + IndexSpec(shape=(10, 8), indexer=(slice(None, None, -1), slice(None, 2)), + out_shape=(10, 2)), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, 2)), + out_shape=(2, 2, 3)), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, 3), slice(0, None)), + out_shape=(2, 8, 3)), + IndexSpec(shape=(10, 8, 3), indexer=(slice(1, None), slice(0, 2)), + out_shape=(9, 2, 3)), + ]), + ("OneColonIndex", [ + IndexSpec(shape=(3,), indexer=slice(None), out_shape=(3,)), + IndexSpec(shape=(3, 4), indexer=slice(None), out_shape=(3, 4)), + ]), + ("MultipleColonIndices", [ + IndexSpec(shape=(3, 4), indexer=(slice(None), slice(None)), + out_shape=(3, 4)), + IndexSpec(shape=(3, 4, 5), indexer=(slice(None), slice(None)), + out_shape=(3, 4, 5)), + ]), + ("MixedSliceIndices", [ + IndexSpec(shape=(10, 4), indexer=(slice(None), slice(0, 2)), + out_shape=(10, 2)), + IndexSpec(shape=(10, 4), indexer=(1, slice(None)), + out_shape=(4,)), + ]), + ("EllipsisIndex", [ + IndexSpec(shape=(3,), indexer=Ellipsis, out_shape=(3,)), + IndexSpec(shape=(3, 4), indexer=Ellipsis, out_shape=(3, 4)), + IndexSpec(shape=(3, 4, 5), indexer=(0, Ellipsis), out_shape=(4, 5)), + IndexSpec(shape=(3, 4, 5), indexer=(Ellipsis, 2, 3), out_shape=(3,)), + ]), + ("NoneIndex", [ + IndexSpec(shape=(), indexer=None, out_shape=(1,)), + IndexSpec(shape=(), indexer=(None, None), out_shape=(1, 1)), + IndexSpec(shape=(), indexer=(Ellipsis, None), out_shape=(1,)), + IndexSpec(shape=(3,), indexer=None, out_shape=(1, 3)), + IndexSpec(shape=(3, 4), indexer=None, out_shape=(1, 3, 4)), + IndexSpec(shape=(3, 4), indexer=(Ellipsis, None), out_shape=(3, 4, 1)), + IndexSpec(shape=(3, 4), indexer=(0, None, Ellipsis), out_shape=(1, 4)), + IndexSpec(shape=(3, 4, 5), indexer=(1, None, Ellipsis), out_shape=(1, 4, 5)), + ]), + ("EmptyIndex", [ + IndexSpec(shape=(), indexer=(), out_shape=()), + IndexSpec(shape=(3,), indexer=(), out_shape=(3,)), + IndexSpec(shape=(3, 4), indexer=(), out_shape=(3, 4)), + ]), + ("TupleOfIntAndSliceAndIntArray", [ + IndexSpec(shape=(3, 2, 3), indexer=(0, slice(None), np.arange(3)), + out_shape=(3, 2)), + IndexSpec(shape=(3, 2, 3), indexer=(np.int32(1), slice(None), np.arange(3)), + out_shape=(3, 2)), + IndexSpec(shape=(3, 2, 3), indexer=(np.array(2), slice(None), np.arange(3)), + out_shape=(3, 2)), + ]), ] STATIC_INDEXING_OUT_OF_BOUNDS_TESTS = [ - ("OneIntIndex", [ - IndexSpec(shape=(3,), indexer=-4, out_shape=()), - IndexSpec(shape=(3, 3), indexer=3, out_shape=(3,)), - IndexSpec(shape=(3, 4, 5), indexer=4, out_shape=(4, 5)), - ]), - ("TwoIntIndices", [ - IndexSpec(shape=(3, 3), indexer=(2, -4), out_shape=()), - IndexSpec(shape=(3, 4, 5), indexer=(3, 2), out_shape=()), - IndexSpec(shape=(3, 4, 5), indexer=(-4, 4), out_shape=(5,)), - ]), + ("OneIntIndex", [ + IndexSpec(shape=(3,), indexer=-4, out_shape=()), + IndexSpec(shape=(3, 3), indexer=3, out_shape=(3,)), + IndexSpec(shape=(3, 4, 5), indexer=4, out_shape=(4, 5)), + ]), + ("TwoIntIndices", [ + IndexSpec(shape=(3, 3), indexer=(2, -4), out_shape=()), + IndexSpec(shape=(3, 4, 5), indexer=(3, 2), out_shape=()), + IndexSpec(shape=(3, 4, 5), indexer=(-4, 4), out_shape=(5,)), + ]), ] ADVANCED_INDEXING_TESTS = [ - ("One1DIntArrayIndex", [ - IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), - IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1]), out_shape=(3, 3)), - IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1]), - out_shape=(4, 4, 5)), - IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), - IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), - IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), - out_shape=(0,)), - ]), - ("One2DIntArrayIndex", [ - IndexSpec(shape=(3,), indexer=np.array([[0, 0]]), out_shape=(1, 2)), - IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], [0, 1, -1]]), - out_shape=(2, 3, 3)), - IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], [-1, -2, 1, 0]]), - out_shape=(2, 4, 4, 5)), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", [ - IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), - out_shape=(2,)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2, 0, 1]), np.array([-1, 0, -1, 2])), - out_shape=(4, 5)), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", [ - IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), - out_shape=(1, 2)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([[0, 2, 0, 1]]), np.array([-1, 0, -1, 2])), - out_shape=(1, 4, 5)), - ]), - ("ArrayOfInts", [ - IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)), - IndexSpec(shape=(3, 4, 5), indexer=np.array([0, -1]), out_shape=(2, 4, 5)), - ]), - ("TupleOfListsOfPythonInts", [ - IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]), - out_shape=(2, 4, 5)), - ]), - ("TupleOfPythonIntsAndIntArrays", [ - IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0, 3]])), - out_shape=(1, 4)), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", [ - IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), - out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0, 3]])), - out_shape=(2, 4, 5)), - ]), + ("One1DIntArrayIndex", [ + IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), + IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1]), out_shape=(3, 3)), + IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1]), + out_shape=(4, 4, 5)), + IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), + IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), + IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), + out_shape=(0,)), + ]), + ("One2DIntArrayIndex", [ + IndexSpec(shape=(3,), indexer=np.array([[0, 0]]), out_shape=(1, 2)), + IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], [0, 1, -1]]), + out_shape=(2, 3, 3)), + IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], [-1, -2, 1, 0]]), + out_shape=(2, 4, 4, 5)), + ]), + ("Two1DIntArrayIndicesNoBroadcasting", [ + IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), + out_shape=(2,)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2, 0, 1]), np.array([-1, 0, -1, 2])), + out_shape=(4, 5)), + ]), + ("Two1DIntArrayIndicesWithBroadcasting", [ + IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), + out_shape=(1, 2)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([[0, 2, 0, 1]]), np.array([-1, 0, -1, 2])), + out_shape=(1, 4, 5)), + ]), + ("ArrayOfInts", [ + IndexSpec(shape=(3,), indexer=np.array([0, 1, 0]), out_shape=(3,)), + IndexSpec(shape=(3, 4, 5), indexer=np.array([0, -1]), out_shape=(2, 4, 5)), + ]), + ("TupleOfListsOfPythonInts", [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0, 3]]), + out_shape=(2, 4, 5)), + ]), + ("TupleOfPythonIntsAndIntArrays", [ + IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0, 3]])), + out_shape=(1, 4)), + ]), + ("TupleOfListsOfPythonIntsAndIntArrays", [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), + out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0, 3]])), + out_shape=(2, 4, 5)), + ]), ] ADVANCED_INDEXING_TESTS_NO_REPEATS = [ - ("One1DIntArrayIndex", [ - IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), - IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 0]), out_shape=(3, 3)), - IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 1]), - out_shape=(3, 4, 5)), - IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), - IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), - IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), - ]), - ("One2DIntArrayIndex", [ - IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)), - IndexSpec(shape=(6, 6), indexer=np.array([[1, 2, 0], [3, 4, -1]]), - out_shape=(2, 3, 6)), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", [ - IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), out_shape=(2,)), - IndexSpec(shape=(4, 5, 6), indexer=(np.array([0, 2, 1, 3]), np.array([-1, 0, -2, 1])), - out_shape=(4, 6)), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", [ - IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), out_shape=(1, 2)), - IndexSpec(shape=(4, 5, 6), - indexer=(np.array([[0, 2, -1, 1]]), np.array([-1, 0, -2, 2])), out_shape=(1, 4, 6)), - ]), - ("ArrayOfInts", [ - IndexSpec(shape=(3,), indexer=np.array([0, 2, 1]), out_shape=(3,)), - IndexSpec(shape=(3, 4, 5), indexer=np.array([0, -1]), out_shape=(2, 4, 5)), - ]), - ("TupleOfListsOfPythonInts", [ - IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]]), out_shape=(2, 3, 5)), - ]), - ("TupleOfPythonIntsAndIntArrays", [ - IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0]])), out_shape=(1, 3)), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", [ - IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), - out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0]])), - out_shape=(2, 3, 5)), - ]), + ("One1DIntArrayIndex", [ + IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), + IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 0]), out_shape=(3, 3)), + IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 1]), + out_shape=(3, 4, 5)), + IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), + IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), + IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), + ]), + ("One2DIntArrayIndex", [ + IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)), + IndexSpec(shape=(6, 6), indexer=np.array([[1, 2, 0], [3, 4, -1]]), + out_shape=(2, 3, 6)), + ]), + ("Two1DIntArrayIndicesNoBroadcasting", [ + IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), out_shape=(2,)), + IndexSpec(shape=(4, 5, 6), indexer=(np.array([0, 2, 1, 3]), np.array([-1, 0, -2, 1])), + out_shape=(4, 6)), + ]), + ("Two1DIntArrayIndicesWithBroadcasting", [ + IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), out_shape=(1, 2)), + IndexSpec(shape=(4, 5, 6), + indexer=(np.array([[0, 2, -1, 1]]), np.array([-1, 0, -2, 2])), out_shape=(1, 4, 6)), + ]), + ("ArrayOfInts", [ + IndexSpec(shape=(3,), indexer=np.array([0, 2, 1]), out_shape=(3,)), + IndexSpec(shape=(3, 4, 5), indexer=np.array([0, -1]), out_shape=(2, 4, 5)), + ]), + ("TupleOfListsOfPythonInts", [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[2, 3, 0]]), out_shape=(2, 3, 5)), + ]), + ("TupleOfPythonIntsAndIntArrays", [ + IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[2, 3, 0]])), out_shape=(1, 3)), + ]), + ("TupleOfListsOfPythonIntsAndIntArrays", [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), + out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[2, 3, 0]])), + out_shape=(2, 3, 5)), + ]), ] ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED = [ - ("One1DIntArrayIndex", [ - IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), - IndexSpec(shape=(3, 3), indexer=np.array([0, 1, 2]), out_shape=(3, 3)), - IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 1, 2]), out_shape=(3, 4, 5)), - IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), - IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), - IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), - ]), - ("One2DIntArrayIndex", [ - IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)), - IndexSpec(shape=(6, 6), indexer=np.array([[-1, 0, 1], - [2, 3, 4]]), out_shape=(2, 3, 6)), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", [ - IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), - out_shape=(2,)), - IndexSpec(shape=(4, 5, 6), - indexer=(np.array([0, 1, 2, 3]), np.array([-2, -1, 0, 1])), - out_shape=(4, 6)), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", [ - IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), - out_shape=(1, 2)), - IndexSpec(shape=(4, 5, 6), - indexer=(np.array([[-1, 0, 1, 2]]), np.array([-2, -1, 0, 2])), - out_shape=(1, 4, 6)), - ]), - ("TupleOfListsOfPythonInts", [ - IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[0, 2, 3]]), - out_shape=(2, 3, 5)), - ]), - ("TupleOfPythonIntsAndIntArrays", [ - IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[0, 2, 3]])), - out_shape=(1, 3)), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", [ - IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), - out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[0, 2, 3]])), - out_shape=(2, 3, 5)), - ]), + ("One1DIntArrayIndex", [ + IndexSpec(shape=(3,), indexer=np.array([0, 1]), out_shape=(2,)), + IndexSpec(shape=(3, 3), indexer=np.array([0, 1, 2]), out_shape=(3, 3)), + IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 1, 2]), out_shape=(3, 4, 5)), + IndexSpec(shape=(3,), indexer=np.array([-1, 1]), out_shape=(2,)), + IndexSpec(shape=(3,), indexer=np.array([-2, -1]), out_shape=(2,)), + IndexSpec(shape=(0,), indexer=np.array([], dtype=np.int32), out_shape=(0,)), + ]), + ("One2DIntArrayIndex", [ + IndexSpec(shape=(3,), indexer=np.array([[0, 1]]), out_shape=(1, 2)), + IndexSpec(shape=(6, 6), indexer=np.array([[-1, 0, 1], + [2, 3, 4]]), out_shape=(2, 3, 6)), + ]), + ("Two1DIntArrayIndicesNoBroadcasting", [ + IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), np.array([1, 2])), + out_shape=(2,)), + IndexSpec(shape=(4, 5, 6), + indexer=(np.array([0, 1, 2, 3]), np.array([-2, -1, 0, 1])), + out_shape=(4, 6)), + ]), + ("Two1DIntArrayIndicesWithBroadcasting", [ + IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), np.array([1, 2])), + out_shape=(1, 2)), + IndexSpec(shape=(4, 5, 6), + indexer=(np.array([[-1, 0, 1, 2]]), np.array([-2, -1, 0, 2])), + out_shape=(1, 4, 6)), + ]), + ("TupleOfListsOfPythonInts", [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1],), out_shape=(2, 4, 5)), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], [[0, 2, 3]]), + out_shape=(2, 3, 5)), + ]), + ("TupleOfPythonIntsAndIntArrays", [ + IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1])), out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, np.array([[0, 2, 3]])), + out_shape=(1, 3)), + ]), + ("TupleOfListsOfPythonIntsAndIntArrays", [ + IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0])), + out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], np.array([[0, 2, 3]])), + out_shape=(2, 3, 5)), + ]), ] MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS = [ - ("SlicesAndOneIntArrayIndex", [ - IndexSpec(shape=(2, 3), indexer=(np.array([0, 1]), slice(1, 2)), - out_shape=(2, 1)), - IndexSpec(shape=(2, 3), indexer=(slice(0, 2), np.array([0, 2])), - out_shape=(2, 2)), - IndexSpec(shape=(3, 4, 5), - indexer=(Ellipsis, np.array([0, 2]), slice(None)), - out_shape=(3, 2, 5)), - IndexSpec(shape=(3, 4, 5), - indexer=(Ellipsis, np.array([[0, 2], [1, 3]]), slice(None)), - out_shape=(3, 2, 2, 5)), - ]), - ("SlicesAndTwoIntArrayIndices", [ - IndexSpec(shape=(3, 4, 5), - indexer=(Ellipsis, np.array([0, 2]), np.array([-1, 2])), - out_shape=(3, 2)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2]), Ellipsis, np.array([-1, 2])), - out_shape=(2, 4)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2]), np.array([-1, 2]), Ellipsis), - out_shape=(2, 5)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2]), np.array([-1, 2]), slice(1, 3)), - out_shape=(2, 2)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2]), slice(1, 3), np.array([-1, 2])), - out_shape=(2, 2)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2, -2]), slice(None, None, 2), - np.array([-1, 2, 1])), - out_shape=(3, 2)), - ]), - ("NonesAndIntArrayIndices", [ - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2]), None, np.array([-1, 2])), - out_shape=(2, 1, 5)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2]), None, None, np.array([-1, 2])), - out_shape=(2, 1, 1, 5)), - IndexSpec(shape=(3, 4, 5), - indexer=(Ellipsis, np.array([0, 2]), None, None, - np.array([-1, 2])), - out_shape=(2, 3, 1, 1)), - ]), - ("IntArrayWithInt32Type", [ - IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)), - out_shape=(3,)), - ]), + ("SlicesAndOneIntArrayIndex", [ + IndexSpec(shape=(2, 3), indexer=(np.array([0, 1]), slice(1, 2)), + out_shape=(2, 1)), + IndexSpec(shape=(2, 3), indexer=(slice(0, 2), np.array([0, 2])), + out_shape=(2, 2)), + IndexSpec(shape=(3, 4, 5), + indexer=(Ellipsis, np.array([0, 2]), slice(None)), + out_shape=(3, 2, 5)), + IndexSpec(shape=(3, 4, 5), + indexer=(Ellipsis, np.array([[0, 2], [1, 3]]), slice(None)), + out_shape=(3, 2, 2, 5)), + ]), + ("SlicesAndTwoIntArrayIndices", [ + IndexSpec(shape=(3, 4, 5), + indexer=(Ellipsis, np.array([0, 2]), np.array([-1, 2])), + out_shape=(3, 2)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2]), Ellipsis, np.array([-1, 2])), + out_shape=(2, 4)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2]), np.array([-1, 2]), Ellipsis), + out_shape=(2, 5)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2]), np.array([-1, 2]), slice(1, 3)), + out_shape=(2, 2)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2]), slice(1, 3), np.array([-1, 2])), + out_shape=(2, 2)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2, -2]), slice(None, None, 2), + np.array([-1, 2, 1])), + out_shape=(3, 2)), + ]), + ("NonesAndIntArrayIndices", [ + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2]), None, np.array([-1, 2])), + out_shape=(2, 1, 5)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2]), None, None, np.array([-1, 2])), + out_shape=(2, 1, 1, 5)), + IndexSpec(shape=(3, 4, 5), + indexer=(Ellipsis, np.array([0, 2]), None, None, + np.array([-1, 2])), + out_shape=(2, 3, 1, 1)), + ]), + ("IntArrayWithInt32Type", [ + IndexSpec(shape=(3, 4), indexer=(Ellipsis, np.array(1, dtype=np.int32)), + out_shape=(3,)), + ]), ] MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [ - ("SlicesAndOneIntArrayIndex", [ - IndexSpec(shape=(3, 4, 5), - indexer=(Ellipsis, np.array([[0, 2], [1, 1]]), slice(None)), - out_shape=(3, 2, 2, 5)), - ]), - ("SlicesAndTwoIntArrayIndices", [ - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([0, 2, -2]), slice(None, None, 2), - np.array([-1, 2, -1])), - out_shape=(3, 2)), - IndexSpec(shape=(3, 4, 5), - indexer=(np.array([[0, 2], [2, 0]]), Ellipsis, - np.array([[1, 0], [1, 0]])), - out_shape=(2, 2, 4)), - ]), + ("SlicesAndOneIntArrayIndex", [ + IndexSpec(shape=(3, 4, 5), + indexer=(Ellipsis, np.array([[0, 2], [1, 1]]), slice(None)), + out_shape=(3, 2, 2, 5)), + ]), + ("SlicesAndTwoIntArrayIndices", [ + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([0, 2, -2]), slice(None, None, 2), + np.array([-1, 2, -1])), + out_shape=(3, 2)), + IndexSpec(shape=(3, 4, 5), + indexer=(np.array([[0, 2], [2, 0]]), Ellipsis, + np.array([[1, 0], [1, 0]])), + out_shape=(2, 2, 4)), + ]), ] MODES = ["clip", "drop", "promise_in_bounds"] + @pytest.mark.skipif(True, reason="No longer need to test.") class IndexingTest(jtu.JaxTestCase): - """Tests for Numpy indexing translation rules.""" - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_indexer={}".format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, - "dtype": dtype, - "indexer": indexer} - for name, index_specs in STATIC_INDEXING_TESTS - for shape, indexer, _ in index_specs - for dtype in all_dtypes) - ) - def testStaticIndexing(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda x: np.asarray(x)[indexer] - jnp_fun = lambda x: bm.asarray(x)[indexer] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - # Tests x.at[...].get(...) as well. - jnp_fun = lambda x: bm.asarray(x).at[indexer].get() - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters( - jtu.cases_from_list({"testcase_name": f"_{funcname}", "funcname": funcname} - for funcname in - ["negative", "sin", "cos", "square", "sqrt", "log", "exp"]) - ) - def testIndexApply(self, funcname, size=10, dtype='float32'): - if not hasattr(jnp.zeros(1).at[0], 'apply'): - self.skipTest('Has not apply() function') - - rng = jtu.rand_default(self.rng()) - idx_rng = jtu.rand_int(self.rng(), -size, size) - np_func = getattr(np, funcname) - jnp_func = getattr(jnp, funcname) - - @jtu.ignore_warning(category=RuntimeWarning) - def np_op(x, idx): - y = x.copy() - np_func.at(y, idx) - return y - - def jnp_op(x, idx): - return bm.asarray(x).at[idx].apply(jnp_func) - - args_maker = lambda: [rng(size, dtype), idx_rng(size, int)] - self._CheckAgainstNumpy(np_op, jnp_op, args_maker) - self._CompileAndCheck(jnp_op, args_maker) - - @parameterized.named_parameters({ - "testcase_name": - f"{jtu.format_shape_dtype_string(shape, dtype)}_inshape={name}" - f"_indexer={indexer}_mode={mode}", - "shape": shape, "dtype": dtype, "indexer": indexer, "mode": mode - } - for mode in MODES - for name, index_specs in ( - STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else - STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS) - for shape, indexer, _ in index_specs - for dtype in float_dtypes) - def testStaticIndexingGrads(self, shape, dtype, indexer, mode): - rng = jtu.rand_default(self.rng()) - tol = 1e-2 if bm.finfo(dtype).bits == 32 else None - arg = rng(shape, dtype) - # Use an arbitrary finite fill_value, since NaNs won't work in a numerical - # gradient test. - fun = lambda x: bm.asarray(x).at[indexer].get(mode=mode, fill_value=7) ** 2 - check_grads(fun, (arg,), 2, tol, tol, tol) - - def _ReplaceSlicesWithTuples(self, idx): - """Helper method to replace slices with tuples for dynamic indexing args.""" - if isinstance(idx, slice): - triple = idx.start, idx.stop, idx.step - isnone = [i for i, elt in enumerate(triple) if elt is None] - zeros = itertools.repeat(0) - nones = itertools.repeat(None) - out = util.subvals(triple, zip(isnone, zeros)) - return out, lambda out: slice(*util.subvals(out, zip(isnone, nones))) - elif isinstance(idx, (tuple, list)) and idx: - t = type(idx) - elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) - return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) - else: - return idx, lambda x: x - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "indexer": indexer} - for name, index_specs in [ - ("OneSliceIndex", - [IndexSpec(shape=(5,), indexer=slice(1, 3)), - IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), - ("TwoSliceIndices", - [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), - ("NonUnitStrides", [ - IndexSpec(shape=(3,), indexer=slice(None, None, -1)), - IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), - IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) - ]), - ("OnlyStartOrStopDynamic", [ - IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), - IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) - ]), - ] - for shape, indexer, _ in index_specs - for dtype in all_dtypes) - def testDynamicIndexingWithSlicesErrors(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - @jax.jit - def fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return x[indexer] - - args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self.assertRaises(IndexError, lambda: fun(*args_maker())) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "indexer": indexer} - for name, index_specs in [ - ("OneIntIndex", - [IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2)]), - ("TwoIntIndices", - [IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), - ("ThreeIntIndices", - [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ] - for shape, indexer, _ in index_specs - for dtype in all_dtypes) - def testDynamicIndexingWithIntegers(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - def np_fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return np.asarray(x)[indexer] - - def jnp_fun(x, unpacked_indexer): - indexer = pack_indexer(unpacked_indexer) - return bm.array(x)[indexer] - - args_maker = lambda: [rng(shape, dtype), unpacked_indexer] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "indexer": indexer} - for name, index_specs in [ - ("OneIntIndex", - [IndexSpec(shape=(3,), indexer=1), - IndexSpec(shape=(3, 3), indexer=0), - IndexSpec(shape=(3, 4, 5), indexer=2), - IndexSpec(shape=(3,), indexer=-1), - IndexSpec(shape=(3,), indexer=-2), - ]), - ("TwoIntIndices", - [IndexSpec(shape=(3, 3), indexer=(2, 1)), - IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), - IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), - ]), - ("ThreeIntIndices", - [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), - ] - for shape, indexer, _ in index_specs - for dtype in float_dtypes) - def testDynamicIndexingWithIntegersGrads(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - tol = 1e-2 if bm.finfo(dtype).bits == 32 else None - unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) - - @jax.jit - def fun(unpacked_indexer, x): - indexer = pack_indexer(unpacked_indexer) - return x[indexer] - - arr = rng(shape, dtype) - check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "indexer": indexer} - for name, index_specs in ADVANCED_INDEXING_TESTS - for shape, indexer, _ in index_specs - for dtype in all_dtypes) - def testAdvancedIntegerIndexing(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), indexer] - np_fun = lambda x, idx: np.asarray(x)[idx] - jnp_fun = lambda x, idx: bm.asarray(x)[idx] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "indexer": indexer} - for name, index_specs in [ - ("One1DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=np.array([0, 1])), - IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])), - IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])), - IndexSpec(shape=(3,), indexer=np.array([-1, 1])), - IndexSpec(shape=(3,), indexer=np.array([-2, -1])), - ]), - ("One2DIntArrayIndex", - [IndexSpec(shape=(3,), indexer=np.array([[0, 0]])), - IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], - [0, 1, -1]])), - IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], - [-1, -2, 1, 0]])), - ]), - ("Two1DIntArrayIndicesNoBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), - np.array([1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, 0, 1]), - np.array([-1, 0, -1, 2]))), - ]), - ("Two1DIntArrayIndicesWithBroadcasting", - [IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), - np.array([1, 2]))), - IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2, 0, 1]]), - np.array([-1, 0, -1, 2]))), - ]), - ("TupleOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1]))), - IndexSpec(shape=(3, 4, 5), indexer=(0, 1, - np.array([[2, 3, 0, 3]]))), - ]), - ("TupleOfListsOfPythonIntsAndIntArrays", - [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0]))), - IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], - np.array([[2, 3, 0, 3]]))), - ]), - ] - for shape, indexer, _ in index_specs - for dtype in float_dtypes) - def testAdvancedIntegerIndexingGrads(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - tol = 1e-2 if bm.finfo(dtype).bits == 32 else None - arg = rng(shape, dtype) - fun = lambda x: bm.asarray(x)[indexer] - check_grads(fun, (arg,), 2, tol, tol, eps=1.) - - @parameterized.named_parameters( - {"testcase_name": "{}_inshape={}_indexer={}" - .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), - "shape": shape, "dtype": dtype, "indexer": indexer} - for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS - for shape, indexer, _ in index_specs - for dtype in all_dtypes) - def testMixedAdvancedIntegerIndexing(self, shape, dtype, indexer): - rng = jtu.rand_default(self.rng()) - indexer_with_dummies = [e if isinstance(e, np.ndarray) else () - for e in indexer] - substitutes = [(i, e) for i, e in enumerate(indexer) - if not isinstance(e, np.ndarray)] - args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] - - def jnp_fun(x, indexer_with_dummies): - idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) - return bm.asarray(x)[idx] - - def np_fun(x, indexer_with_dummies): - idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) - return np.asarray(x)[idx] - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testAdvancedIndexingManually(self): - x = self.rng().randn(3, 4, 5) - index_array = np.array([0, 2, -1, 0]) - - op = lambda x, index_array: x[..., index_array, :] - cop = jax.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2) - - op = lambda x, index_array: x[..., index_array, :, index_array, None] - cop = jax.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2) - - op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] - cop = jax.jit(op) - - a1 = op(x, index_array) - a2 = cop(x, index_array) - - self.assertAllClose(a1, a2) - - def testUnpacking(self): - - def foo(x): - a, b, c = x - return a + b + c - - cfoo = jax.jit(foo) - - a1 = foo(np.arange(3)) - a2 = cfoo(np.arange(3)) - - self.assertAllClose(a1, a2) - - def testBooleanIndexingArray1D(self): - idx = np.array([True, True, False]) - x = jax.device_put(np.arange(3)) - ans = x[idx] - expected = np.arange(3)[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingList1D(self): - idx = [True, True, False] - x = jax.device_put(np.arange(3)) - with self.assertRaisesRegex(TypeError, ARRAY_MSG): - x[idx] - - def testBooleanIndexingArray2DBroadcast(self): - idx = np.array([True, True, False, True]) - x = np.arange(8).reshape(4, 2) - ans = jax.device_put(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingList2DBroadcast(self): - idx = [True, True, False, True] - x = np.arange(8).reshape(4, 2) - with self.assertRaisesRegex(TypeError, ARRAY_MSG): - jax.device_put(x)[idx] - - def testBooleanIndexingArray2D(self): - idx = np.array([[True, False], - [False, True], - [False, False], - [True, True]]) - x = np.arange(8).reshape(4, 2) - ans = jax.device_put(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBoolean1DIndexingWithEllipsis(self): - # Regression test for https://github.com/google/jax/issues/8412 - x = np.arange(24).reshape(4, 3, 2) - idx = (..., np.array([True, False])) - ans = bm.array(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBoolean1DIndexingWithEllipsis2(self): - # Regression test for https://github.com/google/jax/issues/9050 - x = np.arange(3) - idx = (..., np.array([True, False, True])) - ans = bm.array(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBoolean1DIndexingWithEllipsis3(self): - x = np.arange(6).reshape(2, 3) - idx = (0, ..., np.array([True, False, True])) - ans = bm.array(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBoolean2DIndexingWithEllipsis(self): - x = np.arange(24).reshape(4, 3, 2) - idx = (..., np.array([[True, False], [True, False], [False, False]])) - ans = bm.array(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBoolean1DIndexingWithTrailingEllipsis(self): - x = np.arange(24).reshape(4, 3, 2) - idx = (np.array([True, False, True, False]), ...) - ans = bm.array(x)[idx] - expected = x[idx] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingDynamicShapeError(self): - x = np.zeros(3) - i = np.array([True, True, False]) - self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i)) - - def testScalarBooleanIndexingNotImplemented(self): - msg = "JAX arrays do not support boolean scalar indices" - with self.assertRaisesRegex(TypeError, msg): - bm.arange(4)[True] - with self.assertRaisesRegex(TypeError, msg): - bm.arange(4)[False] - with self.assertRaisesRegex(TypeError, msg): - bm.arange(4)[..., True] - - def testIssue187(self): - x = bm.ones((5, 5)) - x[[0, 2, 4], [0, 2, 4]] # doesn't crash - - x = np.arange(25).reshape((5, 5)) - ans = jax.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) - expected = x[[0, 2, 4], [0, 2, 4]] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testJVPOfGradOfIndexing(self): - # Should return a value, even though we didn't pass a symbolic zero as the - # index tangent. - x = bm.ones((3, 4), bm.float32) - i = bm.ones((3,), bm.int32).value - f = lambda x, i: bm.sum(x[i]) - primals, tangents = jax.jvp(jax.grad(f), (x, i), - (x, np.zeros(i.shape, dtypes.float0))) - expected = np.broadcast_to( - np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4)) - self.assertAllClose(expected, primals) - self.assertAllClose(np.zeros_like(x), tangents) - - def testIndexingEmptyDimension(self): - # Issue 2671: XLA error when indexing into dimension of size 0 - x = bm.ones((2, 0)) - # The following work, even on axis 1 of size 0 - with jax.numpy_rank_promotion('allow'): - _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] - - with self.assertRaisesRegex(IndexError, - "index .* is out of bounds for axis .* with size 0"): - _ = np.ones((2, 0))[0, 0] # The numpy error - with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): - _ = x[0, 0] # JAX indexing - with self.assertRaisesRegex(IndexError, - "index is out of bounds for axis .* with size 0"): - jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit - - def testBooleanIndexingWithEmptyResult(self): - # based on a TensorFlow Probability test that started failing after #1622 - x = bm.array([-1]) - mask = bm.array([False]) - ans = x[mask] # doesn't crash - - expected = np.array([-1])[np.array([False])] - self.assertAllClose(ans, expected, check_dtypes=False) - - def testBooleanIndexingShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/7329 - x = bm.arange(4) - idx = bm.array([True, False]) - with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"): - x[idx] - - def testNontrivialBooleanIndexing(self): - # Test nontrivial corner case in boolean indexing shape validation - rng = jtu.rand_default(self.rng()) - index = (rng((2, 3), np.bool_), rng((6,), np.bool_)) - - args_maker = lambda: [rng((2, 3, 6), np.int32)] - np_fun = lambda x: np.asarray(x)[index] - jnp_fun = lambda x: bm.asarray(x)[index] - - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testFloatIndexingError(self): - BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" - with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): - bm.zeros(2)[0.] - with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): - bm.zeros((2, 2))[(0, 0.)] - with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): - bm.zeros((2, 2))[(0, 0.)] - with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): - jax.jit(lambda idx: bm.zeros((2, 2))[idx])((0, 0.)) - with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): - bm.zeros(2).at[0.].add(1.) - with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): - bm.zeros(2).at[0.].set(1.) - - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 - x = bm.arange(5, dtype=bm.int32) + 1 - self.assertAllClose(x, x[:10]) - - idx = bm.array([-10, -6, -5, -4, 0, 3, 4, 5, 6, 100]).value - self.assertArraysEqual( - x.at[idx].get(mode="clip"), - bm.array([1, 1, 1, 2, 1, 4, 5, 5, 5, 5], bm.int32)) - nan = np.nan - self.assertArraysEqual( - x.astype(bm.float32).at[idx].get(mode="fill"), - bm.array([nan, nan, 1, 2, 1, 4, 5, nan, nan, nan], bm.float32)) - imin = np.iinfo(np.int32).min - self.assertArraysEqual( - x.at[idx].get(mode="fill"), - bm.array([imin, imin, 1, 2, 1, 4, 5, imin, imin, imin], bm.int32)) - umax = np.iinfo(np.uint32).max - self.assertArraysEqual( - x.astype(np.uint32).at[idx].get(mode="fill"), - bm.array([umax, umax, 1, 2, 1, 4, 5, umax, umax, umax], bm.uint32)) - self.assertArraysEqual( - x.at[idx].get(mode="fill", fill_value=7), - bm.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], bm.int32)) - - def testIndexingWeakTypes(self): - x = lax_internal._convert_element_type(bm.arange(5), int, weak_type=True) - - a = x.at[0].set(1.0) - self.assertEqual(a.dtype, x.dtype) - self.assertTrue(dtypes.is_weakly_typed(a)) - - b = x.at[0].add(1.0) - self.assertEqual(b.dtype, x.dtype) - self.assertTrue(dtypes.is_weakly_typed(b)) - - c = x.at[0].mul(1.0) - self.assertEqual(c.dtype, x.dtype) - self.assertTrue(dtypes.is_weakly_typed(c)) + """Tests for Numpy indexing translation rules.""" + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_indexer={}".format(name, jtu.format_shape_dtype_string(shape, dtype), + indexer), + "shape": shape, + "dtype": dtype, + "indexer": indexer} + for name, index_specs in STATIC_INDEXING_TESTS + for shape, indexer, _ in index_specs + for dtype in all_dtypes) + ) + def testStaticIndexing(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.asarray(x)[indexer] + jnp_fun = lambda x: bm.asarray(x)[indexer] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + # Tests x.at[...].get(...) as well. + jnp_fun = lambda x: bm.asarray(x).at[indexer].get() + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list({"testcase_name": f"_{funcname}", "funcname": funcname} + for funcname in + ["negative", "sin", "cos", "square", "sqrt", "log", "exp"]) + ) + def testIndexApply(self, funcname, size=10, dtype='float32'): + if not hasattr(jnp.zeros(1).at[0], 'apply'): + self.skipTest('Has not apply() function') + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), -size, size) + np_func = getattr(np, funcname) + jnp_func = getattr(jnp, funcname) + + @jtu.ignore_warning(category=RuntimeWarning) + def np_op(x, idx): + y = x.copy() + np_func.at(y, idx) + return y + + def jnp_op(x, idx): + return bm.asarray(x).at[idx].apply(jnp_func) + + args_maker = lambda: [rng(size, dtype), idx_rng(size, int)] + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + self._CompileAndCheck(jnp_op, args_maker) + + @parameterized.named_parameters({ + "testcase_name": + f"{jtu.format_shape_dtype_string(shape, dtype)}_inshape={name}" + f"_indexer={indexer}_mode={mode}", + "shape": shape, "dtype": dtype, "indexer": indexer, "mode": mode + } + for mode in MODES + for name, index_specs in ( + STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else + STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS) + for shape, indexer, _ in index_specs + for dtype in float_dtypes) + def testStaticIndexingGrads(self, shape, dtype, indexer, mode): + rng = jtu.rand_default(self.rng()) + tol = 1e-2 if bm.finfo(dtype).bits == 32 else None + arg = rng(shape, dtype) + # Use an arbitrary finite fill_value, since NaNs won't work in a numerical + # gradient test. + fun = lambda x: bm.asarray(x).at[indexer].get(mode=mode, fill_value=7) ** 2 + check_grads(fun, (arg,), 2, tol, tol, tol) + + def _ReplaceSlicesWithTuples(self, idx): + """Helper method to replace slices with tuples for dynamic indexing args.""" + if isinstance(idx, slice): + triple = idx.start, idx.stop, idx.step + isnone = [i for i, elt in enumerate(triple) if elt is None] + zeros = itertools.repeat(0) + nones = itertools.repeat(None) + out = util.subvals(triple, zip(isnone, zeros)) + return out, lambda out: slice(*util.subvals(out, zip(isnone, nones))) + elif isinstance(idx, (tuple, list)) and idx: + t = type(idx) + elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) + return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) + else: + return idx, lambda x: x + + @parameterized.named_parameters( + {"testcase_name": "{}_inshape={}_indexer={}" + .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, "dtype": dtype, "indexer": indexer} + for name, index_specs in [ + ("OneSliceIndex", + [IndexSpec(shape=(5,), indexer=slice(1, 3)), + IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), + ("TwoSliceIndices", + [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), + ("NonUnitStrides", [ + IndexSpec(shape=(3,), indexer=slice(None, None, -1)), + IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), + IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) + ]), + ("OnlyStartOrStopDynamic", [ + IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), + IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) + ]), + ] + for shape, indexer, _ in index_specs + for dtype in all_dtypes) + def testDynamicIndexingWithSlicesErrors(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + @jax.jit + def fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return x[indexer] + + args_maker = lambda: [rng(shape, dtype), unpacked_indexer] + self.assertRaises(IndexError, lambda: fun(*args_maker())) + + @parameterized.named_parameters( + {"testcase_name": "{}_inshape={}_indexer={}" + .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, "dtype": dtype, "indexer": indexer} + for name, index_specs in [ + ("OneIntIndex", + [IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2)]), + ("TwoIntIndices", + [IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), + ("ThreeIntIndices", + [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ] + for shape, indexer, _ in index_specs + for dtype in all_dtypes) + def testDynamicIndexingWithIntegers(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + def np_fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return np.asarray(x)[indexer] + + def jnp_fun(x, unpacked_indexer): + indexer = pack_indexer(unpacked_indexer) + return bm.array(x)[indexer] + + args_maker = lambda: [rng(shape, dtype), unpacked_indexer] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters( + {"testcase_name": "{}_inshape={}_indexer={}" + .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, "dtype": dtype, "indexer": indexer} + for name, index_specs in [ + ("OneIntIndex", + [IndexSpec(shape=(3,), indexer=1), + IndexSpec(shape=(3, 3), indexer=0), + IndexSpec(shape=(3, 4, 5), indexer=2), + IndexSpec(shape=(3,), indexer=-1), + IndexSpec(shape=(3,), indexer=-2), + ]), + ("TwoIntIndices", + [IndexSpec(shape=(3, 3), indexer=(2, 1)), + IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), + IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), + ]), + ("ThreeIntIndices", + [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), + ] + for shape, indexer, _ in index_specs + for dtype in float_dtypes) + def testDynamicIndexingWithIntegersGrads(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + tol = 1e-2 if bm.finfo(dtype).bits == 32 else None + unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) + + @jax.jit + def fun(unpacked_indexer, x): + indexer = pack_indexer(unpacked_indexer) + return x[indexer] + + arr = rng(shape, dtype) + check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol) + + @parameterized.named_parameters( + {"testcase_name": "{}_inshape={}_indexer={}" + .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, "dtype": dtype, "indexer": indexer} + for name, index_specs in ADVANCED_INDEXING_TESTS + for shape, indexer, _ in index_specs + for dtype in all_dtypes) + def testAdvancedIntegerIndexing(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), indexer] + np_fun = lambda x, idx: np.asarray(x)[idx] + jnp_fun = lambda x, idx: bm.asarray(x)[idx] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters( + {"testcase_name": "{}_inshape={}_indexer={}" + .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, "dtype": dtype, "indexer": indexer} + for name, index_specs in [ + ("One1DIntArrayIndex", + [IndexSpec(shape=(3,), indexer=np.array([0, 1])), + IndexSpec(shape=(3, 3), indexer=np.array([1, 2, 1])), + IndexSpec(shape=(3, 4, 5), indexer=np.array([0, 2, 0, 1])), + IndexSpec(shape=(3,), indexer=np.array([-1, 1])), + IndexSpec(shape=(3,), indexer=np.array([-2, -1])), + ]), + ("One2DIntArrayIndex", + [IndexSpec(shape=(3,), indexer=np.array([[0, 0]])), + IndexSpec(shape=(3, 3), indexer=np.array([[1, 2, 1], + [0, 1, -1]])), + IndexSpec(shape=(3, 4, 5), indexer=np.array([[0, 2, 0, 1], + [-1, -2, 1, 0]])), + ]), + ("Two1DIntArrayIndicesNoBroadcasting", + [IndexSpec(shape=(3, 3), indexer=(np.array([0, 1]), + np.array([1, 2]))), + IndexSpec(shape=(3, 4, 5), indexer=(np.array([0, 2, 0, 1]), + np.array([-1, 0, -1, 2]))), + ]), + ("Two1DIntArrayIndicesWithBroadcasting", + [IndexSpec(shape=(3, 3), indexer=(np.array([[0, 1]]), + np.array([1, 2]))), + IndexSpec(shape=(3, 4, 5), indexer=(np.array([[0, 2, 0, 1]]), + np.array([-1, 0, -1, 2]))), + ]), + ("TupleOfPythonIntsAndIntArrays", + [IndexSpec(shape=(3, 4, 5), indexer=(0, np.array([0, 1]))), + IndexSpec(shape=(3, 4, 5), indexer=(0, 1, + np.array([[2, 3, 0, 3]]))), + ]), + ("TupleOfListsOfPythonIntsAndIntArrays", + [IndexSpec(shape=(3, 4, 5), indexer=([0, 1], np.array([0]))), + IndexSpec(shape=(3, 4, 5), indexer=([[0], [-1]], + np.array([[2, 3, 0, 3]]))), + ]), + ] + for shape, indexer, _ in index_specs + for dtype in float_dtypes) + def testAdvancedIntegerIndexingGrads(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + tol = 1e-2 if bm.finfo(dtype).bits == 32 else None + arg = rng(shape, dtype) + fun = lambda x: bm.asarray(x)[indexer] + check_grads(fun, (arg,), 2, tol, tol, eps=1.) + + @parameterized.named_parameters( + {"testcase_name": "{}_inshape={}_indexer={}" + .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), + "shape": shape, "dtype": dtype, "indexer": indexer} + for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS + for shape, indexer, _ in index_specs + for dtype in all_dtypes) + def testMixedAdvancedIntegerIndexing(self, shape, dtype, indexer): + rng = jtu.rand_default(self.rng()) + indexer_with_dummies = [e if isinstance(e, np.ndarray) else () + for e in indexer] + substitutes = [(i, e) for i, e in enumerate(indexer) + if not isinstance(e, np.ndarray)] + args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] + + def jnp_fun(x, indexer_with_dummies): + idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) + return bm.asarray(x)[idx] + + def np_fun(x, indexer_with_dummies): + idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) + return np.asarray(x)[idx] + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testAdvancedIndexingManually(self): + x = self.rng().randn(3, 4, 5) + index_array = np.array([0, 2, -1, 0]) + + op = lambda x, index_array: x[..., index_array, :] + cop = jax.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2) + + op = lambda x, index_array: x[..., index_array, :, index_array, None] + cop = jax.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2) + + op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] + cop = jax.jit(op) + + a1 = op(x, index_array) + a2 = cop(x, index_array) + + self.assertAllClose(a1, a2) + + def testUnpacking(self): + + def foo(x): + a, b, c = x + return a + b + c + + cfoo = jax.jit(foo) + + a1 = foo(np.arange(3)) + a2 = cfoo(np.arange(3)) + + self.assertAllClose(a1, a2) + + def testBooleanIndexingArray1D(self): + idx = np.array([True, True, False]) + x = jax.device_put(np.arange(3)) + ans = x[idx] + expected = np.arange(3)[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList1D(self): + idx = [True, True, False] + x = jax.device_put(np.arange(3)) + with self.assertRaisesRegex(TypeError, ARRAY_MSG): + x[idx] + + def testBooleanIndexingArray2DBroadcast(self): + idx = np.array([True, True, False, True]) + x = np.arange(8).reshape(4, 2) + ans = jax.device_put(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingList2DBroadcast(self): + idx = [True, True, False, True] + x = np.arange(8).reshape(4, 2) + with self.assertRaisesRegex(TypeError, ARRAY_MSG): + jax.device_put(x)[idx] + + def testBooleanIndexingArray2D(self): + idx = np.array([[True, False], + [False, True], + [False, False], + [True, True]]) + x = np.arange(8).reshape(4, 2) + ans = jax.device_put(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBoolean1DIndexingWithEllipsis(self): + # Regression test for https://github.com/google/jax/issues/8412 + x = np.arange(24).reshape(4, 3, 2) + idx = (..., np.array([True, False])) + ans = bm.array(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBoolean1DIndexingWithEllipsis2(self): + # Regression test for https://github.com/google/jax/issues/9050 + x = np.arange(3) + idx = (..., np.array([True, False, True])) + ans = bm.array(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBoolean1DIndexingWithEllipsis3(self): + x = np.arange(6).reshape(2, 3) + idx = (0, ..., np.array([True, False, True])) + ans = bm.array(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBoolean2DIndexingWithEllipsis(self): + x = np.arange(24).reshape(4, 3, 2) + idx = (..., np.array([[True, False], [True, False], [False, False]])) + ans = bm.array(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBoolean1DIndexingWithTrailingEllipsis(self): + x = np.arange(24).reshape(4, 3, 2) + idx = (np.array([True, False, True, False]), ...) + ans = bm.array(x)[idx] + expected = x[idx] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingDynamicShapeError(self): + x = np.zeros(3) + i = np.array([True, True, False]) + self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i)) + + def testScalarBooleanIndexingNotImplemented(self): + msg = "JAX arrays do not support boolean scalar indices" + with self.assertRaisesRegex(TypeError, msg): + bm.arange(4)[True] + with self.assertRaisesRegex(TypeError, msg): + bm.arange(4)[False] + with self.assertRaisesRegex(TypeError, msg): + bm.arange(4)[..., True] + + def testIssue187(self): + x = bm.ones((5, 5)) + x[[0, 2, 4], [0, 2, 4]] # doesn't crash + + x = np.arange(25).reshape((5, 5)) + ans = jax.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) + expected = x[[0, 2, 4], [0, 2, 4]] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testJVPOfGradOfIndexing(self): + # Should return a value, even though we didn't pass a symbolic zero as the + # index tangent. + x = bm.ones((3, 4), bm.float32) + i = bm.ones((3,), bm.int32).value + f = lambda x, i: bm.sum(x[i]) + primals, tangents = jax.jvp(jax.grad(f), (x, i), + (x, np.zeros(i.shape, dtypes.float0))) + expected = np.broadcast_to( + np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4)) + self.assertAllClose(expected, primals) + self.assertAllClose(np.zeros_like(x), tangents) + + def testIndexingEmptyDimension(self): + # Issue 2671: XLA error when indexing into dimension of size 0 + x = bm.ones((2, 0)) + # The following work, even on axis 1 of size 0 + with jax.numpy_rank_promotion('allow'): + _ = x[0, :] + x[0, None] + x[0, 1:] + x[0, 1:3:2] + + with self.assertRaisesRegex(IndexError, + "index .* is out of bounds for axis .* with size 0"): + _ = np.ones((2, 0))[0, 0] # The numpy error + with self.assertRaisesRegex(IndexError, + "index is out of bounds for axis .* with size 0"): + _ = x[0, 0] # JAX indexing + with self.assertRaisesRegex(IndexError, + "index is out of bounds for axis .* with size 0"): + jax.jit(lambda i: x[0, i])(0) # JAX indexing under jit + + def testBooleanIndexingWithEmptyResult(self): + # based on a TensorFlow Probability test that started failing after #1622 + x = bm.array([-1]) + mask = bm.array([False]) + ans = x[mask] # doesn't crash + + expected = np.array([-1])[np.array([False])] + self.assertAllClose(ans, expected, check_dtypes=False) + + def testBooleanIndexingShapeMismatch(self): + # Regression test for https://github.com/google/jax/issues/7329 + x = bm.arange(4) + idx = bm.array([True, False]) + with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"): + x[idx] + + def testNontrivialBooleanIndexing(self): + # Test nontrivial corner case in boolean indexing shape validation + rng = jtu.rand_default(self.rng()) + index = (rng((2, 3), np.bool_), rng((6,), np.bool_)) + + args_maker = lambda: [rng((2, 3, 6), np.int32)] + np_fun = lambda x: np.asarray(x)[index] + jnp_fun = lambda x: bm.asarray(x)[index] + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testFloatIndexingError(self): + BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + bm.zeros(2)[0.] + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + bm.zeros((2, 2))[(0, 0.)] + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + bm.zeros((2, 2))[(0, 0.)] + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + jax.jit(lambda idx: bm.zeros((2, 2))[idx])((0, 0.)) + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + bm.zeros(2).at[0.].add(1.) + with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): + bm.zeros(2).at[0.].set(1.) + + def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + x = bm.arange(5, dtype=bm.int32) + 1 + self.assertAllClose(x, x[:10]) + + idx = bm.array([-10, -6, -5, -4, 0, 3, 4, 5, 6, 100]).value + self.assertArraysEqual( + x.at[idx].get(mode="clip"), + bm.array([1, 1, 1, 2, 1, 4, 5, 5, 5, 5], bm.int32)) + nan = np.nan + self.assertArraysEqual( + x.astype(bm.float32).at[idx].get(mode="fill"), + bm.array([nan, nan, 1, 2, 1, 4, 5, nan, nan, nan], bm.float32)) + imin = np.iinfo(np.int32).min + self.assertArraysEqual( + x.at[idx].get(mode="fill"), + bm.array([imin, imin, 1, 2, 1, 4, 5, imin, imin, imin], bm.int32)) + umax = np.iinfo(np.uint32).max + self.assertArraysEqual( + x.astype(np.uint32).at[idx].get(mode="fill"), + bm.array([umax, umax, 1, 2, 1, 4, 5, umax, umax, umax], bm.uint32)) + self.assertArraysEqual( + x.at[idx].get(mode="fill", fill_value=7), + bm.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], bm.int32)) + + def testIndexingWeakTypes(self): + x = lax_internal._convert_element_type(bm.arange(5), int, weak_type=True) + + a = x.at[0].set(1.0) + self.assertEqual(a.dtype, x.dtype) + self.assertTrue(dtypes.is_weakly_typed(a)) + + b = x.at[0].add(1.0) + self.assertEqual(b.dtype, x.dtype) + self.assertTrue(dtypes.is_weakly_typed(b)) + + c = x.at[0].mul(1.0) + self.assertEqual(c.dtype, x.dtype) + self.assertTrue(dtypes.is_weakly_typed(c)) def _broadcastable_shapes(shape): - """Returns all shapes that broadcast to `shape`.""" + """Returns all shapes that broadcast to `shape`.""" - def f(rshape): - yield [] - if rshape: - for s in f(rshape[1:]): - yield rshape[0:1] + s - if rshape[0] != 1: - for s in f(rshape[1:]): - yield [1] + s + def f(rshape): + yield [] + if rshape: + for s in f(rshape[1:]): + yield rshape[0:1] + s + if rshape[0] != 1: + for s in f(rshape[1:]): + yield [1] + s - for x in f(list(reversed(shape))): - yield list(reversed(x)) + for x in f(list(reversed(shape))): + yield list(reversed(x)) class UpdateOps(enum.Enum): - UPDATE = 0 - ADD = 1 - MUL = 2 - DIV = 3 - POW = 4 - MIN = 5 - MAX = 6 - - def np_fn(op, indexer, x, y): - x = x.copy() - x[indexer] = { - UpdateOps.UPDATE: lambda: y, - UpdateOps.ADD: lambda: x[indexer] + y, - UpdateOps.MUL: lambda: x[indexer] * y, - UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] / y.astype(x.dtype)), - UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( - lambda: x[indexer] ** y.astype(x.dtype)), - UpdateOps.MIN: lambda: np.minimum(x[indexer], y), - UpdateOps.MAX: lambda: np.maximum(x[indexer], y), - }[op]() - return x - - def jax_fn(op, indexer, x, y, indices_are_sorted=False, - unique_indices=False, mode=None): - x = bm.array(x) - return { - UpdateOps.UPDATE: x.at[indexer].set, - UpdateOps.ADD: x.at[indexer].add, - UpdateOps.MUL: x.at[indexer].multiply, - UpdateOps.DIV: x.at[indexer].divide, - UpdateOps.POW: x.at[indexer].power, - UpdateOps.MIN: x.at[indexer].min, - UpdateOps.MAX: x.at[indexer].max, - }[op](y, indices_are_sorted=indices_are_sorted, - unique_indices=unique_indices, mode=mode) - - def dtypes(op): - if op == UpdateOps.UPDATE: - return all_dtypes - elif op == UpdateOps.DIV or op == UpdateOps.POW: - return jtu.dtypes.inexact - else: - return default_dtypes + UPDATE = 0 + ADD = 1 + MUL = 2 + DIV = 3 + POW = 4 + MIN = 5 + MAX = 6 + + def np_fn(op, indexer, x, y): + x = x.copy() + x[indexer] = { + UpdateOps.UPDATE: lambda: y, + UpdateOps.ADD: lambda: x[indexer] + y, + UpdateOps.MUL: lambda: x[indexer] * y, + UpdateOps.DIV: jtu.ignore_warning(category=RuntimeWarning)( + lambda: x[indexer] / y.astype(x.dtype)), + UpdateOps.POW: jtu.ignore_warning(category=RuntimeWarning)( + lambda: x[indexer] ** y.astype(x.dtype)), + UpdateOps.MIN: lambda: np.minimum(x[indexer], y), + UpdateOps.MAX: lambda: np.maximum(x[indexer], y), + }[op]() + return x + + def jax_fn(op, indexer, x, y, indices_are_sorted=False, + unique_indices=False, mode=None): + x = bm.array(x) + return { + UpdateOps.UPDATE: x.at[indexer].set, + UpdateOps.ADD: x.at[indexer].add, + UpdateOps.MUL: x.at[indexer].multiply, + UpdateOps.DIV: x.at[indexer].divide, + UpdateOps.POW: x.at[indexer].power, + UpdateOps.MIN: x.at[indexer].min, + UpdateOps.MAX: x.at[indexer].max, + }[op](y, indices_are_sorted=indices_are_sorted, + unique_indices=unique_indices, mode=mode) + + def dtypes(op): + if op == UpdateOps.UPDATE: + return all_dtypes + elif op == UpdateOps.DIV or op == UpdateOps.POW: + return jtu.dtypes.inexact + else: + return default_dtypes def _update_tol(op): - if op == UpdateOps.POW: - tol = {np.complex64: 1e-4 if jtu.device_under_test() == "tpu" else 1e-5, - np.complex128: 1e-14} - else: - tol = {np.complex128: 1e-14} - return tol + if op == UpdateOps.POW: + tol = {np.complex64: 1e-4 if jtu.device_under_test() == "tpu" else 1e-5, + np.complex128: 1e-14} + else: + tol = {np.complex128: 1e-14} + return tol @pytest.mark.skipif(True, reason="No longer need to test.") @jtu.with_config(jax_numpy_dtype_promotion='standard') class IndexedUpdateTest(jtu.JaxTestCase): - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": - f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_indexer={indexer}" - f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}" - f"_op={op.name}", - "shape": shape, "dtype": dtype, "indexer": indexer, - "update_shape": update_shape, "update_dtype": update_dtype, - "op": op, "mode": mode, - } for name, index_specs in s(STATIC_INDEXING_TESTS) - for shape, indexer, update_shape in s(index_specs) - for op in s(UpdateOps) - for dtype in s(UpdateOps.dtypes(op)) - for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes) - for mode in s(MODES)))) - def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, - indexer, op, mode): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode) - self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) - self._CompileAndCheck(jax_fn, args_maker) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( - name, jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, - "dtype": dtype, - "indexer": indexer, - "update_shape": update_shape, - "update_dtype": update_dtype, - "op": op - } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS) - for shape, indexer, update_shape in s(index_specs) - for op in s(UpdateOps) - for dtype in s(UpdateOps.dtypes(op)) - for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) - def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, - indexer, op): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, - unique_indices=True) - self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) - self._CompileAndCheck(jax_fn, args_maker) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( - name, jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "indexer": indexer, - "update_shape": update_shape, "update_dtype": update_dtype, - "op": op - } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED) - for shape, indexer, update_shape in s(index_specs) - for op in s(UpdateOps) - for dtype in s(UpdateOps.dtypes(op)) - for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) - def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype, - indexer, op): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - jax_fn = lambda x, y: UpdateOps.jax_fn( - op, indexer, x, y, indices_are_sorted=True, unique_indices=True) - self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True, - tol=_update_tol(op)) - self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( - name, jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "indexer": indexer, - "update_shape": update_shape, "update_dtype": update_dtype, - "op": op - } for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS) - for shape, indexer, update_shape in s(index_specs) - for op in s(UpdateOps) - for dtype in s(UpdateOps.dtypes(op)) - for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) - def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, - indexer, op): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] - np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) - jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) - self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) - self._CompileAndCheck(jax_fn, args_maker) - - @parameterized.named_parameters(jtu.cases_from_list({ - "testcase_name": - f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}" - f"_indexer={indexer}" - f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}" - f"_op={op.name}_mode={mode}", - "shape": shape, "dtype": dtype, "indexer": indexer, - "update_shape": update_shape, "update_dtype": update_dtype, - "op": op, "mode": mode, - } for mode in [None] + MODES - for name, index_specs in ( - STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else - STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS) - for shape, indexer, update_shape in index_specs - for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] - for dtype in float_dtypes - for update_shape in _broadcastable_shapes(update_shape) - for update_dtype in - ([dtype] if op == UpdateOps.ADD else float_dtypes))) - def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, - indexer, op, mode): - rng = jtu.rand_default(self.rng()) - jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode, - unique_indices=True) - x = rng(shape, dtype) - y = rng(update_shape, update_dtype) - check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( - name, jtu.format_shape_dtype_string(shape, dtype), indexer, - jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), - "shape": shape, "dtype": dtype, "indexer": indexer, - "update_shape": update_shape, "update_dtype": update_dtype, - "op": op, "unique_indices": unique_indices, - } for unique_indices in s([False, True]) - for name, index_specs in s( - ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices - else ADVANCED_INDEXING_TESTS) - for shape, indexer, update_shape in s(index_specs) - for op in s( - [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices - else [UpdateOps.ADD]) - for dtype in s(float_dtypes) - for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes)))) - def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype, - indexer, op, unique_indices): - rng = jtu.rand_default(self.rng()) - jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, - unique_indices=unique_indices) - x = rng(shape, dtype) - y = rng(update_shape, update_dtype) - check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) - - def testIndexMulGradFailsIfNotUnique(self): - y = bm.ones((10,), bm.int32).value - f = lambda x, z: x.at[y].mul(z.value) - - x = bm.ones((100,), bm.float32) - z = bm.ones((10,), bm.float32) - with self.assertRaises(NotImplementedError, - msg="scatter_mul gradients are only implemented if " - "`unique_indices=True`"): - jax.jvp(f, (x, z), (x, z)) - - def testSegmentSumBehavior(self): - # testAdvancedIndexing compares against NumPy, and as a result doesn't check - # repeated indices. This test is just a simple manual check, based on - # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum - data = np.array([5, 1, 7, 2, 3, 4, 1, 3]) - segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3]) - - ans = bm.zeros(np.max(segment_ids) + 1).at[segment_ids].add(data) - expected = np.array([13, 2, 7, 4]) - self.assertAllClose(ans, expected, check_dtypes=False) - - def testSegmentSum(self): - data = bm.array([5, 1, 7, 2, 3, 4, 1, 3]) - segment_ids = bm.array([0, 0, 0, 1, 2, 2, 3, 3]) - - # test with explicit num_segments - ans = bm.segment_sum(data, segment_ids, num_segments=4) - expected = bm.array([13, 2, 7, 4]) - self.assertAllClose(ans, expected, check_dtypes=False) - - # test with explicit num_segments larger than the higher index. - ans = bm.segment_sum(data, segment_ids, num_segments=5) - expected = bm.array([13, 2, 7, 4, 0]) - self.assertAllClose(ans, expected, check_dtypes=False) - - # test without explicit num_segments - ans = bm.segment_sum(data, segment_ids) - expected = bm.array([13, 2, 7, 4]) - self.assertAllClose(ans, expected, check_dtypes=False) - - # test with negative segment ids and segment ids larger than num_segments, - # that will be wrapped with the `mod`. - segment_ids = bm.array([0, 4, 8, 1, 2, -6, -1, 3]) - ans = bm.segment_sum(data, segment_ids, num_segments=4) - expected = bm.array([5, 2, 3, 3]) - self.assertAllClose(ans, expected, check_dtypes=False) - - # test with negative segment ids and without without explicit num_segments - # such as num_segments is defined by the smaller index. - segment_ids = bm.array([3, 3, 3, 4, 5, 5, -7, -6]) - ans = bm.segment_sum(data, segment_ids) - expected = bm.array([0, 0, 0, 13, 2, 7]) - self.assertAllClose(ans, expected, check_dtypes=False) - - def testSegmentSumOutOfBounds(self): - def fn(data, segment_ids): - return bm.segment_sum(data, segment_ids, num_segments).sum() - - data = np.array([0, 0], dtype=np.float32) - num_segments = 2 - segment_ids = np.array([2, 3]) - val, grad = jax.value_and_grad(fn)(data, segment_ids) - self.assertAllClose(val, np.array(0., np.float32)) - self.assertAllClose(grad, np.array([0., 0.], np.float32)) - - def testIndexDtypeError(self): - # https://github.com/google/jax/issues/2795 - bm.array(1) # get rid of startup warning - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("error") - bm.zeros(5).at[::2].set(1) - self.assertLen(w, 0) - - @contextmanager - def assertNoWarnings(self): - with warnings.catch_warnings(record=True) as caught_warnings: - yield - self.assertEmpty(caught_warnings) - - @parameterized.named_parameters(jtu.cases_from_list({ - "testcase_name": "idx={}".format(idx), "idx": idx, - "idx_type": idx_type} - for idx, idx_type in [ - ([0], "array"), - ([0, 0], "array"), - ([[0, 0]], "tuple"), - ([0, [0, 1]], "tuple"), - ([0, np.arange(2)], "tuple"), - ([0, None], "tuple"), - ([0, slice(None)], "tuple"), - ])) - def testIndexSequenceDeprecation(self, idx, idx_type): - normalize = {"array": np.array, "tuple": tuple}[idx_type] - msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type] - x = bm.arange(6).reshape(3, 2) - - with self.assertRaisesRegex(TypeError, msg): - x[idx] - with self.assertNoWarnings(): - x[normalize(idx)] - - with self.assertRaisesRegex(TypeError, msg): - x.at[idx].set(0) - with self.assertNoWarnings(): - x.at[normalize(idx)].set(0) - - def testIndexedUpdateAliasingBug(self): - # https://github.com/google/jax/issues/7461 - fn = lambda x: x.at[1:].set(1 + x[:-1]) - y = bm.zeros(8) - self.assertArraysEqual(fn(y), jax.jit(fn)(y)) + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": + f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}" + f"_indexer={indexer}" + f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}" + f"_op={op.name}", + "shape": shape, "dtype": dtype, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op, "mode": mode, + } for name, index_specs in s(STATIC_INDEXING_TESTS) + for shape, indexer, update_shape in s(index_specs) + for op in s(UpdateOps) + for dtype in s(UpdateOps.dtypes(op)) + for update_shape in s(_broadcastable_shapes(update_shape)) + for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes) + for mode in s(MODES)))) + def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, + indexer, op, mode): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode) + self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) + self._CompileAndCheck(jax_fn, args_maker) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), + "shape": shape, + "dtype": dtype, + "indexer": indexer, + "update_shape": update_shape, + "update_dtype": update_dtype, + "op": op + } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS) + for shape, indexer, update_shape in s(index_specs) + for op in s(UpdateOps) + for dtype in s(UpdateOps.dtypes(op)) + for update_shape in s(_broadcastable_shapes(update_shape)) + for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) + def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, + indexer, op): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, + unique_indices=True) + self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) + self._CompileAndCheck(jax_fn, args_maker) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), + "shape": shape, "dtype": dtype, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op + } for name, index_specs in s(ADVANCED_INDEXING_TESTS_NO_REPEATS_SORTED) + for shape, indexer, update_shape in s(index_specs) + for op in s(UpdateOps) + for dtype in s(UpdateOps.dtypes(op)) + for update_shape in s(_broadcastable_shapes(update_shape)) + for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) + def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype, + indexer, op): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + jax_fn = lambda x, y: UpdateOps.jax_fn( + op, indexer, x, y, indices_are_sorted=True, unique_indices=True) + self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, check_dtypes=True, + tol=_update_tol(op)) + self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), + "shape": shape, "dtype": dtype, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op + } for name, index_specs in s(MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS) + for shape, indexer, update_shape in s(index_specs) + for op in s(UpdateOps) + for dtype in s(UpdateOps.dtypes(op)) + for update_shape in s(_broadcastable_shapes(update_shape)) + for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) + def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, + indexer, op): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] + np_fn = lambda x, y: UpdateOps.np_fn(op, indexer, x, y) + jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) + self._CheckAgainstNumpy(np_fn, jax_fn, args_maker, tol=_update_tol(op)) + self._CompileAndCheck(jax_fn, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list({ + "testcase_name": + f"{name}_inshape={jtu.format_shape_dtype_string(shape, dtype)}" + f"_indexer={indexer}" + f"_update={jtu.format_shape_dtype_string(update_shape, update_dtype)}" + f"_op={op.name}_mode={mode}", + "shape": shape, "dtype": dtype, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op, "mode": mode, + } for mode in [None] + MODES + for name, index_specs in ( + STATIC_INDEXING_TESTS if mode == "promise_in_bounds" else + STATIC_INDEXING_TESTS + STATIC_INDEXING_OUT_OF_BOUNDS_TESTS) + for shape, indexer, update_shape in index_specs + for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] + for dtype in float_dtypes + for update_shape in _broadcastable_shapes(update_shape) + for update_dtype in + ([dtype] if op == UpdateOps.ADD else float_dtypes))) + def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, + indexer, op, mode): + rng = jtu.rand_default(self.rng()) + jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, mode=mode, + unique_indices=True) + x = rng(shape, dtype) + y = rng(update_shape, update_dtype) + check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( + name, jtu.format_shape_dtype_string(shape, dtype), indexer, + jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), + "shape": shape, "dtype": dtype, "indexer": indexer, + "update_shape": update_shape, "update_dtype": update_dtype, + "op": op, "unique_indices": unique_indices, + } for unique_indices in s([False, True]) + for name, index_specs in s( + ADVANCED_INDEXING_TESTS_NO_REPEATS if unique_indices + else ADVANCED_INDEXING_TESTS) + for shape, indexer, update_shape in s(index_specs) + for op in s( + [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] if unique_indices + else [UpdateOps.ADD]) + for dtype in s(float_dtypes) + for update_shape in s(_broadcastable_shapes(update_shape)) + for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes)))) + def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype, + indexer, op, unique_indices): + rng = jtu.rand_default(self.rng()) + jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y, + unique_indices=unique_indices) + x = rng(shape, dtype) + y = rng(update_shape, update_dtype) + check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) + + def testIndexMulGradFailsIfNotUnique(self): + y = bm.ones((10,), bm.int32).value + f = lambda x, z: x.at[y].mul(z.value) + + x = bm.ones((100,), bm.float32) + z = bm.ones((10,), bm.float32) + with self.assertRaises(NotImplementedError, + msg="scatter_mul gradients are only implemented if " + "`unique_indices=True`"): + jax.jvp(f, (x, z), (x, z)) + + def testSegmentSumBehavior(self): + # testAdvancedIndexing compares against NumPy, and as a result doesn't check + # repeated indices. This test is just a simple manual check, based on + # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum + data = np.array([5, 1, 7, 2, 3, 4, 1, 3]) + segment_ids = np.array([0, 0, 0, 1, 2, 2, 3, 3]) + + ans = bm.zeros(np.max(segment_ids) + 1).at[segment_ids].add(data) + expected = np.array([13, 2, 7, 4]) + self.assertAllClose(ans, expected, check_dtypes=False) + + def testSegmentSum(self): + data = bm.array([5, 1, 7, 2, 3, 4, 1, 3]) + segment_ids = bm.array([0, 0, 0, 1, 2, 2, 3, 3]) + + # test with explicit num_segments + ans = bm.segment_sum(data, segment_ids, num_segments=4) + expected = bm.array([13, 2, 7, 4]) + self.assertAllClose(ans, expected, check_dtypes=False) + + # test with explicit num_segments larger than the higher index. + ans = bm.segment_sum(data, segment_ids, num_segments=5) + expected = bm.array([13, 2, 7, 4, 0]) + self.assertAllClose(ans, expected, check_dtypes=False) + + # test without explicit num_segments + ans = bm.segment_sum(data, segment_ids) + expected = bm.array([13, 2, 7, 4]) + self.assertAllClose(ans, expected, check_dtypes=False) + + # test with negative segment ids and segment ids larger than num_segments, + # that will be wrapped with the `mod`. + segment_ids = bm.array([0, 4, 8, 1, 2, -6, -1, 3]) + ans = bm.segment_sum(data, segment_ids, num_segments=4) + expected = bm.array([5, 2, 3, 3]) + self.assertAllClose(ans, expected, check_dtypes=False) + + # test with negative segment ids and without without explicit num_segments + # such as num_segments is defined by the smaller index. + segment_ids = bm.array([3, 3, 3, 4, 5, 5, -7, -6]) + ans = bm.segment_sum(data, segment_ids) + expected = bm.array([0, 0, 0, 13, 2, 7]) + self.assertAllClose(ans, expected, check_dtypes=False) + + def testSegmentSumOutOfBounds(self): + def fn(data, segment_ids): + return bm.segment_sum(data, segment_ids, num_segments).sum() + + data = np.array([0, 0], dtype=np.float32) + num_segments = 2 + segment_ids = np.array([2, 3]) + val, grad = jax.value_and_grad(fn)(data, segment_ids) + self.assertAllClose(val, np.array(0., np.float32)) + self.assertAllClose(grad, np.array([0., 0.], np.float32)) + + def testIndexDtypeError(self): + # https://github.com/google/jax/issues/2795 + bm.array(1) # get rid of startup warning + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("error") + bm.zeros(5).at[::2].set(1) + self.assertLen(w, 0) + + @contextmanager + def assertNoWarnings(self): + with warnings.catch_warnings(record=True) as caught_warnings: + yield + self.assertEmpty(caught_warnings) + + @parameterized.named_parameters(jtu.cases_from_list({ + "testcase_name": "idx={}".format(idx), "idx": idx, + "idx_type": idx_type} + for idx, idx_type in [ + ([0], "array"), + ([0, 0], "array"), + ([[0, 0]], "tuple"), + ([0, [0, 1]], "tuple"), + ([0, np.arange(2)], "tuple"), + ([0, None], "tuple"), + ([0, slice(None)], "tuple"), + ])) + def testIndexSequenceDeprecation(self, idx, idx_type): + normalize = {"array": np.array, "tuple": tuple}[idx_type] + msg = {"array": ARRAY_MSG, "tuple": TUPLE_MSG}[idx_type] + x = bm.arange(6).reshape(3, 2) + + with self.assertRaisesRegex(TypeError, msg): + x[idx] + with self.assertNoWarnings(): + x[normalize(idx)] + + with self.assertRaisesRegex(TypeError, msg): + x.at[idx].set(0) + with self.assertNoWarnings(): + x.at[normalize(idx)].set(0) + + def testIndexedUpdateAliasingBug(self): + # https://github.com/google/jax/issues/7461 + fn = lambda x: x.at[1:].set(1 + x[:-1]) + y = bm.zeros(8) + self.assertArraysEqual(fn(y), jax.jit(fn)(y)) diff --git a/brainpy/_src/math/tests/test_numpy_ops.py b/brainpy/_src/math/tests/test_numpy_ops.py index 60615656e..d06235004 100644 --- a/brainpy/_src/math/tests/test_numpy_ops.py +++ b/brainpy/_src/math/tests/test_numpy_ops.py @@ -14,12 +14,12 @@ import pytest + pytest.skip("No need to test.", allow_module_level=True) import collections import functools from functools import partial -import inspect import io import itertools import operator @@ -32,10 +32,11 @@ from absl.testing import parameterized import numpy as np + try: - import numpy_dispatch + import numpy_dispatch except ImportError: - numpy_dispatch = None + numpy_dispatch = None import jax import jax.ops @@ -52,10 +53,10 @@ from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps from jax._src.util import prod, safe_zip -import brainpy as bp import brainpy.math as bm from jax.config import config + config.parse_flags_with_absl() FLAGS = config.FLAGS @@ -64,7 +65,7 @@ nonempty_nonscalar_array_shapes = [(4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] nonempty_array_shapes = [()] + nonempty_nonscalar_array_shapes one_dim_array_shapes = [(1,), (6,), (12,)] -empty_array_shapes = [(0,), (0, 4), (3, 0),] +empty_array_shapes = [(0,), (0, 4), (3, 0), ] scalar_shapes = [jtu.NUMPY_SCALAR_SHAPE, jtu.PYTHON_SCALAR_SHAPE] array_shapes = nonempty_array_shapes + empty_array_shapes @@ -82,58 +83,67 @@ number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes all_dtypes = number_dtypes + bool_dtypes - python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_] # uint64 is problematic because with any uint type it promotes to float: int_dtypes_no_uint64 = [d for d in int_dtypes + unsigned_dtypes if d != np.uint64] + def _indexer_with_default_outputs(indexer, use_defaults=True): - """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" - class Indexer: - @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) - def __getitem__(self, *args): - return indexer.__getitem__(*args) - return Indexer() + """Like jtu.with_jax_dtype_defaults, but for __getitem__ APIs""" + + class Indexer: + @partial(jtu.with_jax_dtype_defaults, use_defaults=use_defaults) + def __getitem__(self, *args): + return indexer.__getitem__(*args) + + return Indexer() + def _valid_dtypes_for_shape(shape, dtypes): - # Not all (shape, dtype) pairs are valid. In particular, Python scalars only - # have one type in each category (float, bool, etc.) - if shape is jtu.PYTHON_SCALAR_SHAPE: - return [t for t in dtypes if t in python_scalar_dtypes] - return dtypes + # Not all (shape, dtype) pairs are valid. In particular, Python scalars only + # have one type in each category (float, bool, etc.) + if shape is jtu.PYTHON_SCALAR_SHAPE: + return [t for t in dtypes if t in python_scalar_dtypes] + return dtypes + def _shape_and_dtypes(shapes, dtypes): - for shape in shapes: - for dtype in _valid_dtypes_for_shape(shape, dtypes): - yield (shape, dtype) + for shape in shapes: + for dtype in _valid_dtypes_for_shape(shape, dtypes): + yield (shape, dtype) + def _compatible_shapes(shape): - if shape in scalar_shapes or np.ndim(shape) == 0: - return [shape] - return (shape[n:] for n in range(len(shape) + 1)) + if shape in scalar_shapes or np.ndim(shape) == 0: + return [shape] + return (shape[n:] for n in range(len(shape) + 1)) + def _get_y_shapes(y_dtype, shape, rowvar): - # Helper function for testCov. - if y_dtype is None: - return [None] - if len(shape) == 1: - return [shape] - elif rowvar or shape[0] == 1: - return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])] - return [(shape[0], 1), (shape[0], 2), (shape[0], 5)] + # Helper function for testCov. + if y_dtype is None: + return [None] + if len(shape) == 1: + return [shape] + elif rowvar or shape[0] == 1: + return [(1, shape[-1]), (2, shape[-1]), (5, shape[-1])] + return [(shape[0], 1), (shape[0], 2), (shape[0], 5)] + OpRecord = collections.namedtuple( - "OpRecord", - ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", - "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) + "OpRecord", + ["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes", + "test_name", "check_dtypes", "tolerance", "inexact", "kwargs"]) + def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, test_name=None, check_dtypes=True, tolerance=None, inexact=False, kwargs=None): - test_name = test_name or name - return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, - test_name, check_dtypes, tolerance, inexact, kwargs) + test_name = test_name or name + return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes, + test_name, check_dtypes, tolerance, inexact, kwargs) + JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, all_dtypes, @@ -218,10 +228,10 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, ] JAX_TEST_RECORDS = [ - op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes, + op_record("divmod", 2, int_dtypes + float_dtypes, all_shapes, jtu.rand_nonzero, []), - op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []), - op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes, + op_record("modf", 1, float_dtypes, all_shapes, jtu.rand_default, []), + op_record("modf", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), ] @@ -373,12 +383,12 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []), ] if numpy_version >= (1, 22): # initial & where keywords added in numpy 1.22 - JAX_REDUCER_INITIAL_RECORDS += [ - op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []), - op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), - ] + JAX_REDUCER_INITIAL_RECORDS += [ + op_record("nanprod", 1, inexact_dtypes, all_shapes, jtu.rand_small_positive, []), + op_record("nansum", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("nanmax", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + op_record("nanmin", 1, inexact_dtypes, all_shapes, jtu.rand_default, []), + ] JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [ op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []), @@ -391,14 +401,14 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, inexact=True), ] if numpy_version >= (1, 22): # where keyword added in numpy 1.22 - JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [ - op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], - inexact=True), - ] + JAX_REDUCER_WHERE_NO_INITIAL_RECORDS += [ + op_record("nanmean", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("nanvar", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + op_record("nanstd", 1, inexact_dtypes, nonempty_shapes, jtu.rand_default, [], + inexact=True), + ] JAX_REDUCER_NO_DTYPE_RECORDS = [ op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []), @@ -477,5480 +487,5567 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("__rrshift__", 2, int_dtypes_no_uint64, all_shapes, partial(jtu.rand_int, high=8), []) ] + class _OverrideEverything(object): - pass + pass + for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS: - if rec.nargs == 2: - setattr(_OverrideEverything, rec.name, lambda self, other: self) + if rec.nargs == 2: + setattr(_OverrideEverything, rec.name, lambda self, other: self) + class _OverrideNothing(object): - pass + pass + for rec in JAX_OPERATOR_OVERLOADS + JAX_RIGHT_OPERATOR_OVERLOADS: - if rec.nargs == 2: - setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented) + if rec.nargs == 2: + setattr(_OverrideNothing, rec.name, lambda self, other: NotImplemented) def _dtypes_are_compatible_for_bitwise_ops(args): - if len(args) <= 1: - return True - is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger) - width = lambda dtype: jnp.iinfo(dtype).bits - x, y = args - if width(x) > width(y): - x, y = y, x - # The following condition seems a little ad hoc, but seems to capture what - # numpy actually implements. - return ( - is_signed(x) == is_signed(y) - or (width(x) == 32 and width(y) == 32) - or (width(x) == 32 and width(y) == 64 and is_signed(y))) + if len(args) <= 1: + return True + is_signed = lambda dtype: jnp.issubdtype(dtype, np.signedinteger) + width = lambda dtype: jnp.iinfo(dtype).bits + x, y = args + if width(x) > width(y): + x, y = y, x + # The following condition seems a little ad hoc, but seems to capture what + # numpy actually implements. + return ( + is_signed(x) == is_signed(y) + or (width(x) == 32 and width(y) == 32) + or (width(x) == 32 and width(y) == 64 and is_signed(y))) + def _shapes_are_broadcast_compatible(shapes): - try: - lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) - except ValueError: - return False - else: - return True + try: + lax.broadcast_shapes(*(() if s in scalar_shapes else s for s in shapes)) + except ValueError: + return False + else: + return True + def _shapes_are_equal_length(shapes): - return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) + return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) def _promote_like_jnp(fun, inexact=False): - """Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`. + """Decorator that promotes the arguments of `fun` to `jnp.result_type(*args)`. jnp and np have different type promotion semantics; this decorator allows tests make an np reference implementation act more like an jnp implementation. """ - _promote = _promote_dtypes_inexact if inexact else _promote_dtypes - def wrapper(*args, **kw): - flat_args, tree = tree_util.tree_flatten(args) - args = tree_util.tree_unflatten(tree, _promote(*flat_args)) - return fun(*args, **kw) - return wrapper + _promote = _promote_dtypes_inexact if inexact else _promote_dtypes + + def wrapper(*args, **kw): + flat_args, tree = tree_util.tree_flatten(args) + args = tree_util.tree_unflatten(tree, _promote(*flat_args)) + return fun(*args, **kw) + + return wrapper def bm_func(fun): - def wrapper(*args, **kw): - res = fun(*args, **kw) - if isinstance(res, bm.Array): - return res.value - elif isinstance(res, tuple): - return tuple(r.value if isinstance(r, bm.Array) else r for r in res) - elif isinstance(res, list): - return list(r.value if isinstance(r, bm.Array) else r for r in res) - else: - return res + def wrapper(*args, **kw): + res = fun(*args, **kw) + if isinstance(res, bm.Array): + return res.value + elif isinstance(res, tuple): + return tuple(r.value if isinstance(r, bm.Array) else r for r in res) + elif isinstance(res, list): + return list(r.value if isinstance(r, bm.Array) else r for r in res) + else: + return res + + return wrapper - return wrapper @pytest.mark.skipif(True, reason="No longer need to test.") @jtu.with_config(jax_numpy_dtype_promotion='standard') class LaxBackedNumpyTests(jtu.JaxTestCase): - """Tests for LAX-backed Numpy implementation.""" - - def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): - def f(): - out = [rng(shape, dtype or jnp.float_) - for shape, dtype in zip(shapes, dtypes)] - if np_arrays: - return out - return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a - for a in out] - return f - - # todo: not tested - def testNotImplemented(self): - for name in jnp._NOT_IMPLEMENTED: - func = getattr(jnp, name) - with self.assertRaises(NotImplementedError): - func() - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_allow_picke={}".format(dtype, allow_pickle), - "dtype": dtype, "allow_pickle": allow_pickle} - for dtype in float_dtypes + [object] - for allow_pickle in [True, False])) - def testLoad(self, dtype, allow_pickle): - if dtype == object and not allow_pickle: - self.skipTest("dtype=object requires allow_pickle=True") - rng = jtu.rand_default(self.rng()) - arr = rng((10), dtype) - with io.BytesIO() as f: - bm.save(f, arr) - f.seek(0) - arr_out = bm.load(f, allow_pickle=allow_pickle) - self.assertArraysEqual(arr, arr_out) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), - "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance, - "inexact": rec.inexact, "kwargs": rec.kwargs or {}} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, - JAX_COMPOUND_OP_RECORDS))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testOp(self, np_op, bm_op, rng_factory, shapes, dtypes, check_dtypes, - tolerance, inexact, kwargs): - np_op = partial(np_op, **kwargs) - bm_op = partial(bm_op, **kwargs) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="invalid value.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) - - rng = rng_factory(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) - tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) - tol = functools.reduce(jtu.join_tolerance, - [tolerance, tol, jtu.default_tolerance()]) - self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), bm_func(bm_op), - args_maker, check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(bm_func(bm_op), args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "tol": rec.tolerance} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_OPERATOR_OVERLOADS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): - rng = rng_factory(self.rng()) - # np and jnp arrays have different type promotion rules; force the use of - # jnp arrays. - args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) - fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) - self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, - dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, - "op_tolerance": rec.tolerance} + """Tests for LAX-backed Numpy implementation.""" + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + + return f + + # todo: not tested + def testNotImplemented(self): + for name in jnp._NOT_IMPLEMENTED: + func = getattr(jnp, name) + with self.assertRaises(NotImplementedError): + func() + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_allow_picke={}".format(dtype, allow_pickle), + "dtype": dtype, "allow_pickle": allow_pickle} + for dtype in float_dtypes + [object] + for allow_pickle in [True, False])) + def testLoad(self, dtype, allow_pickle): + if dtype == object and not allow_pickle: + self.skipTest("dtype=object requires allow_pickle=True") + rng = jtu.rand_default(self.rng()) + arr = rng((10), dtype) + with io.BytesIO() as f: + bm.save(f, arr) + f.seek(0) + arr_out = bm.load(f, allow_pickle=allow_pickle) + self.assertArraysEqual(arr, arr_out) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, + dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), + "check_dtypes": rec.check_dtypes, "tolerance": rec.tolerance, + "inexact": rec.inexact, "kwargs": rec.kwargs or {}} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) + for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, + JAX_COMPOUND_OP_RECORDS))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testOp(self, np_op, bm_op, rng_factory, shapes, dtypes, check_dtypes, + tolerance, inexact, kwargs): + np_op = partial(np_op, **kwargs) + bm_op = partial(bm_op, **kwargs) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="divide by zero.*")(np_op) + + rng = rng_factory(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) + tol = max(jtu.tolerance(dtype, tolerance) for dtype in dtypes) + tol = functools.reduce(jtu.join_tolerance, + [tolerance, tol, jtu.default_tolerance()]) + self._CheckAgainstNumpy(_promote_like_jnp(np_op, inexact), bm_func(bm_op), + args_maker, check_dtypes=check_dtypes, tol=tol) + self._CompileAndCheck(bm_func(bm_op), args_maker, check_dtypes=check_dtypes, + atol=tol, rtol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, + dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, + "tol": rec.tolerance} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) + for rec in JAX_OPERATOR_OVERLOADS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testOperatorOverload(self, name, rng_factory, shapes, dtypes, tol): + rng = rng_factory(self.rng()) + # np and jnp arrays have different type promotion rules; force the use of + # jnp arrays. + args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) + fun = lambda *xs: getattr(operator, name.strip('_'))(*xs) + self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, + dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "name": rec.name, + "op_tolerance": rec.tolerance} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) + for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, + op_tolerance): + if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: + raise SkipTest("scalars not implemented") # TODO(mattjj): clean up + rng = rng_factory(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) + fun = lambda fst, snd: getattr(snd, name)(fst) + tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) + self._CompileAndCheck(fun, args_maker, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": rec.test_name + "_{}".format(dtype), + "rng_factory": rec.rng_factory, + "op_name": rec.name, "dtype": dtype} + for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2 + for dtype in rec.dtypes)) + def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): + rng = rng_factory(self.rng()) + arg = jax.device_put(rng((), dtype)) + op = getattr(operator, op_name) + + other = _OverrideEverything() + assert op(other, arg) is other + assert op(arg, other) is other + + other = _OverrideNothing() + if op_name == "__eq__": + assert op(other, arg) is False + assert op(arg, other) is False + elif op_name == "__ne__": + assert op(other, arg) is True + assert op(arg, other) is True + else: + with self.assertRaises(TypeError): + op(other, arg) + with self.assertRaises(TypeError): + op(arg, other) + + def testArrayEqualExamples(self): + # examples from the array_equal() docstring. + self.assertTrue(bm.array_equal([1, 2], [1, 2])) + self.assertTrue(bm.array_equal(np.array([1, 2]), np.array([1, 2]))) + self.assertFalse(bm.array_equal([1, 2], [1, 2, 3])) + self.assertFalse(bm.array_equal([1, 2], [1, 4])) + + a = np.array([1, np.nan]) + self.assertFalse(bm.array_equal(a, a)) + self.assertTrue(bm.array_equal(a, a, equal_nan=True)) + + a = np.array([1 + 1j]) + b = a.copy() + a.real = np.nan + b.imag = np.nan + self.assertTrue(bm.array_equal(a, b, equal_nan=True)) + + def testArrayEquivExamples(self): + # examples from the array_equiv() docstring. + self.assertTrue(bm.array_equiv([1, 2], [1, 2])) + self.assertFalse(bm.array_equiv([1, 2], [1, 3])) + with jax.numpy_rank_promotion('allow'): + self.assertTrue(bm.array_equiv([1, 2], [[1, 2], [1, 2]])) + self.assertFalse(bm.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) + self.assertFalse(bm.array_equiv([1, 2], [[1, 2], [1, 3]])) + + def testArrayModule(self): + if numpy_dispatch is None: + raise SkipTest('requires https://github.com/seberg/numpy-dispatch') + + bm_array = bm.array(1.0) + np_array = np.array(1.0) + + module = numpy_dispatch.get_array_module(bm_array) + self.assertIs(module, jnp) + + module = numpy_dispatch.get_array_module(bm_array, np_array) + self.assertIs(module, jnp) + + def f(x): + module = numpy_dispatch.get_array_module(x) + self.assertIs(module, jnp) + return x + + jax.jit(f)(bm_array) + jax.grad(f)(bm_array) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix( + rec.test_name, shapes, dtypes), + "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name)} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + for dtypes in filter( + _dtypes_are_compatible_for_bitwise_ops, + itertools.combinations_with_replacement(rec.dtypes, rec.nargs))) + for rec in JAX_BITWISE_OP_RECORDS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testBitwiseOp(self, np_op, bm_op, rng_factory, shapes, dtypes): + rng = rng_factory(self.rng()) + if not config.x64_enabled and any( + bm.iinfo(dtype).bits == 64 for dtype in dtypes): + self.skipTest("x64 types are disabled by jax_enable_x64") + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, + check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes), + "op": op, "dtypes": dtypes, "shapes": shapes} + for op in [bm.left_shift, bm.right_shift] for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) + _shapes_are_broadcast_compatible, + # TODO numpy always promotes to shift dtype for zero-dim shapes: + itertools.combinations_with_replacement(nonzerodim_shapes, 2)) for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, rec.dtypes) for s in shapes))) - for rec in JAX_RIGHT_OPERATOR_OVERLOADS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testRightOperatorOverload(self, name, rng_factory, shapes, dtypes, - op_tolerance): - if shapes[1] is jtu.PYTHON_SCALAR_SHAPE: - raise SkipTest("scalars not implemented") # TODO(mattjj): clean up - rng = rng_factory(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes, np_arrays=False) - fun = lambda fst, snd: getattr(snd, name)(fst) - tol = max(jtu.tolerance(dtype, op_tolerance) for dtype in dtypes) - self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": rec.test_name + "_{}".format(dtype), - "rng_factory": rec.rng_factory, - "op_name": rec.name, "dtype": dtype} - for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2 - for dtype in rec.dtypes)) - def testBinaryOperatorDefers(self, op_name, rng_factory, dtype): - rng = rng_factory(self.rng()) - arg = jax.device_put(rng((), dtype)) - op = getattr(operator, op_name) - - other = _OverrideEverything() - assert op(other, arg) is other - assert op(arg, other) is other - - other = _OverrideNothing() - if op_name == "__eq__": - assert op(other, arg) is False - assert op(arg, other) is False - elif op_name == "__ne__": - assert op(other, arg) is True - assert op(arg, other) is True - else: - with self.assertRaises(TypeError): - op(other, arg) - with self.assertRaises(TypeError): - op(arg, other) - - def testArrayEqualExamples(self): - # examples from the array_equal() docstring. - self.assertTrue(bm.array_equal([1, 2], [1, 2])) - self.assertTrue(bm.array_equal(np.array([1, 2]), np.array([1, 2]))) - self.assertFalse(bm.array_equal([1, 2], [1, 2, 3])) - self.assertFalse(bm.array_equal([1, 2], [1, 4])) - - a = np.array([1, np.nan]) - self.assertFalse(bm.array_equal(a, a)) - self.assertTrue(bm.array_equal(a, a, equal_nan=True)) - - a = np.array([1 + 1j]) - b = a.copy() - a.real = np.nan - b.imag = np.nan - self.assertTrue(bm.array_equal(a, b, equal_nan=True)) - - def testArrayEquivExamples(self): - # examples from the array_equiv() docstring. - self.assertTrue(bm.array_equiv([1, 2], [1, 2])) - self.assertFalse(bm.array_equiv([1, 2], [1, 3])) - with jax.numpy_rank_promotion('allow'): - self.assertTrue(bm.array_equiv([1, 2], [[1, 2], [1, 2]])) - self.assertFalse(bm.array_equiv([1, 2], [[1, 2, 1, 2], [1, 2, 1, 2]])) - self.assertFalse(bm.array_equiv([1, 2], [[1, 2], [1, 3]])) - - def testArrayModule(self): - if numpy_dispatch is None: - raise SkipTest('requires https://github.com/seberg/numpy-dispatch') - - bm_array = bm.array(1.0) - np_array = np.array(1.0) - - module = numpy_dispatch.get_array_module(bm_array) - self.assertIs(module, jnp) - - module = numpy_dispatch.get_array_module(bm_array, np_array) - self.assertIs(module, jnp) - - def f(x): - module = numpy_dispatch.get_array_module(x) - self.assertIs(module, jnp) - return x - jax.jit(f)(bm_array) - jax.grad(f)(bm_array) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.test_name, shapes, dtypes), - "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name)} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(rec.shapes, rec.nargs)) - for dtypes in filter( - _dtypes_are_compatible_for_bitwise_ops, - itertools.combinations_with_replacement(rec.dtypes, rec.nargs))) - for rec in JAX_BITWISE_OP_RECORDS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testBitwiseOp(self, np_op, bm_op, rng_factory, shapes, dtypes): - rng = rng_factory(self.rng()) - if not config.x64_enabled and any( - bm.iinfo(dtype).bits == 64 for dtype in dtypes): - self.skipTest("x64 types are disabled by jax_enable_x64") - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, - check_dtypes=jtu.PYTHON_SCALAR_SHAPE not in shapes) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(op.__name__, shapes, dtypes), - "op": op, "dtypes": dtypes, "shapes": shapes} - for op in [bm.left_shift, bm.right_shift] - for shapes in filter( - _shapes_are_broadcast_compatible, - # TODO numpy always promotes to shift dtype for zero-dim shapes: - itertools.combinations_with_replacement(nonzerodim_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes)))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testShiftOpAgainstNumpy(self, op, dtypes, shapes): - dtype, shift_dtype = dtypes - signed_mix = np.issubdtype(dtype, np.signedinteger) != \ - np.issubdtype(shift_dtype, np.signedinteger) - has_32 = any(np.iinfo(d).bits == 32 for d in dtypes) - promoting_to_64 = has_32 and signed_mix - if promoting_to_64 and not config.x64_enabled: - self.skipTest("np.right_shift/left_shift promoting to int64" - "differs from jnp in 32 bit mode.") - - info, shift_info = map(np.iinfo, dtypes) - x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1) - # NumPy requires shifts to be non-negative and below the bit width: - shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) - args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) - self._CompileAndCheck(bm_func(op), args_maker) - np_op = getattr(np, op.__name__) - self._CheckAgainstNumpy(bm_func(np_op), op, args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, - "None" if out_dtype is None else np.dtype(out_dtype).name, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True]) - for rec in JAX_REDUCER_RECORDS)) - def testReducer(self, np_op, bm_op, rng_factory, shape, dtype, out_dtype, - axis, keepdims, inexact): - rng = rng_factory(self.rng()) - @jtu.ignore_warning(category=np.ComplexWarning) - @jtu.ignore_warning(category=RuntimeWarning, - message="mean of empty slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered.*") - def np_fun(x): - x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32) - t = out_dtype if out_dtype != jnp.bfloat16 else np.float32 - return np_op(x_cast, axis, dtype=t, keepdims=keepdims) - np_fun = _promote_like_jnp(np_fun, inexact) - bm_fun = lambda x: bm_op(x, axis, dtype=out_dtype, keepdims=keepdims) - bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) - args_maker = lambda: [rng(shape, dtype)] - tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3, - np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5} - tol = jtu.tolerance(dtype, tol_spec) - tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=jnp.bfloat16 not in (dtype, out_dtype), - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, - rtol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( + *(_valid_dtypes_for_shape(s, int_dtypes_no_uint64) for s in shapes)))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testShiftOpAgainstNumpy(self, op, dtypes, shapes): + dtype, shift_dtype = dtypes + signed_mix = np.issubdtype(dtype, np.signedinteger) != \ + np.issubdtype(shift_dtype, np.signedinteger) + has_32 = any(np.iinfo(d).bits == 32 for d in dtypes) + promoting_to_64 = has_32 and signed_mix + if promoting_to_64 and not config.x64_enabled: + self.skipTest("np.right_shift/left_shift promoting to int64" + "differs from jnp in 32 bit mode.") + + info, shift_info = map(np.iinfo, dtypes) + x_rng = jtu.rand_int(self.rng(), low=info.min, high=info.max + 1) + # NumPy requires shifts to be non-negative and below the bit width: + shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) + args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) + self._CompileAndCheck(bm_func(op), args_maker) + np_op = getattr(np, op.__name__) + self._CheckAgainstNumpy(bm_func(np_op), op, args_maker) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, + "None" if out_dtype is None else np.dtype(out_dtype).name, keepdims), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for out_dtype in [None] + rec.dtypes if out_dtype not in unsigned_dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for keepdims in [False, True]) + for rec in JAX_REDUCER_RECORDS)) + def testReducer(self, np_op, bm_op, rng_factory, shape, dtype, out_dtype, + axis, keepdims, inexact): + rng = rng_factory(self.rng()) + + @jtu.ignore_warning(category=np.ComplexWarning) + @jtu.ignore_warning(category=RuntimeWarning, + message="mean of empty slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered.*") + def np_fun(x): + x_cast = x if dtype != jnp.bfloat16 else x.astype(np.float32) + t = out_dtype if out_dtype != jnp.bfloat16 else np.float32 + return np_op(x_cast, axis, dtype=t, keepdims=keepdims) + + np_fun = _promote_like_jnp(np_fun, inexact) + bm_fun = lambda x: bm_op(x, axis, dtype=out_dtype, keepdims=keepdims) + bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) + args_maker = lambda: [rng(shape, dtype)] + tol_spec = {np.float16: 1e-2, np.int32: 1E-3, np.float32: 1e-3, + np.complex64: 1e-3, np.float64: 1e-5, np.complex128: 1e-5} + tol = jtu.tolerance(dtype, tol_spec) + tol = max(tol, jtu.tolerance(out_dtype, tol_spec)) if out_dtype else tol + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=jnp.bfloat16 not in (dtype, out_dtype), + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, + rtol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for keepdims in [False, True]) + for rec in JAX_REDUCER_NO_DTYPE_RECORDS)) + def testReducerNoDtype(self, np_op, bm_op, rng_factory, shape, dtype, axis, + keepdims, inexact): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="All-NaN (slice|axis) encountered.*") + def np_fun(x): + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + np_fun = _promote_like_jnp(np_fun, inexact) + bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims) + args_maker = lambda: [rng(shape, dtype)] + tol = {np.float16: 0.002} + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol, atol=tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), + "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + for initial in [0, 1] for keepdims in [False, True]) + for rec in JAX_REDUCER_INITIAL_RECORDS)) + def testReducerInitial(self, np_op, bm_op, rng_factory, shape, dtype, axis, + keepdims, initial, inexact): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + def np_fun(x): + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + np_fun = _promote_like_jnp(np_fun, inexact) + np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims, initial=initial) + bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 3E-2} + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, rtol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, + jtu.format_shape_dtype_string(whereshape, bool)), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), "whereshape": whereshape, + "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for whereshape in _compatible_shapes(shape) + for axis in list(range(-len(shape), len(shape))) + [None] + for initial in [0, 1] for keepdims in [False, True]) + for rec in JAX_REDUCER_INITIAL_RECORDS)) + def testReducerWhere(self, np_op, bm_op, rng_factory, shape, dtype, axis, + keepdims, initial, inexact, whereshape): + if (shape in [()] + scalar_shapes and + dtype in [bm.int16, bm.uint16] and + bm_op in [bm.min, bm.max]): + self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' + # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. + where = jtu.rand_bool(self.rng())(whereshape, np.bool_) + + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + def np_fun(x): + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + np_fun = _promote_like_jnp(np_fun, inexact) + np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims, initial=initial, where=where) + bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy") + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format( + rec.test_name.capitalize(), + jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, + jtu.format_shape_dtype_string(whereshape, bool)), + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), "whereshape": whereshape, + "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} + for shape in rec.shapes for dtype in rec.dtypes + for whereshape in _compatible_shapes(shape) + for axis in list(range(-len(shape), len(shape))) + [None] + for keepdims in [False, True]) + for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)) + def testReducerWhereNoInitial(self, np_op, bm_op, rng_factory, shape, dtype, axis, + keepdims, inexact, whereshape): + rng = rng_factory(self.rng()) + is_bf16_nan_test = dtype == jnp.bfloat16 + # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. + where = jtu.rand_bool(self.rng())(whereshape, np.bool_) + + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="Mean of empty slice.*") + @jtu.ignore_warning(category=RuntimeWarning, + message="invalid value encountered in true_divide*") + def np_fun(x): + x_cast = x if not is_bf16_nan_test else x.astype(np.float32) + res = np_op(x_cast, axis, keepdims=keepdims, where=where) + res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) + return res + + np_fun = _promote_like_jnp(np_fun, inexact) + np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims, where=where) + bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) + args_maker = lambda: [rng(shape, dtype)] + if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_axis={}_discont={}_period={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, discont, period), + "shape": shape, "dtype": dtype, "axis": axis, "discont": discont, "period": period} + for shape in all_shapes for dtype in default_dtypes + for discont in [None, "pi", 2] + for period in ["2pi", "pi"] + for axis in list(range(-len(shape), len(shape))))) + def testUnwrap(self, shape, dtype, axis, discont, period): + if numpy_version < (1, 21) and period != "2pi": + self.skipTest("numpy < 1.21 does not support the period argument to unwrap()") + special_vals = {"pi": np.pi, "2pi": 2 * np.pi} + period = special_vals.get(period, period) + discont = special_vals.get(discont, discont) + + rng = jtu.rand_default(self.rng()) + if numpy_version < (1, 21): + np_fun = partial(np.unwrap, axis=axis, discont=discont) + else: + np_fun = partial(np.unwrap, axis=axis, discont=discont, period=period) + bm_fun = partial(bm.unwrap, axis=axis, discont=discont, period=period) + args_maker = lambda: [rng(shape, dtype)] + if dtype != jnp.bfloat16: # numpy crashes on bfloat16 + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for shape in all_shapes for dtype in all_dtypes + for axis in list(range(-len(shape), len(shape))) + [None])) + def testCountNonzero(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + np_fun = lambda x: np.count_nonzero(x, axis) + bm_fun = lambda x: bm.count_nonzero(x, axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in all_shapes for dtype in all_dtypes)) + def testNonzero(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + np_fun = lambda x: np.nonzero(x) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np_fun) + bm_fun = lambda x: bm.nonzero(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape, dtype), size, fill_value), + "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} + for shape in nonempty_array_shapes + for dtype in all_dtypes + for fill_value in [None, -1, shape or (1,)] + for size in [1, 5, 10])) + def testNonzeroSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") + def np_fun(x): + result = np.nonzero(x) + if size <= len(result[0]): + return tuple(arg[:size] for arg in result) + else: + fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0] + return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) + for fval, arg in safe_zip(fillvals, result)) + + bm_fun = lambda x: bm.nonzero(x, size=size, fill_value=fill_value) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in all_shapes for dtype in all_dtypes)) + def testFlatNonzero(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np.flatnonzero) + bm_fun = bm.flatnonzero + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + + # JIT compilation requires specifying the size statically: + bm_fun = lambda x: bm.flatnonzero(x, size=np.size(x) // 2) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape, dtype), size, fill_value), + "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} + for shape in nonempty_array_shapes + for dtype in all_dtypes + for fill_value in [None, -1, 10, (-1,), (10,)] + for size in [1, 5, 10])) + def testFlatNonzeroSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") + def np_fun(x): + result = np.flatnonzero(x) + if size <= len(result): + return result[:size] + else: + fill_val = fill_value or 0 + return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)]) + + bm_fun = lambda x: bm.flatnonzero(x, size=size, fill_value=fill_value) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in all_shapes for dtype in all_dtypes)) + def testArgWhere(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np.argwhere) + bm_fun = bm.argwhere + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + + # JIT compilation requires specifying a size statically. Full test of this + # behavior is in testNonzeroSize(). + bm_fun = lambda x: bm.argwhere(x, size=np.size(x) // 2) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape, dtype), size, fill_value), + "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} + for shape in nonempty_array_shapes + for dtype in all_dtypes + for fill_value in [None, -1, shape or (1,)] + for size in [1, 5, 10])) + def testArgWhereSize(self, shape, dtype, size, fill_value): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") + def np_fun(x): + result = np.argwhere(x) + if size <= len(result): + return result[:size] + else: + fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0] + return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack( + [np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) + for fval, arg in safe_zip(fillvals, result.T)]).T + + bm_fun = lambda x: bm.argwhere(x, size=size, fill_value=fill_value) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True]) - for rec in JAX_REDUCER_NO_DTYPE_RECORDS)) - def testReducerNoDtype(self, np_op, bm_op, rng_factory, shape, dtype, axis, - keepdims, inexact): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="All-NaN (slice|axis) encountered.*") - def np_fun(x): - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - np_fun = _promote_like_jnp(np_fun, inexact) - bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims) - args_maker = lambda: [rng(shape, dtype)] - tol = {np.float16: 0.002} - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol, atol=tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True]) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerInitial(self, np_op, bm_op, rng_factory, shape, dtype, axis, - keepdims, initial, inexact): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - def np_fun(x): - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, initial=initial) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - np_fun = _promote_like_jnp(np_fun, inexact) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) - bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims, initial=initial) - bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 3E-2} - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, rtol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_whereshape={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, - jtu.format_shape_dtype_string(whereshape, bool)), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), "whereshape": whereshape, - "initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for whereshape in _compatible_shapes(shape) - for axis in list(range(-len(shape), len(shape))) + [None] - for initial in [0, 1] for keepdims in [False, True]) - for rec in JAX_REDUCER_INITIAL_RECORDS)) - def testReducerWhere(self, np_op, bm_op, rng_factory, shape, dtype, axis, - keepdims, initial, inexact, whereshape): - if (shape in [()] + scalar_shapes and - dtype in [bm.int16, bm.uint16] and - bm_op in [bm.min, bm.max]): - self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' - # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. - where = jtu.rand_bool(self.rng())(whereshape, np.bool_) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - def np_fun(x): - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, initial=initial, where=where) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - np_fun = _promote_like_jnp(np_fun, inexact) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) - bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims, initial=initial, where=where) - bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy") - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, - jtu.format_shape_dtype_string(whereshape, bool)), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), "whereshape": whereshape, - "axis": axis, "keepdims": keepdims, "inexact": rec.inexact} - for shape in rec.shapes for dtype in rec.dtypes - for whereshape in _compatible_shapes(shape) - for axis in list(range(-len(shape), len(shape))) + [None] - for keepdims in [False, True]) - for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)) - def testReducerWhereNoInitial(self, np_op, bm_op, rng_factory, shape, dtype, axis, - keepdims, inexact, whereshape): - rng = rng_factory(self.rng()) - is_bf16_nan_test = dtype == jnp.bfloat16 - # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. - where = jtu.rand_bool(self.rng())(whereshape, np.bool_) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="Mean of empty slice.*") - @jtu.ignore_warning(category=RuntimeWarning, - message="invalid value encountered in true_divide*") - def np_fun(x): - x_cast = x if not is_bf16_nan_test else x.astype(np.float32) - res = np_op(x_cast, axis, keepdims=keepdims, where=where) - res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16) - return res - - np_fun = _promote_like_jnp(np_fun, inexact) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) - bm_fun = lambda x: bm_op(x, axis, keepdims=keepdims, where=where) - bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) - args_maker = lambda: [rng(shape, dtype)] - if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"): - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}_discont={}_period={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, discont, period), - "shape": shape, "dtype": dtype, "axis": axis, "discont": discont, "period": period} - for shape in all_shapes for dtype in default_dtypes - for discont in [None, "pi", 2] - for period in ["2pi", "pi"] - for axis in list(range(-len(shape), len(shape))))) - def testUnwrap(self, shape, dtype, axis, discont, period): - if numpy_version < (1, 21) and period != "2pi": - self.skipTest("numpy < 1.21 does not support the period argument to unwrap()") - special_vals = {"pi": np.pi, "2pi": 2 * np.pi} - period = special_vals.get(period, period) - discont = special_vals.get(discont, discont) - - rng = jtu.rand_default(self.rng()) - if numpy_version < (1, 21): - np_fun = partial(np.unwrap, axis=axis, discont=discont) - else: - np_fun = partial(np.unwrap, axis=axis, discont=discont, period=period) - bm_fun = partial(bm.unwrap, axis=axis, discont=discont, period=period) - args_maker = lambda: [rng(shape, dtype)] - if dtype != jnp.bfloat16: # numpy crashes on bfloat16 - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in all_shapes for dtype in all_dtypes - for axis in list(range(-len(shape), len(shape))) + [None])) - def testCountNonzero(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - np_fun = lambda x: np.count_nonzero(x, axis) - bm_fun = lambda x: bm.count_nonzero(x, axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testNonzero(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - np_fun = lambda x: np.nonzero(x) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np_fun) - bm_fun = lambda x: bm.nonzero(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_size={}_fill_value={}".format( - jtu.format_shape_dtype_string(shape, dtype), size, fill_value), - "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} - for shape in nonempty_array_shapes - for dtype in all_dtypes - for fill_value in [None, -1, shape or (1,)] - for size in [1, 5, 10])) - def testNonzeroSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") - def np_fun(x): - result = np.nonzero(x) - if size <= len(result[0]): - return tuple(arg[:size] for arg in result) - else: - fillvals = fill_value if np.ndim(fill_value) else len(result) * [fill_value or 0] - return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) - for fval, arg in safe_zip(fillvals, result)) - bm_fun = lambda x: bm.nonzero(x, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testFlatNonzero(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.flatnonzero) - bm_fun = bm.flatnonzero - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - - # JIT compilation requires specifying the size statically: - bm_fun = lambda x: bm.flatnonzero(x, size=np.size(x) // 2) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_size={}_fill_value={}".format( - jtu.format_shape_dtype_string(shape, dtype), size, fill_value), - "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} - for shape in nonempty_array_shapes - for dtype in all_dtypes - for fill_value in [None, -1, 10, (-1,), (10,)] - for size in [1, 5, 10])) - def testFlatNonzeroSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") - def np_fun(x): - result = np.flatnonzero(x) - if size <= len(result): - return result[:size] - else: - fill_val = fill_value or 0 - return np.concatenate([result, np.full(size - len(result), fill_val, result.dtype)]) - bm_fun = lambda x: bm.flatnonzero(x, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testArgWhere(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.argwhere) - bm_fun = bm.argwhere - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - - # JIT compilation requires specifying a size statically. Full test of this - # behavior is in testNonzeroSize(). - bm_fun = lambda x: bm.argwhere(x, size=np.size(x) // 2) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_size={}_fill_value={}".format( - jtu.format_shape_dtype_string(shape, dtype), size, fill_value), - "shape": shape, "dtype": dtype, "size": size, "fill_value": fill_value} - for shape in nonempty_array_shapes - for dtype in all_dtypes - for fill_value in [None, -1, shape or (1,)] - for size in [1, 5, 10])) - def testArgWhereSize(self, shape, dtype, size, fill_value): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") - def np_fun(x): - result = np.argwhere(x) - if size <= len(result): - return result[:size] - else: - fillvals = fill_value if np.ndim(fill_value) else result.shape[-1] * [fill_value or 0] - return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) - for fval, arg in safe_zip(fillvals, result.T)]).T - bm_fun = lambda x: bm.argwhere(x, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( - rec.test_name.capitalize(), - jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), - "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), - "axis": axis, "keepdims": keepdims} - for rec in JAX_ARGMINMAX_RECORDS - for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) - for axis in range(-len(shape), len(shape)) - for keepdims in [True, False])) - def testArgMinMax(self, np_op, bm_op, rng_factory, shape, dtype, axis, keepdims): - rng = rng_factory(self.rng()) - if dtype == np.complex128 and jtu.device_under_test() == "gpu": - raise unittest.SkipTest("complex128 reductions not supported on GPU") - if "nan" in np_op.__name__ and dtype == jnp.bfloat16: - raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") - if numpy_version < (1, 22) and keepdims: - raise unittest.SkipTest("NumPy < 1.22 does not support keepdims argument to argmin/argmax") - kwds = {"keepdims": True} if keepdims else {} - - np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) - bm_fun = partial(bm_op, axis=axis, **kwds) - - args_maker = lambda: [rng(shape, dtype)] - try: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - except ValueError as e: - if str(e) == "All-NaN slice encountered": - self.skipTest("JAX doesn't support checking for all-NaN slices") - else: - raise - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": rec.test_name.capitalize(), "name": rec.name, - "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name)} - for rec in JAX_ARGMINMAX_RECORDS)) - def testArgMinMaxEmpty(self, name, np_op, bm_op): - name = name[3:] if name.startswith("nan") else name - msg = "attempt to get {} of an empty sequence".format(name) - with self.assertRaises(ValueError, msg=msg): - bm_op(np.array([])) - with self.assertRaises(ValueError, msg=msg): - bm_op(np.zeros((2, 0)), axis=1) - np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) - bm_fun = partial(bm_op, axis=0) - args_maker = lambda: [np.zeros((2, 0))] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - axes), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "axes": axes} - for lhs_shape, rhs_shape, axes in [ - [(2,), (2,), (-1, -1, -1, None)], # scalar output - [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors - [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors - [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting - [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes - [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting - [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors - [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting - [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing - [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before - ] - for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - axisa, axisb, axisc, axis = axes - bm_fun = lambda a, b: bm.cross(a, b, axisa, axisb, axisc, axis) - def np_fun(a, b): - a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a - b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b - out = np.cross(a, b, axisa, axisb, axisc, axis) - return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, - rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype} - for name, lhs_shape, rhs_shape in [ - ("matrix-scalar", (3, 3), ()), - ("scalar-matrix", (), (3, 3)), - ("matrix-vector", (4, 5), (5,)), - ("vector-matrix", (6,), (6, 4)), - ("matrix-matrix", (3, 4), (4, 5)), - ("tensor-vector", (4, 3, 2), (2,)), - ("vector-tensor", (2,), (3, 2, 4)), - ("tensor-matrix", (4, 3, 2), (2, 5)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-tensor", (2, 3, 4), (5, 4, 1))] - for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) - def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-14, - np.complex128: 1e-14} - if jtu.device_under_test() == "tpu": - tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1 - def np_dot(x, y): - x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x - y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y - return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) - self._CheckAgainstNumpy(np_dot, bm_func(bm.dot), args_maker, - tol=tol) - self._CompileAndCheck(bm_func(bm.dot), args_maker, atol=tol, - rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - name, - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype} - for name, lhs_shape, rhs_shape in [ - ("vector-vector", (3,), (3,)), - ("matrix-vector", (3, 3), (3,)), - ("vector-matrix", (3,), (3, 3)), - ("matrix-matrix", (3, 3), (3, 3)), - ("vector-tensor", (3,), (5, 3, 2)), - ("tensor-vector", (5, 3, 2), (2,)), - ("matrix-tensor", (5, 2), (3, 2, 4)), - ("tensor-matrix", (5, 2, 3), (3, 2)), - ("tensor-tensor", (5, 3, 4), (5, 4, 1)), - ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] - for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) - def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - def np_fun(x, y): - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.matmul(x, y).astype(dtype) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, - np.complex128: 1e-12} - if jtu.device_under_test() == "tpu": - tol[np.float16] = tol[np.float32] = tol[np.complex64] = 4e-2 - self._CheckAgainstNumpy(np_fun, bm_func(bm.matmul), args_maker, tol=tol) - self._CompileAndCheck(bm_func(bm.matmul), args_maker, atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), - axes), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, - "axes": axes} - for lhs_shape, rhs_shape, axes in [ - [(3,), (), 0], - [(2, 3, 4), (5, 6, 7), 0], # from issue #740 - [(2, 3, 4), (3, 4, 5, 6), 2], - [(2, 3, 4), (5, 4, 3, 6), [1, 2]], - [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], - [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], - ] - for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) - def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - bm_fun = lambda a, b: bm.tensordot(a, b, axes) - def np_fun(a, b): - a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32) - b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32) - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.tensordot(a, b, axes).astype(dtype) - tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12, - np.complex64: 1e-3, np.complex128: 1e-12} - if jtu.device_under_test() == "tpu": - tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1 - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testTensordotErrors(self): - a = self.rng().random((3, 2, 2)) - b = self.rng().random((2,)) - self.assertRaisesRegex( - TypeError, "Number of tensordot axes.*exceeds input ranks.*", - lambda: bm.tensordot(a, b, axes=2)) - - self.assertRaisesRegex( - TypeError, "tensordot requires axes lists to have equal length.*", - lambda: bm.tensordot(a, b, axes=([0], [0, 1]))) - - self.assertRaisesRegex( - TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*", - lambda: bm.tensordot(a, b, axes=('bad', 'axes'))) - - self.assertRaisesRegex( - TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", - lambda: bm.tensordot(a, b, axes='badaxes')) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_invert={}".format( - jtu.format_shape_dtype_string(element_shape, dtype), - jtu.format_shape_dtype_string(test_shape, dtype), invert), - "element_shape": element_shape, "test_shape": test_shape, - "dtype": dtype, "invert": invert} - for element_shape in all_shapes - for test_shape in all_shapes - for dtype in default_dtypes - for invert in [True, False])) - def testIsin(self, element_shape, test_shape, dtype, invert): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - bm_fun = lambda e, t: bm.isin(e, t, invert=invert) - np_fun = lambda e, t: np.isin(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_invert={}".format( - jtu.format_shape_dtype_string(element_shape, dtype), - jtu.format_shape_dtype_string(test_shape, dtype), invert), - "element_shape": element_shape, "test_shape": test_shape, - "dtype": dtype, "invert": invert} - for element_shape in all_shapes - for test_shape in all_shapes - for dtype in default_dtypes - for invert in [True, False])) - def testIn1d(self, element_shape, test_shape, dtype, invert): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - bm_fun = lambda e, t: bm.in1d(e, t, invert=invert) - np_fun = lambda e, t: np.in1d(e, t, invert=invert) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(shape1, dtype1), - jtu.format_shape_dtype_string(shape2, dtype2)), - "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2} - for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] - for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] - for shape1 in all_shapes - for shape2 in all_shapes)) - def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - self._CheckAgainstNumpy(np.setdiff1d, bm_func(bm.setdiff1d), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_size={}_fill_value={}".format( - jtu.format_shape_dtype_string(shape1, dtype1), - jtu.format_shape_dtype_string(shape2, dtype2), - size, fill_value), - "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2, - "size": size, "fill_value": fill_value} - for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] - for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] - for shape1 in all_shapes - for shape2 in all_shapes - for size in [1, 5, 10] - for fill_value in [None, -1])) - def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - result = np.setdiff1d(arg1, arg2) - if size <= len(result): - return result[:size] - else: - return np.pad(result, (0, size-len(result)), constant_values=fill_value or 0) - def bm_fun(arg1, arg2): - return bm.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(shape1, dtype1), - jtu.format_shape_dtype_string(shape2, dtype2)), - "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2} - for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] - for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] - for shape1 in nonempty_nonscalar_array_shapes - for shape2 in nonempty_nonscalar_array_shapes)) - def testUnion1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - return np.union1d(arg1, arg2).astype(dtype) - self._CheckAgainstNumpy(np_fun, bm_func(bm.union1d), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_size={}_fill_value={}".format( - jtu.format_shape_dtype_string(shape1, dtype1), - jtu.format_shape_dtype_string(shape2, dtype2), size, fill_value), - "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2, - "size": size, "fill_value": fill_value} - for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] - for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] - for shape1 in nonempty_nonscalar_array_shapes - for shape2 in nonempty_nonscalar_array_shapes - for size in [1, 5, 10] - for fill_value in [None, -1])) - def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - def np_fun(arg1, arg2): - dtype = jnp.promote_types(arg1.dtype, arg2.dtype) - result = np.union1d(arg1, arg2).astype(dtype) - fv = result.min() if fill_value is None else fill_value - if size <= len(result): - return result[:size] - else: - return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) - def bm_fun(arg1, arg2): - return bm.union1d(arg1, arg2, size=size, fill_value=fill_value) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_assume_unique={}".format( - jtu.format_shape_dtype_string(shape1, dtype1), - jtu.format_shape_dtype_string(shape2, dtype2), - assume_unique), - "shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2, - "assume_unique": assume_unique} - for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] - for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] - for shape1 in all_shapes - for shape2 in all_shapes - for assume_unique in [False, True])) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - bm_fun = lambda ar1, ar2: bm.setxor1d(ar1, ar2, assume_unique=assume_unique) - def np_fun(ar1, ar2): - if assume_unique: - # pre-flatten the arrays to match with jax implementation - ar1 = np.ravel(ar1) - ar2 = np.ravel(ar2) - return np.setxor1d(ar1, ar2, assume_unique) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_assume_unique={}_return_indices={}".format( - jtu.format_shape_dtype_string(shape1, dtype1), - jtu.format_shape_dtype_string(shape2, dtype2), - assume_unique, - return_indices), - "shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2, - "assume_unique": assume_unique, "return_indices": return_indices} - for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] - for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] - for shape1 in all_shapes - for shape2 in all_shapes - for assume_unique in [False, True] - for return_indices in [False, True])) - def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, return_indices): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - bm_fun = lambda ar1, ar2: bm.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}".format( - jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), - jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), - "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, - "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype} - # TODO(phawkins): support integer dtypes too. - for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) - if len(jtu._dims_of_shape(lhs_shape)) == 0 - or len(jtu._dims_of_shape(rhs_shape)) == 0 - or lhs_shape[-1] == rhs_shape[-1])) - def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] - def np_fun(lhs, rhs): - lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32) - rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32) - dtype = jnp.promote_types(lhs_dtype, rhs_dtype) - return np.inner(lhs, rhs).astype(dtype) - bm_fun = lambda lhs, rhs: bm.inner(lhs, rhs) - tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, - np.complex64: 1e-5} - if jtu.device_under_test() == "tpu": - tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1 - tol = max(jtu.tolerance(lhs_dtype, tol_spec), - jtu.tolerance(rhs_dtype, tol_spec)) - # TODO(phawkins): there are float32/float64 disagreements for some inputs. - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=False, atol=tol, rtol=tol) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_deg={}_rcond={}_full={}_w={}_cov={}".format( - jtu.format_shape_dtype_string(shape, dtype), - deg, - rcond, - full, - w, - cov), - "shape": shape, "dtype": dtype, "deg": deg, - "rcond": rcond, "full": full, "w":w, "cov":cov} - for dtype in [dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]] - for shape in [shape for shape in one_dim_array_shapes if shape != (1,)] - for deg in [1, 2, 3] - for rcond in [None, -1, 10e-3, 10e-5, 10e-10] - for full in [False, True] - for w in [False, True] - for cov in [False, True, "unscaled"])) - def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): - rng = jtu.rand_default(self.rng()) - tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} - if jtu.device_under_test() == "tpu": - tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1 - tol = jtu.tolerance(dtype, tol_spec) - _w = lambda a: abs(a) if w else None - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] - bm_fun = lambda x, y, a: bm.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) - np_fun = jtu.ignore_warning( - message="Polyfit may be poorly conditioned*")(lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=False, atol=tol, rtol=tol) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_amin={}_amax={}".format( - jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), - "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max} - for shape in all_shapes for dtype in number_dtypes - for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1), - (-np.ones(1), None), - (None, np.ones(1)), - (np.full(1, -0.9), np.ones(1))])) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testClipStaticBounds(self, shape, dtype, a_min, a_max): - if np.issubdtype(dtype, np.unsignedinteger): - a_min = None if a_min is None else abs(a_min) - a_max = None if a_max is None else abs(a_max) - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) - bm_fun = lambda x: bm.clip(x, a_min=a_min, a_max=a_max) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testClipError(self): - with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): - bm.clip(jnp.zeros((3,))) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_decimals={}".format( - jtu.format_shape_dtype_string(shape, dtype), decimals), - "shape": shape, "dtype": dtype, "decimals": decimals} - for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes) - for decimals in [0, 1, -2])) - def testRoundStaticDecimals(self, shape, dtype, decimals): - rng = jtu.rand_default(self.rng()) - if jnp.issubdtype(dtype, np.integer) and decimals < 0: - self.skipTest("Integer rounding with decimals < 0 not implemented") - np_fun = lambda x: np.round(x, decimals=decimals) - bm_fun = lambda x: bm.round(x, decimals=decimals) - args_maker = lambda: [rng(shape, dtype)] - tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=check_dtypes, tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=check_dtypes, - atol=tol, rtol=tol) - - def testOperatorRound(self): - self.assertAllClose(round(np.float32(7.532), 1), - round(bm.float32(7.5), 1)) - self.assertAllClose(round(np.float32(1.234), 2), - round(bm.float32(1.234), 2)) - self.assertAllClose(round(np.float32(1.234)), - round(bm.float32(1.234)), check_dtypes=False) - self.assertAllClose(round(np.float32(7.532), 1), - round(bm.array(7.5, bm.float32), 1)) - self.assertAllClose(round(np.float32(1.234), 2), - round(bm.array(1.234, bm.float32), 2)) - self.assertAllClose(round(np.float32(1.234)), - round(bm.array(1.234, bm.float32)), - check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format( - jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, - constant_values), - "shape": shape, "dtype": dtype, "mode": mode, - "pad_width": pad_width, "constant_values": constant_values} - for mode, shapes in [ - ('constant', all_shapes), - ('wrap', nonempty_shapes), - ('edge', nonempty_shapes), - ] - for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) - for constant_values in [ - # None is used for modes other than 'constant' - None, - # constant - 0, 1, - # (constant,) - (0,), (2.718,), - # ((before_const, after_const),) - ((0, 2),), ((-1, 3.14),), - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i / 2, -3.14 * i) for i in range(len(shape))), - ] - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - if (pad_width != () and constant_values != () and - ((mode == 'constant' and constant_values is not None) or - (mode != 'constant' and constant_values is None))))) - def testPad(self, shape, dtype, mode, pad_width, constant_values): - if np.issubdtype(dtype, np.unsignedinteger): - constant_values = tree_util.tree_map(abs, constant_values) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if constant_values is None: - np_fun = partial(np.pad, pad_width=pad_width, mode=mode) - bm_fun = partial(bm.pad, pad_width=pad_width, mode=mode) - else: - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, - constant_values=constant_values) - bm_fun = partial(bm.pad, pad_width=pad_width, mode=mode, - constant_values=constant_values) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_pad_width={}_stat_length={}".format( - jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, stat_length), - "shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width, - "stat_length": stat_length} - for mode in ['maximum', 'minimum', 'mean', 'median'] - for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - for stat_length in [ - None, - # ((before_1, after_1), ..., (before_N, after_N)) - tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 2),), - # (before, after) (not in the docstring but works in numpy) - (1, 1), (3, 4), - # (pad,) - (1,), (2,), - # pad - 1, 2 - ] - if (pad_width != () and stat_length != () and - not (dtype in bool_dtypes and mode == 'mean')))) - def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length): - if mode == 'median' and np.issubdtype(dtype, np.complexfloating): - self.skipTest("median statistic is not supported for dtype=complex.") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) - bm_fun = partial(bm.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_pad_width={}_reflect_type={}".format( - jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, reflect_type), - "shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width, - "reflect_type": reflect_type} - for mode in ['symmetric', 'reflect'] - for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 3),), - # (before, after) (not in the docstring but works in numpy) - (2, 1), (1, 2), - # (pad,) - (1,), (2,), (3,), - # pad - 0, 5, 7, 10 - ] - for reflect_type in ['even', 'odd'] - if (pad_width != () and - # following types lack precision when calculating odd values - (reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16])))) - def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) - bm_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, - tol={np.float32: 1e-3, np.complex64: 1e-3}) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format( - jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values), - "shape": shape, "dtype": dtype, "pad_width": pad_width, - "end_values": end_values} - for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes) - for pad_width in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2, 0),), - # (before, after) (not in the docstring but works in numpy) - (2, 0), (0, 0), - # (pad,) - (1,), (2,), - # pad - 0, 1, - ] - for end_values in [ - # ((before_1, after_1), ..., (before_N, after_N)) - tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), - # ((before, after),) - ((1, 2),), ((2.0, 3.14),), - # (before, after) (not in the docstring but works in numpy) - (0, 0), (-8.0, 2.0), - # (end_values,) - (1,), (2,), - # end_values - 0, 1, 100, 10.0, 3.5, 4.2, -5, -3 - ] - if (pad_width != () and end_values != () and - # following types lack precision - dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16]))) - def testPadLinearRamp(self, shape, dtype, pad_width, end_values): - if numpy_version < (1, 20) and np.issubdtype(dtype, np.integer): - raise unittest.SkipTest("NumPy 1.20 changed the semantics of np.linspace") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp", - end_values=end_values) - bm_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp", - end_values=end_values) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testPadEmpty(self): - arr = np.arange(6).reshape(2, 3) - - pad_width = ((2, 3), (3, 1)) - np_res = np.pad(arr, pad_width=pad_width, mode="empty") - bm_res = bm.pad(arr, pad_width=pad_width, mode="empty").value - - np.testing.assert_equal(np_res.shape, bm_res.shape) - np.testing.assert_equal(arr, np_res[2:-3, 3:-1]) - np.testing.assert_equal(arr, bm_res[2:-3, 3:-1]) - np.testing.assert_equal(np_res[2:-3, 3:-1], bm_res[2:-3, 3:-1]) - - def testPadKwargs(self): - modes = { - 'constant': {'constant_values': 0}, - 'edge': {}, - 'linear_ramp': {'end_values': 0}, - 'maximum': {'stat_length': None}, - 'mean': {'stat_length': None}, - 'median': {'stat_length': None}, - 'minimum': {'stat_length': None}, - 'reflect': {'reflect_type': 'even'}, - 'symmetric': {'reflect_type': 'even'}, - 'wrap': {}, - 'empty': {} - } - arr = bm.array([1, 2, 3]) - pad_width = 1 - - for mode in modes.keys(): - allowed = modes[mode] - not_allowed = {} - for kwargs in modes.values(): - if kwargs != allowed: - not_allowed.update(kwargs) - - # Test if allowed keyword arguments pass - bm.pad(arr, pad_width, mode, **allowed) - # Test if prohibited keyword arguments of other modes raise an error - match = "unsupported keyword arguments for mode '{}'".format(mode) - for key, value in not_allowed.items(): - with self.assertRaisesRegex(ValueError, match): - bm.pad(arr, pad_width, mode, **{key: value}) - - # Test if unsupported mode raise error. - unsupported_modes = [1, None, "foo"] - for mode in unsupported_modes: - match = "Unimplemented padding mode '{}' for np.pad.".format(mode) - with self.assertRaisesRegex(NotImplementedError, match): - bm.pad(arr, pad_width, mode) - - def testPadFunction(self): - def np_pad_with(vector, pad_width, iaxis, kwargs): - pad_value = kwargs.get('padder', 10) - vector[:pad_width[0]] = pad_value - vector[-pad_width[1]:] = pad_value - - def bm_pad_with(vector, pad_width, iaxis, kwargs): - pad_value = kwargs.get('padder', 10) - vector = vector.at[:pad_width[0]].set(pad_value) - vector = vector.at[-pad_width[1]:].set(pad_value) - return vector - - arr = np.arange(6).reshape(2, 3) - np_res = np.pad(arr, 2, np_pad_with) - bm_res = bm.pad(arr, 2, bm_pad_with) - np.testing.assert_equal(np_res, bm_res) - - arr = np.arange(24).reshape(2, 3, 4) - np_res = np.pad(arr, 1, np_pad_with, padder=100) - bm_res = bm.pad(arr, 1, bm_pad_with, padder=100) - np.testing.assert_equal(np_res, bm_res) - - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(arr.shape, arr.dtype)] - bm_fun = partial(bm.pad, pad_width=1, mode=bm_pad_with) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testPadWithNumpyPadWidth(self): - a = bm.array([1, 2, 3, 4, 5]) - f = jax.jit( - partial( - bm.pad, - pad_width=np.asarray((2, 3)), - mode="constant", - constant_values=(4, 6))) - - np.testing.assert_array_equal( - f(a), - np.pad( - a, - pad_width=np.asarray((2, 3)), - mode="constant", - constant_values=(4, 6))) - - def testPadWeakType(self): - x = bm.array(1.0)[None] - for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', - 'minimum', 'reflect', 'symmetric', 'wrap', 'empty']: - y = bm.pad(x, 0, mode=mode).value - self.assertTrue(dtypes.is_weakly_typed(y)) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape=[{}]_reps={}".format( - jtu.format_shape_dtype_string(shape, dtype), reps), - "shape": shape, "dtype": dtype, "reps": reps} - for reps in [(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)] - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - )) - def testTile(self, shape, dtype, reps): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.tile(arg, reps) - bm_fun = lambda arg: bm.tile(arg, reps) - - args_maker = lambda: [rng(shape, dtype)] - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes - for dtype in all_dtypes)) - def testExtract(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] - self._CheckAgainstNumpy(np.extract, bm_func(bm.extract), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_ncond={}_nfunc={}".format( - jtu.format_shape_dtype_string(shape, dtype), ncond, nfunc), - "shape": shape, "dtype": dtype, "ncond": ncond, "nfunc": nfunc} - for ncond in [1, 2, 3] - for nfunc in [ncond, ncond + 1] - for shape in all_shapes - for dtype in all_dtypes)) - def testPiecewise(self, shape, dtype, ncond, nfunc): - rng = jtu.rand_default(self.rng()) - rng_bool = jtu.rand_int(self.rng(), 0, 2) - funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc] - args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)]) - np_fun = partial(np.piecewise, funclist=funclist) - bm_fun = partial(bm.piecewise, funclist=funclist) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) - # This is a higher-order function, so the cache miss check will fail. - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True, check_cache_misses=False) - - def testPiecewiseRecompile(self): - def g(x): - g.num_traces += 1 - return x - g.num_traces = 0 - x = bm.arange(10.0) - for i in range(5): - bm.piecewise(x, [x < 0], [g, 0.]) - self.assertEqual(g.num_traces, 1) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_perm={}_{}".format( - jtu.format_shape_dtype_string(shape, dtype), perm, arg_type), - "dtype": dtype, "shape": shape, "perm": perm, "arg_type": arg_type} - for dtype in default_dtypes - for shape in array_shapes - for arg_type in ["splat", "value"] - for perm in [None, tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim))])) - def testTransposeTuple(self, shape, dtype, perm, arg_type): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - if arg_type == "value": - np_fun = lambda x: x.transpose(perm) - bm_fun = lambda x: bm.array(x).transpose(perm) - else: - np_fun = lambda x: x.transpose(*(perm or ())) - bm_fun = lambda x: bm.array(x).transpose(*(perm or ())) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "{}_trim={}".format( - jtu.format_shape_dtype_string(a_shape, dtype), trim), - "dtype": dtype, "a_shape": a_shape, "trim": trim} - for dtype in default_dtypes - for a_shape in one_dim_array_shapes - for trim in ["f", "b", "fb"])) - def testTrimZeros(self, a_shape, dtype, trim): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(a_shape, dtype)] - np_fun = lambda arg1: np.trim_zeros(arg1, trim) - bm_fun = lambda arg1: bm.trim_zeros(arg1, trim) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_rank{}".format( - jtu.format_shape_dtype_string(a_shape, dtype), rank), - "dtype": dtype, "a_shape": a_shape, "rank": rank} - for rank in (1, 2) - for dtype in default_dtypes - for a_shape in one_dim_array_shapes)) - def testPoly(self, a_shape, dtype, rank): - if dtype in (np.float16, jnp.bfloat16, np.int16): - self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") - elif rank == 2 and jtu.device_under_test() in ("tpu", "gpu"): - self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") - rng = jtu.rand_default(self.rng()) - tol = { np.int8: 1e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6 } - if jtu.device_under_test() == "tpu": - tol[np.int32] = tol[np.float32] = 1e-1 - tol = jtu.tolerance(dtype, tol) - args_maker = lambda: [rng(a_shape * rank, dtype)] - self._CheckAgainstNumpy(np.poly, bm_func(bm.poly), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm.poly), args_maker, check_dtypes=True, rtol=tol, atol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "a_shape={} , b_shape={}".format( - jtu.format_shape_dtype_string(a_shape, dtype), - jtu.format_shape_dtype_string(b_shape, dtype)), - "dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape} - for dtype in default_dtypes - for a_shape in one_dim_array_shapes - for b_shape in one_dim_array_shapes)) - def testPolyAdd(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2) - bm_fun = lambda arg1, arg2: bm.polyadd(arg1, arg2) - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "a_shape={} , b_shape={}".format( - jtu.format_shape_dtype_string(a_shape, dtype), - jtu.format_shape_dtype_string(b_shape, dtype)), - "dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape} - for dtype in default_dtypes - for a_shape in one_dim_array_shapes - for b_shape in one_dim_array_shapes)) - def testPolySub(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polysub(arg1, arg2) - bm_fun = lambda arg1, arg2: bm.polysub(arg1, arg2) - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_order={}_k={}".format( - jtu.format_shape_dtype_string(a_shape, dtype), - order, k), - "dtype": dtype, "a_shape": a_shape, "order" : order, "k": k} - for dtype in default_dtypes - for a_shape in one_dim_array_shapes - for order in range(5) - for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None])) - def testPolyInt(self, a_shape, order, k, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1: np.polyint(arg1, m=order, k=k) - bm_fun = lambda arg1: bm.polyint(arg1, m=order, k=k) - args_maker = lambda: [rng(a_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_order={}".format( - jtu.format_shape_dtype_string(a_shape, dtype), - order), - "dtype": dtype, "a_shape": a_shape, "order" : order} - for dtype in default_dtypes - for a_shape in one_dim_array_shapes - for order in range(5))) - def testPolyDer(self, a_shape, order, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1: np.polyder(arg1, m=order) - bm_fun = lambda arg1: bm.polyder(arg1, m=order) - args_maker = lambda: [rng(a_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_ptype={}".format(ptype), "ptype": ptype} - for ptype in ['int', 'np.int', 'bm.int'])) - def testIntegerPower(self, ptype): - p = {'int': 2, 'np.int': np.int32(2), 'bm.int': bm.int32(2)}[ptype] - jaxpr = jax.make_jaxpr(partial(bm_func(bm.power), x2=p))(1) - eqns = jaxpr.jaxpr.eqns - self.assertLen(eqns, 1) - self.assertEqual(eqns[0].primitive, lax.integer_pow_p) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_x={}_y={}".format(x, y), "x": x, "y": y} - for x in [-1, 0, 1] - for y in [0, 32, 64, 128])) - def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 - args_maker = lambda: [x, y] - self._CheckAgainstNumpy(np.power, bm_func(bm.power), args_maker) - self._CompileAndCheck(bm_func(bm.power), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in all_shapes - for dtype in all_dtypes - for axis in [None] + list(range(len(shape))))) - def testCompress(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - if shape in scalar_shapes or len(shape) == 0: - cond_shape = (0,) - elif axis is None: - cond_shape = (prod(shape),) - else: - cond_shape = (shape[axis],) - - args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - - np_fun = partial(np.compress, axis=axis) - bm_fun = partial(bm.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_condition=array[{}]_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), len(condition), axis), - "shape": shape, "dtype": dtype, "condition": condition, "axis": axis} - for shape in [(2, 3)] - for dtype in int_dtypes - # condition entries beyond axis size must be zero. - for condition in [[1], [1, 0, 0, 0, 0, 0, 0]] - for axis in [None, 0, 1])) - def testCompressMismatchedShapes(self, shape, dtype, condition, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.array(condition), rng(shape, dtype)] - np_fun = partial(np.compress, axis=axis) - bm_fun = partial(bm.compress, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(len(shape))))) - def testCompressMethod(self, shape, dtype, axis): - rng = jtu.rand_some_zero(self.rng()) - if shape in scalar_shapes or len(shape) == 0: - cond_shape = (0,) - elif axis is None: - cond_shape = (prod(shape),) - else: - cond_shape = (shape[axis],) - - args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] - - np_fun = lambda condition, x: np.compress(condition, x, axis=axis) - bm_fun = lambda condition, x: bm.compress(condition, x, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( - axis, ",".join(str(d) for d in base_shape), - ",".join(np.dtype(dtype).name for dtype in arg_dtypes)), - "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes} - for num_arrs in [3] - for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)))) - def testConcatenate(self, axis, base_shape, arg_dtypes): - rng = jtu.rand_default(self.rng()) - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def np_fun(*args): - args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32) - for x in args] - dtype = functools.reduce(jnp.promote_types, arg_dtypes) - return np.concatenate(args, axis=axis).astype(dtype) - bm_fun = lambda *args: bm.concatenate(args, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in [(4, 1), (4, 3), (4, 5, 6)] - for dtype in all_dtypes - for axis in [None] + list(range(1 - len(shape), len(shape) - 1)))) - def testConcatenateArray(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda x: np.concatenate(x, axis=axis) - bm_fun = lambda x: bm.concatenate(x, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 - a = bm.array([[1, 2], [3, 4]]) - b = bm.array([[5]]) - bm.concatenate((a, b), axis=None) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( - axis, ",".join(str(d) for d in base_shape), - ",".join(np.dtype(dtype).name for dtype in arg_dtypes)), - "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes} - for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, 2) - for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)))) - def testAppend(self, axis, base_shape, arg_dtypes): - rng = jtu.rand_default(self.rng()) - wrapped_axis = axis % len(base_shape) - shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] - for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] - def np_fun(arr, values): - arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr - values = (values.astype(np.float32) if values.dtype == jnp.bfloat16 - else values) - out = np.append(arr, values, axis=axis) - return out.astype(jnp.promote_types(*arg_dtypes)) - bm_fun = lambda arr, values: bm.append(arr, values, axis=axis) - - def args_maker(): - return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_idx={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, idx), - "dtype": dtype, "shape": shape, "axis": axis, "idx": idx} - for shape in nonempty_nonscalar_array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(-len(shape), len(shape))) - for idx in (range(-prod(shape), prod(shape)) - if axis is None else - range(-shape[axis], shape[axis])))) - def testDeleteInteger(self, shape, dtype, idx, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, idx, axis=axis) - bm_fun = lambda arg: bm.delete(arg, idx, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_slc={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, slc), - "dtype": dtype, "shape": shape, "axis": axis, "slc": slc} - for shape in nonempty_nonscalar_array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(-len(shape), len(shape))) - for slc in [slice(None), slice(1, 3), slice(1, 5, 2)])) - def testDeleteSlice(self, shape, dtype, axis, slc): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, slc, axis=axis) - bm_fun = lambda arg: bm.delete(arg, slc, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_idx={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, - jtu.format_shape_dtype_string(idx_shape, int)), - "dtype": dtype, "shape": shape, "axis": axis, "idx_shape": idx_shape} - for shape in nonempty_nonscalar_array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(-len(shape), len(shape))) - for idx_shape in all_shapes)) - def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): - rng = jtu.rand_default(self.rng()) - max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - # Previous to numpy 1.19, negative indices were ignored so we don't test this. - low = 0 if numpy_version < (1, 19, 0) else -max_idx - idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, idx, axis=axis) - bm_fun = lambda arg: bm.delete(arg, idx, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "dtype": dtype, "shape": shape, "axis": axis} - for shape in nonempty_nonscalar_array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testDeleteMaskArray(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] - mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arg: np.delete(arg, mask, axis=axis) - bm_fun = lambda arg: bm.delete(arg, mask, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "dtype": dtype, "shape": shape, "axis": axis} - for shape in nonempty_nonscalar_array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testInsertInteger(self, shape, dtype, axis): - x = jnp.empty(shape) - max_ind = x.size if axis is None else x.shape[axis] - rng = jtu.rand_default(self.rng()) - i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) - args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)] - np_fun = lambda *args: np.insert(*args, axis=axis) - bm_fun = lambda *args: bm.insert(*args, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "dtype": dtype, "shape": shape, "axis": axis} - for shape in nonempty_nonscalar_array_shapes - for dtype in all_dtypes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testInsertSlice(self, shape, dtype, axis): - x = jnp.empty(shape) - max_ind = x.size if axis is None else x.shape[axis] - rng = jtu.rand_default(self.rng()) - i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) - slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item()) - args_maker = lambda: [rng(shape, dtype), rng((), dtype)] - np_fun = lambda x, val: np.insert(x, slc, val, axis=axis) - bm_fun = lambda x, val: bm.insert(x, slc, val, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.parameters([ - [[[1, 1], [2, 2], [3, 3]], 1, 5, None], - [[[1, 1], [2, 2], [3, 3]], 1, 5, 1], - [[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1], - [[[1, 1], [2, 2], [3, 3]], [1], [[1],[2],[3]], 1], - [[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None], - [[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None], - [[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None], - [[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1] - ]) - def testInsertExamples(self, arr, index, values, axis): - # Test examples from the np.insert docstring - args_maker = lambda: ( - np.asarray(arr), index if isinstance(index, slice) else np.array(index), - np.asarray(values), axis) - self._CheckAgainstNumpy(np.insert, bm_func(bm.insert), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_out_dims={}".format( - jtu.format_shape_dtype_string(shape, dtype), - axis, out_dims), - "shape": shape, "dtype": dtype, "axis": axis, "out_dims": out_dims} - for shape in nonempty_array_shapes - for dtype in default_dtypes - for axis in range(-len(shape), len(shape)) - for out_dims in [0, 1, 2])) - def testApplyAlongAxis(self, shape, dtype, axis, out_dims): - def func(x, out_dims): - if out_dims == 0: - return x.sum() - elif out_dims == 1: - return x * x[0] - elif out_dims == 2: - return x[:, None] + x[None, :] - else: - raise NotImplementedError(f"out_dims={out_dims}") - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims) - bm_fun = lambda arr: bm.apply_along_axis(func, axis, arr, out_dims=out_dims) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_func={}_keepdims={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), - func, keepdims, axes), - "shape": shape, "dtype": dtype, "func": func, "keepdims": keepdims, "axes": axes} - for shape in nonempty_shapes - for func in ["sum"] - for keepdims in [True, False] - for axes in itertools.combinations(range(len(shape)), 2) - # Avoid low-precision types in sum() - for dtype in default_dtypes if dtype not in [np.float16, jnp.bfloat16])) - def testApplyOverAxes(self, shape, dtype, func, keepdims, axes): - f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype),) - np_fun = lambda a: np.apply_over_axes(f, a, axes) - bm_fun = lambda a: bm.apply_over_axes(f, a, axes) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape=[{}]_axis={}_repeats={}_fixed_size={}".format( - jtu.format_shape_dtype_string(shape, dtype), - axis, repeats, fixed_size), - "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, - 'fixed_size': fixed_size} - for repeats in [0, 1, 2] - for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) - for axis in [None] + list(range(-len(shape), max(1, len(shape)))) - for fixed_size in [True, False])) - def testRepeat(self, axis, shape, dtype, repeats, fixed_size): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis) - np_fun = _promote_like_jnp(np_fun) - if fixed_size: - total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0] - bm_fun = lambda arg, rep: bm.repeat(arg, repeats=rep, axis=axis, - total_repeat_length=total_repeat_length) - jnp_args_maker = lambda: [rng(shape, dtype), repeats] - clo_fun = lambda arg: bm.repeat(arg, repeats=repeats, axis=axis, - total_repeat_length=total_repeat_length) - clo_fun_args_maker = lambda: [rng(shape, dtype)] - self._CompileAndCheck(bm_func(bm_fun), jnp_args_maker) - self._CheckAgainstNumpy(np_fun, bm_func(clo_fun), clo_fun_args_maker) - else: - # Now repeats is in a closure, so a constant. - jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) - self._CompileAndCheck(jnp_fun, args_maker) - - def testRepeatScalarFastPath(self): - a = jnp.array([1,2,3,4]) - f = lambda a: bm.repeat(a, repeats=2) - jaxpr = jax.make_jaxpr(bm_func(f))(a) - self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_ind={}_inv={}_count={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, - return_index, return_inverse, return_counts), - "shape": shape, "dtype": dtype, "axis": axis, - "return_index": return_index, "return_inverse": return_inverse, - "return_counts": return_counts} - for dtype in number_dtypes - for shape in all_shapes - for axis in [None] + list(range(len(shape))) - for return_index in [False, True] - for return_inverse in [False, True] - for return_counts in [False, True])) - def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): - if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0: - self.skipTest("zero-sized axis in unique leads to error in older numpy.") - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - extra_args = (return_index, return_inverse, return_counts) - use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False - np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults) - bm_fun = lambda x: bm.unique(x, *extra_args, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_size={}_fill_value={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, size, fill_value), - "shape": shape, "dtype": dtype, "axis": axis, - "size": size, "fill_value": fill_value} - for dtype in number_dtypes - for size in [1, 5, 10] - for fill_value in [None, -1.0, "slice"] - for shape in nonempty_array_shapes - for axis in [None] + list(range(len(shape))))) - def testUniqueSize(self, shape, dtype, axis, size, fill_value): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) - - if fill_value == "slice": - if axis is None: - fill_value = rng((), dtype) - else: - fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) - - @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) - def np_fun(x, fill_value=fill_value): - u, ind, inv, counts = np.unique(x, **kwds) - axis = kwds['axis'] - if axis is None: - x = x.ravel() - axis = 0 - - n_unique = u.shape[axis] - if size <= u.shape[axis]: - slc = (slice(None),) * axis + (slice(size),) - u, ind, counts = u[slc], ind[:size], counts[:size] - else: - extra = (0, size - n_unique) - pads = [(0, 0)] * u.ndim - pads[axis] = extra - u = np.pad(u, pads, constant_values=0) - slices = [slice(None)] * u.ndim - slices[axis] = slice(1) - if fill_value is None: - fill_value = u[tuple(slices)] - elif np.ndim(fill_value): - fill_value = lax.expand_dims(fill_value, (axis,)) - slices[axis] = slice(n_unique, None) - u[tuple(slices)] = fill_value - ind = np.pad(ind, extra, constant_values=ind[0]) - counts = np.pad(counts, extra, constant_values=0) - return u, ind, inv, counts - - bm_fun = lambda x: bm.unique(x, size=size, fill_value=fill_value, **kwds) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @unittest.skipIf(numpy_version < (1, 21), "Numpy < 1.21 does not properly handle NaN values in unique.") - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_{dtype.__name__}", "dtype": dtype} - for dtype in inexact_dtypes)) - def testUniqueNans(self, dtype): - def args_maker(): - x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] - if np.issubdtype(dtype, np.complexfloating): - x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] - return [np.array(x, dtype=dtype)] - - kwds = dict(return_index=True, return_inverse=True, return_counts=True) - bm_fun = partial(bm.unique, **kwds) - def np_fun(x): - dtype = x.dtype - # numpy unique fails for bfloat16 NaNs, so we cast to float64 - if x.dtype == jnp.bfloat16: - x = x.astype('float64') - u, *rest = np.unique(x, **kwds) - return (u.astype(dtype), *rest) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_fixed_size={}".format(fixed_size), - "fixed_size": fixed_size} - for fixed_size in [True, False])) - def testNonScalarRepeats(self, fixed_size): - ''' + "rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name), + "axis": axis, "keepdims": keepdims} + for rec in JAX_ARGMINMAX_RECORDS + for shape, dtype in _shape_and_dtypes(rec.shapes, rec.dtypes) + for axis in range(-len(shape), len(shape)) + for keepdims in [True, False])) + def testArgMinMax(self, np_op, bm_op, rng_factory, shape, dtype, axis, keepdims): + rng = rng_factory(self.rng()) + if dtype == np.complex128 and jtu.device_under_test() == "gpu": + raise unittest.SkipTest("complex128 reductions not supported on GPU") + if "nan" in np_op.__name__ and dtype == jnp.bfloat16: + raise unittest.SkipTest("NumPy doesn't correctly handle bfloat16 arrays") + if numpy_version < (1, 22) and keepdims: + raise unittest.SkipTest("NumPy < 1.22 does not support keepdims argument to argmin/argmax") + kwds = {"keepdims": True} if keepdims else {} + + np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=axis, **kwds)) + bm_fun = partial(bm_op, axis=axis, **kwds) + + args_maker = lambda: [rng(shape, dtype)] + try: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + except ValueError as e: + if str(e) == "All-NaN slice encountered": + self.skipTest("JAX doesn't support checking for all-NaN slices") + else: + raise + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": rec.test_name.capitalize(), "name": rec.name, + "np_op": getattr(np, rec.name), "bm_op": getattr(bm, rec.name)} + for rec in JAX_ARGMINMAX_RECORDS)) + def testArgMinMaxEmpty(self, name, np_op, bm_op): + name = name[3:] if name.startswith("nan") else name + msg = "attempt to get {} of an empty sequence".format(name) + with self.assertRaises(ValueError, msg=msg): + bm_op(np.array([])) + with self.assertRaises(ValueError, msg=msg): + bm_op(np.zeros((2, 0)), axis=1) + np_fun = jtu.with_jax_dtype_defaults(partial(np_op, axis=0)) + bm_fun = partial(bm_op, axis=0) + args_maker = lambda: [np.zeros((2, 0))] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + axes), + "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, + "axes": axes} + for lhs_shape, rhs_shape, axes in [ + [(2,), (2,), (-1, -1, -1, None)], # scalar output + [(2, 4), (2, 4), (-1, -1, -1, 0)], # 2D vectors + [(3, 4), (3, 4), (-1, -1, -1, 0)], # 3D vectors + [(3, 4), (3, 6, 5, 4), (-1, -1, -1, 0)], # broadcasting + [(4, 3), (3, 6, 5, 4), (1, 0, -1, None)], # different axes + [(6, 1, 3), (5, 3), (-1, -1, -1, None)], # more broadcasting + [(6, 1, 2), (5, 3), (-1, -1, -1, None)], # mixed 2D and 3D vectors + [(10, 5, 2, 8), (1, 5, 1, 3), (-2, -1, -3, None)], # axes/broadcasting + [(4, 5, 2), (4, 5, 2), (-1, -1, 0, None)], # axisc should do nothing + [(4, 5, 2), (4, 5, 2), (-1, -1, -1, None)] # same as before + ] + for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testCross(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + axisa, axisb, axisc, axis = axes + bm_fun = lambda a, b: bm.cross(a, b, axisa, axisb, axisc, axis) + + def np_fun(a, b): + a = a.astype(np.float32) if lhs_dtype == jnp.bfloat16 else a + b = b.astype(np.float32) if rhs_dtype == jnp.bfloat16 else b + out = np.cross(a, b, axisa, axisb, axisc, axis) + return out.astype(jnp.promote_types(lhs_dtype, rhs_dtype)) + + tol_spec = {dtypes.bfloat16: 3e-1, np.float16: 0.15} + tol = max(jtu.tolerance(lhs_dtype, tol_spec), + jtu.tolerance(rhs_dtype, tol_spec)) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, + rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), + "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype} + for name, lhs_shape, rhs_shape in [ + ("matrix-scalar", (3, 3), ()), + ("scalar-matrix", (), (3, 3)), + ("matrix-vector", (4, 5), (5,)), + ("vector-matrix", (6,), (6, 4)), + ("matrix-matrix", (3, 4), (4, 5)), + ("tensor-vector", (4, 3, 2), (2,)), + ("vector-tensor", (2,), (3, 2, 4)), + ("tensor-matrix", (4, 3, 2), (2, 5)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-tensor", (2, 3, 4), (5, 4, 1))] + for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) + def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-14, + np.complex128: 1e-14} + if jtu.device_under_test() == "tpu": + tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1 + + def np_dot(x, y): + x = x.astype(np.float32) if lhs_dtype == jnp.bfloat16 else x + y = y.astype(np.float32) if rhs_dtype == jnp.bfloat16 else y + return np.dot(x, y).astype(jnp.promote_types(lhs_dtype, rhs_dtype)) + + self._CheckAgainstNumpy(np_dot, bm_func(bm.dot), args_maker, + tol=tol) + self._CompileAndCheck(bm_func(bm.dot), args_maker, atol=tol, + rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_{}".format( + name, + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), + "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype} + for name, lhs_shape, rhs_shape in [ + ("vector-vector", (3,), (3,)), + ("matrix-vector", (3, 3), (3,)), + ("vector-matrix", (3,), (3, 3)), + ("matrix-matrix", (3, 3), (3, 3)), + ("vector-tensor", (3,), (5, 3, 2)), + ("tensor-vector", (5, 3, 2), (2,)), + ("matrix-tensor", (5, 2), (3, 2, 4)), + ("tensor-matrix", (5, 2, 3), (3, 2)), + ("tensor-tensor", (5, 3, 4), (5, 4, 1)), + ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] + for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) + def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + + def np_fun(x, y): + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.matmul(x, y).astype(dtype) + + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12, + np.complex128: 1e-12} + if jtu.device_under_test() == "tpu": + tol[np.float16] = tol[np.float32] = tol[np.complex64] = 4e-2 + self._CheckAgainstNumpy(np_fun, bm_func(bm.matmul), args_maker, tol=tol) + self._CompileAndCheck(bm_func(bm.matmul), args_maker, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), + axes), + "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, + "axes": axes} + for lhs_shape, rhs_shape, axes in [ + [(3,), (), 0], + [(2, 3, 4), (5, 6, 7), 0], # from issue #740 + [(2, 3, 4), (3, 4, 5, 6), 2], + [(2, 3, 4), (5, 4, 3, 6), [1, 2]], + [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], + [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], + ] + for lhs_dtype, rhs_dtype in itertools.combinations_with_replacement(number_dtypes, 2))) + def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + bm_fun = lambda a, b: bm.tensordot(a, b, axes) + + def np_fun(a, b): + a = a if lhs_dtype != jnp.bfloat16 else a.astype(np.float32) + b = b if rhs_dtype != jnp.bfloat16 else b.astype(np.float32) + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.tensordot(a, b, axes).astype(dtype) + + tol = {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-12, + np.complex64: 1e-3, np.complex128: 1e-12} + if jtu.device_under_test() == "tpu": + tol[np.float16] = tol[np.float32] = tol[np.complex64] = 2e-1 + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testTensordotErrors(self): + a = self.rng().random((3, 2, 2)) + b = self.rng().random((2,)) + self.assertRaisesRegex( + TypeError, "Number of tensordot axes.*exceeds input ranks.*", + lambda: bm.tensordot(a, b, axes=2)) + + self.assertRaisesRegex( + TypeError, "tensordot requires axes lists to have equal length.*", + lambda: bm.tensordot(a, b, axes=([0], [0, 1]))) + + self.assertRaisesRegex( + TypeError, "tensordot requires both axes lists to be either ints, tuples or lists.*", + lambda: bm.tensordot(a, b, axes=('bad', 'axes'))) + + self.assertRaisesRegex( + TypeError, "tensordot axes argument must be an int, a pair of ints, or a pair of lists.*", + lambda: bm.tensordot(a, b, axes='badaxes')) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_invert={}".format( + jtu.format_shape_dtype_string(element_shape, dtype), + jtu.format_shape_dtype_string(test_shape, dtype), invert), + "element_shape": element_shape, "test_shape": test_shape, + "dtype": dtype, "invert": invert} + for element_shape in all_shapes + for test_shape in all_shapes + for dtype in default_dtypes + for invert in [True, False])) + def testIsin(self, element_shape, test_shape, dtype, invert): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] + bm_fun = lambda e, t: bm.isin(e, t, invert=invert) + np_fun = lambda e, t: np.isin(e, t, invert=invert) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_invert={}".format( + jtu.format_shape_dtype_string(element_shape, dtype), + jtu.format_shape_dtype_string(test_shape, dtype), invert), + "element_shape": element_shape, "test_shape": test_shape, + "dtype": dtype, "invert": invert} + for element_shape in all_shapes + for test_shape in all_shapes + for dtype in default_dtypes + for invert in [True, False])) + def testIn1d(self, element_shape, test_shape, dtype, invert): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] + bm_fun = lambda e, t: bm.in1d(e, t, invert=invert) + np_fun = lambda e, t: np.in1d(e, t, invert=invert) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2)), + "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in all_shapes + for shape2 in all_shapes)) + def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + self._CheckAgainstNumpy(np.setdiff1d, bm_func(bm.setdiff1d), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2), + size, fill_value), + "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2, + "size": size, "fill_value": fill_value} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in all_shapes + for shape2 in all_shapes + for size in [1, 5, 10] + for fill_value in [None, -1])) + def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + + def np_fun(arg1, arg2): + result = np.setdiff1d(arg1, arg2) + if size <= len(result): + return result[:size] + else: + return np.pad(result, (0, size - len(result)), constant_values=fill_value or 0) + + def bm_fun(arg1, arg2): + return bm.setdiff1d(arg1, arg2, size=size, fill_value=fill_value) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2)), + "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in nonempty_nonscalar_array_shapes + for shape2 in nonempty_nonscalar_array_shapes)) + def testUnion1d(self, shape1, shape2, dtype1, dtype2): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + return np.union1d(arg1, arg2).astype(dtype) + + self._CheckAgainstNumpy(np_fun, bm_func(bm.union1d), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2), size, fill_value), + "shape1": shape1, "shape2": shape2, "dtype1": dtype1, "dtype2": dtype2, + "size": size, "fill_value": fill_value} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in nonempty_nonscalar_array_shapes + for shape2 in nonempty_nonscalar_array_shapes + for size in [1, 5, 10] + for fill_value in [None, -1])) + def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + + def np_fun(arg1, arg2): + dtype = jnp.promote_types(arg1.dtype, arg2.dtype) + result = np.union1d(arg1, arg2).astype(dtype) + fv = result.min() if fill_value is None else fill_value + if size <= len(result): + return result[:size] + else: + return np.concatenate([result, np.full(size - len(result), fv, result.dtype)]) + + def bm_fun(arg1, arg2): + return bm.union1d(arg1, arg2, size=size, fill_value=fill_value) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_assume_unique={}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2), + assume_unique), + "shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2, + "assume_unique": assume_unique} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in all_shapes + for shape2 in all_shapes + for assume_unique in [False, True])) + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + bm_fun = lambda ar1, ar2: bm.setxor1d(ar1, ar2, assume_unique=assume_unique) + + def np_fun(ar1, ar2): + if assume_unique: + # pre-flatten the arrays to match with jax implementation + ar1 = np.ravel(ar1) + ar2 = np.ravel(ar2) + return np.setxor1d(ar1, ar2, assume_unique) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_assume_unique={}_return_indices={}".format( + jtu.format_shape_dtype_string(shape1, dtype1), + jtu.format_shape_dtype_string(shape2, dtype2), + assume_unique, + return_indices), + "shape1": shape1, "dtype1": dtype1, "shape2": shape2, "dtype2": dtype2, + "assume_unique": assume_unique, "return_indices": return_indices} + for dtype1 in [s for s in default_dtypes if s != jnp.bfloat16] + for dtype2 in [s for s in default_dtypes if s != jnp.bfloat16] + for shape1 in all_shapes + for shape2 in all_shapes + for assume_unique in [False, True] + for return_indices in [False, True])) + def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, return_indices): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + bm_fun = lambda ar1, ar2: bm.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}".format( + jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), + jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), + "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, + "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype} + # TODO(phawkins): support integer dtypes too. + for lhs_shape, lhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + for rhs_shape, rhs_dtype in _shape_and_dtypes(all_shapes, inexact_dtypes) + if len(jtu._dims_of_shape(lhs_shape)) == 0 + or len(jtu._dims_of_shape(rhs_shape)) == 0 + or lhs_shape[-1] == rhs_shape[-1])) + def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] + + def np_fun(lhs, rhs): + lhs = lhs if lhs_dtype != jnp.bfloat16 else lhs.astype(np.float32) + rhs = rhs if rhs_dtype != jnp.bfloat16 else rhs.astype(np.float32) + dtype = jnp.promote_types(lhs_dtype, rhs_dtype) + return np.inner(lhs, rhs).astype(dtype) + + bm_fun = lambda lhs, rhs: bm.inner(lhs, rhs) + tol_spec = {np.float16: 1e-2, np.float32: 1e-5, np.float64: 1e-13, + np.complex64: 1e-5} + if jtu.device_under_test() == "tpu": + tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1 + tol = max(jtu.tolerance(lhs_dtype, tol_spec), + jtu.tolerance(rhs_dtype, tol_spec)) + # TODO(phawkins): there are float32/float64 disagreements for some inputs. + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=False, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_deg={}_rcond={}_full={}_w={}_cov={}".format( + jtu.format_shape_dtype_string(shape, dtype), + deg, + rcond, + full, + w, + cov), + "shape": shape, "dtype": dtype, "deg": deg, + "rcond": rcond, "full": full, "w": w, "cov": cov} + for dtype in [dt for dt in float_dtypes if dt not in [jnp.float16, jnp.bfloat16]] + for shape in [shape for shape in one_dim_array_shapes if shape != (1,)] + for deg in [1, 2, 3] + for rcond in [None, -1, 10e-3, 10e-5, 10e-10] + for full in [False, True] + for w in [False, True] + for cov in [False, True, "unscaled"])) + def testPolyfit(self, shape, dtype, deg, rcond, full, w, cov): + rng = jtu.rand_default(self.rng()) + tol_spec = {np.float32: 1e-3, np.float64: 1e-13, np.complex64: 1e-5} + if jtu.device_under_test() == "tpu": + tol_spec[np.float32] = tol_spec[np.complex64] = 2e-1 + tol = jtu.tolerance(dtype, tol_spec) + _w = lambda a: abs(a) if w else None + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + bm_fun = lambda x, y, a: bm.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov) + np_fun = jtu.ignore_warning( + message="Polyfit may be poorly conditioned*")( + lambda x, y, a: np.polyfit(x, y, deg=deg, rcond=rcond, full=full, w=_w(a), cov=cov)) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=False, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_amin={}_amax={}".format( + jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), + "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max} + for shape in all_shapes for dtype in number_dtypes + for a_min, a_max in [(-1, None), (None, 1), (-0.9, 1), + (-np.ones(1), None), + (None, np.ones(1)), + (np.full(1, -0.9), np.ones(1))])) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testClipStaticBounds(self, shape, dtype, a_min, a_max): + if np.issubdtype(dtype, np.unsignedinteger): + a_min = None if a_min is None else abs(a_min) + a_max = None if a_max is None else abs(a_max) + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max) + bm_fun = lambda x: bm.clip(x, a_min=a_min, a_max=a_max) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testClipError(self): + with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"): + bm.clip(jnp.zeros((3,))) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_decimals={}".format( + jtu.format_shape_dtype_string(shape, dtype), decimals), + "shape": shape, "dtype": dtype, "decimals": decimals} + for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes) + for decimals in [0, 1, -2])) + def testRoundStaticDecimals(self, shape, dtype, decimals): + rng = jtu.rand_default(self.rng()) + if jnp.issubdtype(dtype, np.integer) and decimals < 0: + self.skipTest("Integer rounding with decimals < 0 not implemented") + np_fun = lambda x: np.round(x, decimals=decimals) + bm_fun = lambda x: bm.round(x, decimals=decimals) + args_maker = lambda: [rng(shape, dtype)] + tol = {jnp.bfloat16: 5e-2, np.float16: 1e-2} + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=check_dtypes, tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=check_dtypes, + atol=tol, rtol=tol) + + def testOperatorRound(self): + self.assertAllClose(round(np.float32(7.532), 1), + round(bm.float32(7.5), 1)) + self.assertAllClose(round(np.float32(1.234), 2), + round(bm.float32(1.234), 2)) + self.assertAllClose(round(np.float32(1.234)), + round(bm.float32(1.234)), check_dtypes=False) + self.assertAllClose(round(np.float32(7.532), 1), + round(bm.array(7.5, bm.float32), 1)) + self.assertAllClose(round(np.float32(1.234), 2), + round(bm.array(1.234, bm.float32), 2)) + self.assertAllClose(round(np.float32(1.234)), + round(bm.array(1.234, bm.float32)), + check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_mode={}_padwidth={}_constantvalues={}".format( + jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, + constant_values), + "shape": shape, "dtype": dtype, "mode": mode, + "pad_width": pad_width, "constant_values": constant_values} + for mode, shapes in [ + ('constant', all_shapes), + ('wrap', nonempty_shapes), + ('edge', nonempty_shapes), + ] + for shape, dtype in _shape_and_dtypes(shapes, all_dtypes) + for constant_values in [ + # None is used for modes other than 'constant' + None, + # constant + 0, 1, + # (constant,) + (0,), (2.718,), + # ((before_const, after_const),) + ((0, 2),), ((-1, 3.14),), + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i / 2, -3.14 * i) for i in range(len(shape))), + ] + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + if (pad_width != () and constant_values != () and + ((mode == 'constant' and constant_values is not None) or + (mode != 'constant' and constant_values is None))))) + def testPad(self, shape, dtype, mode, pad_width, constant_values): + if np.issubdtype(dtype, np.unsignedinteger): + constant_values = tree_util.tree_map(abs, constant_values) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if constant_values is None: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode) + bm_fun = partial(bm.pad, pad_width=pad_width, mode=mode) + else: + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + bm_fun = partial(bm.pad, pad_width=pad_width, mode=mode, + constant_values=constant_values) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_mode={}_pad_width={}_stat_length={}".format( + jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, stat_length), + "shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width, + "stat_length": stat_length} + for mode in ['maximum', 'minimum', 'mean', 'median'] + for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + for stat_length in [ + None, + # ((before_1, after_1), ..., (before_N, after_N)) + tuple(((i % 3 + 1), ((i + 1) % 3) + 1) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 2),), + # (before, after) (not in the docstring but works in numpy) + (1, 1), (3, 4), + # (pad,) + (1,), (2,), + # pad + 1, 2 + ] + if (pad_width != () and stat_length != () and + not (dtype in bool_dtypes and mode == 'mean')))) + def testPadStatValues(self, shape, dtype, mode, pad_width, stat_length): + if mode == 'median' and np.issubdtype(dtype, np.complexfloating): + self.skipTest("median statistic is not supported for dtype=complex.") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) + bm_fun = partial(bm.pad, pad_width=pad_width, mode=mode, stat_length=stat_length) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_mode={}_pad_width={}_reflect_type={}".format( + jtu.format_shape_dtype_string(shape, dtype), mode, pad_width, reflect_type), + "shape": shape, "dtype": dtype, "mode": mode, "pad_width": pad_width, + "reflect_type": reflect_type} + for mode in ['symmetric', 'reflect'] + for shape, dtype in _shape_and_dtypes(nonempty_shapes, all_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 3),), + # (before, after) (not in the docstring but works in numpy) + (2, 1), (1, 2), + # (pad,) + (1,), (2,), (3,), + # pad + 0, 5, 7, 10 + ] + for reflect_type in ['even', 'odd'] + if (pad_width != () and + # following types lack precision when calculating odd values + (reflect_type != 'odd' or dtype not in [np.bool_, np.float16, jnp.bfloat16])))) + def testPadSymmetricAndReflect(self, shape, dtype, mode, pad_width, reflect_type): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) + bm_fun = partial(jnp.pad, pad_width=pad_width, mode=mode, reflect_type=reflect_type) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE, + tol={np.float32: 1e-3, np.complex64: 1e-3}) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_mode={}_pad_width={}_end_values={}".format( + jtu.format_shape_dtype_string(shape, dtype), "linear_ramp", pad_width, end_values), + "shape": shape, "dtype": dtype, "pad_width": pad_width, + "end_values": end_values} + for shape, dtype in _shape_and_dtypes(nonempty_shapes, default_dtypes + complex_dtypes) + for pad_width in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2, 0),), + # (before, after) (not in the docstring but works in numpy) + (2, 0), (0, 0), + # (pad,) + (1,), (2,), + # pad + 0, 1, + ] + for end_values in [ + # ((before_1, after_1), ..., (before_N, after_N)) + tuple((i % 3, (i + 1) % 3) for i in range(len(shape))), + # ((before, after),) + ((1, 2),), ((2.0, 3.14),), + # (before, after) (not in the docstring but works in numpy) + (0, 0), (-8.0, 2.0), + # (end_values,) + (1,), (2,), + # end_values + 0, 1, 100, 10.0, 3.5, 4.2, -5, -3 + ] + if (pad_width != () and end_values != () and + # following types lack precision + dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16]))) + def testPadLinearRamp(self, shape, dtype, pad_width, end_values): + if numpy_version < (1, 20) and np.issubdtype(dtype, np.integer): + raise unittest.SkipTest("NumPy 1.20 changed the semantics of np.linspace") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + np_fun = partial(np.pad, pad_width=pad_width, mode="linear_ramp", + end_values=end_values) + bm_fun = partial(jnp.pad, pad_width=pad_width, mode="linear_ramp", + end_values=end_values) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testPadEmpty(self): + arr = np.arange(6).reshape(2, 3) + + pad_width = ((2, 3), (3, 1)) + np_res = np.pad(arr, pad_width=pad_width, mode="empty") + bm_res = bm.pad(arr, pad_width=pad_width, mode="empty").value + + np.testing.assert_equal(np_res.shape, bm_res.shape) + np.testing.assert_equal(arr, np_res[2:-3, 3:-1]) + np.testing.assert_equal(arr, bm_res[2:-3, 3:-1]) + np.testing.assert_equal(np_res[2:-3, 3:-1], bm_res[2:-3, 3:-1]) + + def testPadKwargs(self): + modes = { + 'constant': {'constant_values': 0}, + 'edge': {}, + 'linear_ramp': {'end_values': 0}, + 'maximum': {'stat_length': None}, + 'mean': {'stat_length': None}, + 'median': {'stat_length': None}, + 'minimum': {'stat_length': None}, + 'reflect': {'reflect_type': 'even'}, + 'symmetric': {'reflect_type': 'even'}, + 'wrap': {}, + 'empty': {} + } + arr = bm.array([1, 2, 3]) + pad_width = 1 + + for mode in modes.keys(): + allowed = modes[mode] + not_allowed = {} + for kwargs in modes.values(): + if kwargs != allowed: + not_allowed.update(kwargs) + + # Test if allowed keyword arguments pass + bm.pad(arr, pad_width, mode, **allowed) + # Test if prohibited keyword arguments of other modes raise an error + match = "unsupported keyword arguments for mode '{}'".format(mode) + for key, value in not_allowed.items(): + with self.assertRaisesRegex(ValueError, match): + bm.pad(arr, pad_width, mode, **{key: value}) + + # Test if unsupported mode raise error. + unsupported_modes = [1, None, "foo"] + for mode in unsupported_modes: + match = "Unimplemented padding mode '{}' for np.pad.".format(mode) + with self.assertRaisesRegex(NotImplementedError, match): + bm.pad(arr, pad_width, mode) + + def testPadFunction(self): + def np_pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector[:pad_width[0]] = pad_value + vector[-pad_width[1]:] = pad_value + + def bm_pad_with(vector, pad_width, iaxis, kwargs): + pad_value = kwargs.get('padder', 10) + vector = vector.at[:pad_width[0]].set(pad_value) + vector = vector.at[-pad_width[1]:].set(pad_value) + return vector + + arr = np.arange(6).reshape(2, 3) + np_res = np.pad(arr, 2, np_pad_with) + bm_res = bm.pad(arr, 2, bm_pad_with) + np.testing.assert_equal(np_res, bm_res) + + arr = np.arange(24).reshape(2, 3, 4) + np_res = np.pad(arr, 1, np_pad_with, padder=100) + bm_res = bm.pad(arr, 1, bm_pad_with, padder=100) + np.testing.assert_equal(np_res, bm_res) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(arr.shape, arr.dtype)] + bm_fun = partial(bm.pad, pad_width=1, mode=bm_pad_with) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testPadWithNumpyPadWidth(self): + a = bm.array([1, 2, 3, 4, 5]) + f = jax.jit( + partial( + bm.pad, + pad_width=np.asarray((2, 3)), + mode="constant", + constant_values=(4, 6))) + + np.testing.assert_array_equal( + f(a), + np.pad( + a, + pad_width=np.asarray((2, 3)), + mode="constant", + constant_values=(4, 6))) + + def testPadWeakType(self): + x = bm.array(1.0)[None] + for mode in ['constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', + 'minimum', 'reflect', 'symmetric', 'wrap', 'empty']: + y = bm.pad(x, 0, mode=mode).value + self.assertTrue(dtypes.is_weakly_typed(y)) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape=[{}]_reps={}".format( + jtu.format_shape_dtype_string(shape, dtype), reps), + "shape": shape, "dtype": dtype, "reps": reps} + for reps in [(), (2,), (3, 4), (2, 3, 4), (1, 0, 2)] + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + )) + def testTile(self, shape, dtype, reps): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.tile(arg, reps) + bm_fun = lambda arg: bm.tile(arg, reps) + + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=shape is not jtu.PYTHON_SCALAR_SHAPE) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in all_shapes + for dtype in all_dtypes)) + def testExtract(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, jnp.float32), rng(shape, dtype)] + self._CheckAgainstNumpy(np.extract, bm_func(bm.extract), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_ncond={}_nfunc={}".format( + jtu.format_shape_dtype_string(shape, dtype), ncond, nfunc), + "shape": shape, "dtype": dtype, "ncond": ncond, "nfunc": nfunc} + for ncond in [1, 2, 3] + for nfunc in [ncond, ncond + 1] + for shape in all_shapes + for dtype in all_dtypes)) + def testPiecewise(self, shape, dtype, ncond, nfunc): + rng = jtu.rand_default(self.rng()) + rng_bool = jtu.rand_int(self.rng(), 0, 2) + funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc] + args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)]) + np_fun = partial(np.piecewise, funclist=funclist) + bm_fun = partial(bm.piecewise, funclist=funclist) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) + # This is a higher-order function, so the cache miss check will fail. + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True, check_cache_misses=False) + + def testPiecewiseRecompile(self): + def g(x): + g.num_traces += 1 + return x + + g.num_traces = 0 + x = bm.arange(10.0) + for i in range(5): + bm.piecewise(x, [x < 0], [g, 0.]) + self.assertEqual(g.num_traces, 1) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "{}_perm={}_{}".format( + jtu.format_shape_dtype_string(shape, dtype), perm, arg_type), + "dtype": dtype, "shape": shape, "perm": perm, "arg_type": arg_type} + for dtype in default_dtypes + for shape in array_shapes + for arg_type in ["splat", "value"] + for perm in [None, tuple(np.random.RandomState(0).permutation(np.zeros(shape).ndim))])) + def testTransposeTuple(self, shape, dtype, perm, arg_type): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + if arg_type == "value": + np_fun = lambda x: x.transpose(perm) + bm_fun = lambda x: bm.array(x).transpose(perm) + else: + np_fun = lambda x: x.transpose(*(perm or ())) + bm_fun = lambda x: bm.array(x).transpose(*(perm or ())) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "{}_trim={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), trim), + "dtype": dtype, "a_shape": a_shape, "trim": trim} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for trim in ["f", "b", "fb"])) + def testTrimZeros(self, a_shape, dtype, trim): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(a_shape, dtype)] + np_fun = lambda arg1: np.trim_zeros(arg1, trim) + bm_fun = lambda arg1: bm.trim_zeros(arg1, trim) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_rank{}".format( + jtu.format_shape_dtype_string(a_shape, dtype), rank), + "dtype": dtype, "a_shape": a_shape, "rank": rank} + for rank in (1, 2) + for dtype in default_dtypes + for a_shape in one_dim_array_shapes)) + def testPoly(self, a_shape, dtype, rank): + if dtype in (np.float16, jnp.bfloat16, np.int16): + self.skipTest(f"{dtype} gets promoted to {np.float16}, which is not supported.") + elif rank == 2 and jtu.device_under_test() in ("tpu", "gpu"): + self.skipTest("Nonsymmetric eigendecomposition is only implemented on the CPU backend.") + rng = jtu.rand_default(self.rng()) + tol = {np.int8: 1e-3, np.int32: 1e-3, np.float32: 1e-3, np.float64: 1e-6} + if jtu.device_under_test() == "tpu": + tol[np.int32] = tol[np.float32] = 1e-1 + tol = jtu.tolerance(dtype, tol) + args_maker = lambda: [rng(a_shape * rank, dtype)] + self._CheckAgainstNumpy(np.poly, bm_func(bm.poly), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm.poly), args_maker, check_dtypes=True, rtol=tol, atol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "a_shape={} , b_shape={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), + jtu.format_shape_dtype_string(b_shape, dtype)), + "dtype": dtype, "a_shape": a_shape, "b_shape": b_shape} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for b_shape in one_dim_array_shapes)) + def testPolyAdd(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polyadd(arg1, arg2) + bm_fun = lambda arg1, arg2: bm.polyadd(arg1, arg2) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "a_shape={} , b_shape={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), + jtu.format_shape_dtype_string(b_shape, dtype)), + "dtype": dtype, "a_shape": a_shape, "b_shape": b_shape} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for b_shape in one_dim_array_shapes)) + def testPolySub(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polysub(arg1, arg2) + bm_fun = lambda arg1, arg2: bm.polysub(arg1, arg2) + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_order={}_k={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), + order, k), + "dtype": dtype, "a_shape": a_shape, "order": order, "k": k} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for order in range(5) + for k in [np.arange(order, dtype=dtype), np.ones(1, dtype), None])) + def testPolyInt(self, a_shape, order, k, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1: np.polyint(arg1, m=order, k=k) + bm_fun = lambda arg1: bm.polyint(arg1, m=order, k=k) + args_maker = lambda: [rng(a_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_order={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), + order), + "dtype": dtype, "a_shape": a_shape, "order": order} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for order in range(5))) + def testPolyDer(self, a_shape, order, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1: np.polyder(arg1, m=order) + bm_fun = lambda arg1: bm.polyder(arg1, m=order) + args_maker = lambda: [rng(a_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_ptype={}".format(ptype), "ptype": ptype} + for ptype in ['int', 'np.int', 'bm.int'])) + def testIntegerPower(self, ptype): + p = {'int': 2, 'np.int': np.int32(2), 'bm.int': bm.int32(2)}[ptype] + jaxpr = jax.make_jaxpr(partial(bm_func(bm.power), x2=p))(1) + eqns = jaxpr.jaxpr.eqns + self.assertLen(eqns, 1) + self.assertEqual(eqns[0].primitive, lax.integer_pow_p) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_x={}_y={}".format(x, y), "x": x, "y": y} + for x in [-1, 0, 1] + for y in [0, 32, 64, 128])) + def testIntegerPowerOverflow(self, x, y): + # Regression test for https://github.com/google/jax/issues/5987 + args_maker = lambda: [x, y] + self._CheckAgainstNumpy(np.power, bm_func(bm.power), args_maker) + self._CompileAndCheck(bm_func(bm.power), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for shape in all_shapes + for dtype in all_dtypes + for axis in [None] + list(range(len(shape))))) + def testCompress(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (prod(shape),) + else: + cond_shape = (shape[axis],) + + args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] + + np_fun = partial(np.compress, axis=axis) + bm_fun = partial(bm.compress, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_condition=array[{}]_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), len(condition), axis), + "shape": shape, "dtype": dtype, "condition": condition, "axis": axis} + for shape in [(2, 3)] + for dtype in int_dtypes + # condition entries beyond axis size must be zero. + for condition in [[1], [1, 0, 0, 0, 0, 0, 0]] + for axis in [None, 0, 1])) + def testCompressMismatchedShapes(self, shape, dtype, condition, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [np.array(condition), rng(shape, dtype)] + np_fun = partial(np.compress, axis=axis) + bm_fun = partial(bm.compress, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for shape in array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(len(shape))))) + def testCompressMethod(self, shape, dtype, axis): + rng = jtu.rand_some_zero(self.rng()) + if shape in scalar_shapes or len(shape) == 0: + cond_shape = (0,) + elif axis is None: + cond_shape = (prod(shape),) + else: + cond_shape = (shape[axis],) + + args_maker = lambda: [rng(cond_shape, jnp.float32), rng(shape, dtype)] + + np_fun = lambda condition, x: np.compress(condition, x, axis=axis) + bm_fun = lambda condition, x: bm.compress(condition, x, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( + axis, ",".join(str(d) for d in base_shape), + ",".join(np.dtype(dtype).name for dtype in arg_dtypes)), + "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes} + for num_arrs in [3] + for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, num_arrs) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape) + 1, len(base_shape)))) + def testConcatenate(self, axis, base_shape, arg_dtypes): + rng = jtu.rand_default(self.rng()) + wrapped_axis = axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis + 1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + + def np_fun(*args): + args = [x if x.dtype != jnp.bfloat16 else x.astype(np.float32) + for x in args] + dtype = functools.reduce(jnp.promote_types, arg_dtypes) + return np.concatenate(args, axis=axis).astype(dtype) + + bm_fun = lambda *args: bm.concatenate(args, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for shape in [(4, 1), (4, 3), (4, 5, 6)] + for dtype in all_dtypes + for axis in [None] + list(range(1 - len(shape), len(shape) - 1)))) + def testConcatenateArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda x: np.concatenate(x, axis=axis) + bm_fun = lambda x: bm.concatenate(x, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testConcatenateAxisNone(self): + # https://github.com/google/jax/issues/3419 + a = bm.array([[1, 2], [3, 4]]) + b = bm.array([[5]]) + bm.concatenate((a, b), axis=None) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( + axis, ",".join(str(d) for d in base_shape), + ",".join(np.dtype(dtype).name for dtype in arg_dtypes)), + "axis": axis, "base_shape": base_shape, "arg_dtypes": arg_dtypes} + for arg_dtypes in itertools.combinations_with_replacement(default_dtypes, 2) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(-len(base_shape) + 1, len(base_shape)))) + def testAppend(self, axis, base_shape, arg_dtypes): + rng = jtu.rand_default(self.rng()) + wrapped_axis = axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis + 1:] + for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] + + def np_fun(arr, values): + arr = arr.astype(np.float32) if arr.dtype == jnp.bfloat16 else arr + values = (values.astype(np.float32) if values.dtype == jnp.bfloat16 + else values) + out = np.append(arr, values, axis=axis) + return out.astype(jnp.promote_types(*arg_dtypes)) + + bm_fun = lambda arr, values: bm.append(arr, values, axis=axis) + + def args_maker(): + return [rng(shape, dtype) for shape, dtype in zip(shapes, arg_dtypes)] + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_idx={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, idx), + "dtype": dtype, "shape": shape, "axis": axis, "idx": idx} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))) + for idx in (range(-prod(shape), prod(shape)) + if axis is None else + range(-shape[axis], shape[axis])))) + def testDeleteInteger(self, shape, dtype, idx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + bm_fun = lambda arg: bm.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_slc={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, slc), + "dtype": dtype, "shape": shape, "axis": axis, "slc": slc} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))) + for slc in [slice(None), slice(1, 3), slice(1, 5, 2)])) + def testDeleteSlice(self, shape, dtype, axis, slc): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, slc, axis=axis) + bm_fun = lambda arg: bm.delete(arg, slc, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_idx={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, + jtu.format_shape_dtype_string(idx_shape, int)), + "dtype": dtype, "shape": shape, "axis": axis, "idx_shape": idx_shape} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))) + for idx_shape in all_shapes)) + def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): + rng = jtu.rand_default(self.rng()) + max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + # Previous to numpy 1.19, negative indices were ignored so we don't test this. + low = 0 if numpy_version < (1, 19, 0) else -max_idx + idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + bm_fun = lambda arg: bm.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0") + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "dtype": dtype, "shape": shape, "axis": axis} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))))) + def testDeleteMaskArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, mask, axis=axis) + bm_fun = lambda arg: bm.delete(arg, mask, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "dtype": dtype, "shape": shape, "axis": axis} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))))) + def testInsertInteger(self, shape, dtype, axis): + x = jnp.empty(shape) + max_ind = x.size if axis is None else x.shape[axis] + rng = jtu.rand_default(self.rng()) + i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) + args_maker = lambda: [rng(shape, dtype), i_rng((), np.int32), rng((), dtype)] + np_fun = lambda *args: np.insert(*args, axis=axis) + bm_fun = lambda *args: bm.insert(*args, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "dtype": dtype, "shape": shape, "axis": axis} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))))) + def testInsertSlice(self, shape, dtype, axis): + x = jnp.empty(shape) + max_ind = x.size if axis is None else x.shape[axis] + rng = jtu.rand_default(self.rng()) + i_rng = jtu.rand_int(self.rng(), -max_ind, max_ind) + slc = slice(i_rng((), jnp.int32).item(), i_rng((), jnp.int32).item()) + args_maker = lambda: [rng(shape, dtype), rng((), dtype)] + np_fun = lambda x, val: np.insert(x, slc, val, axis=axis) + bm_fun = lambda x, val: bm.insert(x, slc, val, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.parameters([ + [[[1, 1], [2, 2], [3, 3]], 1, 5, None], + [[[1, 1], [2, 2], [3, 3]], 1, 5, 1], + [[[1, 1], [2, 2], [3, 3]], 1, [1, 2, 3], 1], + [[[1, 1], [2, 2], [3, 3]], [1], [[1], [2], [3]], 1], + [[1, 1, 2, 2, 3, 3], [2, 2], [5, 6], None], + [[1, 1, 2, 2, 3, 3], slice(2, 4), [5, 6], None], + [[1, 1, 2, 2, 3, 3], [2, 2], [7.13, False], None], + [[[0, 1, 2, 3], [4, 5, 6, 7]], (1, 3), 999, 1] + ]) + def testInsertExamples(self, arr, index, values, axis): + # Test examples from the np.insert docstring + args_maker = lambda: ( + np.asarray(arr), index if isinstance(index, slice) else np.array(index), + np.asarray(values), axis) + self._CheckAgainstNumpy(np.insert, bm_func(bm.insert), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_out_dims={}".format( + jtu.format_shape_dtype_string(shape, dtype), + axis, out_dims), + "shape": shape, "dtype": dtype, "axis": axis, "out_dims": out_dims} + for shape in nonempty_array_shapes + for dtype in default_dtypes + for axis in range(-len(shape), len(shape)) + for out_dims in [0, 1, 2])) + def testApplyAlongAxis(self, shape, dtype, axis, out_dims): + def func(x, out_dims): + if out_dims == 0: + return x.sum() + elif out_dims == 1: + return x * x[0] + elif out_dims == 2: + return x[:, None] + x[None, :] + else: + raise NotImplementedError(f"out_dims={out_dims}") + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arr: np.apply_along_axis(func, axis, arr, out_dims=out_dims) + bm_fun = lambda arr: bm.apply_along_axis(func, axis, arr, out_dims=out_dims) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_func={}_keepdims={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), + func, keepdims, axes), + "shape": shape, "dtype": dtype, "func": func, "keepdims": keepdims, "axes": axes} + for shape in nonempty_shapes + for func in ["sum"] + for keepdims in [True, False] + for axes in itertools.combinations(range(len(shape)), 2) + # Avoid low-precision types in sum() + for dtype in default_dtypes if dtype not in [np.float16, jnp.bfloat16])) + def testApplyOverAxes(self, shape, dtype, func, keepdims, axes): + f = lambda x, axis: getattr(x, func)(axis=axis, keepdims=keepdims) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype),) + np_fun = lambda a: np.apply_over_axes(f, a, axes) + bm_fun = lambda a: bm.apply_over_axes(f, a, axes) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape=[{}]_axis={}_repeats={}_fixed_size={}".format( + jtu.format_shape_dtype_string(shape, dtype), + axis, repeats, fixed_size), + "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, + 'fixed_size': fixed_size} + for repeats in [0, 1, 2] + for shape, dtype in _shape_and_dtypes(all_shapes, default_dtypes) + for axis in [None] + list(range(-len(shape), max(1, len(shape)))) + for fixed_size in [True, False])) + def testRepeat(self, axis, shape, dtype, repeats, fixed_size): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.repeat(arg, repeats=repeats, axis=axis) + np_fun = _promote_like_jnp(np_fun) + if fixed_size: + total_repeat_length = np.repeat(np.zeros(shape), repeats, axis).shape[axis or 0] + bm_fun = lambda arg, rep: bm.repeat(arg, repeats=rep, axis=axis, + total_repeat_length=total_repeat_length) + jnp_args_maker = lambda: [rng(shape, dtype), repeats] + clo_fun = lambda arg: bm.repeat(arg, repeats=repeats, axis=axis, + total_repeat_length=total_repeat_length) + clo_fun_args_maker = lambda: [rng(shape, dtype)] + self._CompileAndCheck(bm_func(bm_fun), jnp_args_maker) + self._CheckAgainstNumpy(np_fun, bm_func(clo_fun), clo_fun_args_maker) + else: + # Now repeats is in a closure, so a constant. + jnp_fun = lambda arg: jnp.repeat(arg, repeats=repeats, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + def testRepeatScalarFastPath(self): + a = jnp.array([1, 2, 3, 4]) + f = lambda a: bm.repeat(a, repeats=2) + jaxpr = jax.make_jaxpr(bm_func(f))(a) + self.assertLessEqual(len(jaxpr.jaxpr.eqns), 6) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_ind={}_inv={}_count={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, + return_index, return_inverse, return_counts), + "shape": shape, "dtype": dtype, "axis": axis, + "return_index": return_index, "return_inverse": return_inverse, + "return_counts": return_counts} + for dtype in number_dtypes + for shape in all_shapes + for axis in [None] + list(range(len(shape))) + for return_index in [False, True] + for return_inverse in [False, True] + for return_counts in [False, True])) + def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_counts): + if axis is not None and numpy_version < (1, 19) and np.empty(shape).size == 0: + self.skipTest("zero-sized axis in unique leads to error in older numpy.") + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + extra_args = (return_index, return_inverse, return_counts) + use_defaults = (False, *(True for arg in extra_args if arg)) if any(extra_args) else False + np_fun = jtu.with_jax_dtype_defaults(lambda x: np.unique(x, *extra_args, axis=axis), use_defaults) + bm_fun = lambda x: bm.unique(x, *extra_args, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_size={}_fill_value={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, size, fill_value), + "shape": shape, "dtype": dtype, "axis": axis, + "size": size, "fill_value": fill_value} + for dtype in number_dtypes + for size in [1, 5, 10] + for fill_value in [None, -1.0, "slice"] + for shape in nonempty_array_shapes + for axis in [None] + list(range(len(shape))))) + def testUniqueSize(self, shape, dtype, axis, size, fill_value): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + kwds = dict(axis=axis, return_index=True, return_inverse=True, return_counts=True) + + if fill_value == "slice": + if axis is None: + fill_value = rng((), dtype) + else: + fill_value = rng(shape[:axis] + shape[axis + 1:], dtype) + + @partial(jtu.with_jax_dtype_defaults, use_defaults=(False, True, True, True)) + def np_fun(x, fill_value=fill_value): + u, ind, inv, counts = np.unique(x, **kwds) + axis = kwds['axis'] + if axis is None: + x = x.ravel() + axis = 0 + + n_unique = u.shape[axis] + if size <= u.shape[axis]: + slc = (slice(None),) * axis + (slice(size),) + u, ind, counts = u[slc], ind[:size], counts[:size] + else: + extra = (0, size - n_unique) + pads = [(0, 0)] * u.ndim + pads[axis] = extra + u = np.pad(u, pads, constant_values=0) + slices = [slice(None)] * u.ndim + slices[axis] = slice(1) + if fill_value is None: + fill_value = u[tuple(slices)] + elif np.ndim(fill_value): + fill_value = lax.expand_dims(fill_value, (axis,)) + slices[axis] = slice(n_unique, None) + u[tuple(slices)] = fill_value + ind = np.pad(ind, extra, constant_values=ind[0]) + counts = np.pad(counts, extra, constant_values=0) + return u, ind, inv, counts + + bm_fun = lambda x: bm.unique(x, size=size, fill_value=fill_value, **kwds) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @unittest.skipIf(numpy_version < (1, 21), "Numpy < 1.21 does not properly handle NaN values in unique.") + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"_{dtype.__name__}", "dtype": dtype} + for dtype in inexact_dtypes)) + def testUniqueNans(self, dtype): + def args_maker(): + x = [-0.0, 0.0, 1.0, 1.0, np.nan, -np.nan] + if np.issubdtype(dtype, np.complexfloating): + x = [complex(i, j) for i, j in itertools.product(x, repeat=2)] + return [np.array(x, dtype=dtype)] + + kwds = dict(return_index=True, return_inverse=True, return_counts=True) + bm_fun = partial(bm.unique, **kwds) + + def np_fun(x): + dtype = x.dtype + # numpy unique fails for bfloat16 NaNs, so we cast to float64 + if x.dtype == jnp.bfloat16: + x = x.astype('float64') + u, *rest = np.unique(x, **kwds) + return (u.astype(dtype), *rest) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_fixed_size={}".format(fixed_size), + "fixed_size": fixed_size} + for fixed_size in [True, False])) + def testNonScalarRepeats(self, fixed_size): + ''' Following numpy test suite from `test_repeat` at https://github.com/numpy/numpy/blob/main/numpy/core/tests/test_multiarray.py ''' - tol = 1e-5 - - def test_single(m, args_maker, repeats, axis): - bm_ans = bm.repeat(m, repeats, axis).value - numpy_ans = np.repeat(m, repeats, axis) - - self.assertAllClose(bm_ans, numpy_ans, rtol=tol, atol=tol) - if fixed_size: - - # Calculate expected size of the repeated axis. - rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0] - bm_fun = lambda arg, rep: bm.repeat( - arg, repeats=rep, axis=axis, total_repeat_length=rep_length) - else: - bm_fun = lambda arg: bm.repeat(arg, repeats = repeats, axis=axis) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - m = jnp.array([1,2,3,4,5,6]) - if fixed_size: - args_maker = lambda: [m, repeats] - else: - args_maker = lambda: [m] + tol = 1e-5 + + def test_single(m, args_maker, repeats, axis): + bm_ans = bm.repeat(m, repeats, axis).value + numpy_ans = np.repeat(m, repeats, axis) + + self.assertAllClose(bm_ans, numpy_ans, rtol=tol, atol=tol) + if fixed_size: + + # Calculate expected size of the repeated axis. + rep_length = np.repeat(np.zeros_like(m), repeats, axis).shape[axis or 0] + bm_fun = lambda arg, rep: bm.repeat( + arg, repeats=rep, axis=axis, total_repeat_length=rep_length) + else: + bm_fun = lambda arg: bm.repeat(arg, repeats=repeats, axis=axis) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + m = jnp.array([1, 2, 3, 4, 5, 6]) + if fixed_size: + args_maker = lambda: [m, repeats] + else: + args_maker = lambda: [m] - for repeats in [2, jnp.array([1,3,0,1,1,2]), jnp.array([1,3,2,1,1,2]), jnp.array([2])]: - test_single(m, args_maker, repeats, axis=None) - test_single(m, args_maker, repeats, axis=0) + for repeats in [2, jnp.array([1, 3, 0, 1, 1, 2]), jnp.array([1, 3, 2, 1, 1, 2]), jnp.array([2])]: + test_single(m, args_maker, repeats, axis=None) + test_single(m, args_maker, repeats, axis=0) - m_rect = m.reshape((2,3)) - if fixed_size: - args_maker = lambda: [m_rect, repeats] - else: - args_maker = lambda: [m_rect] + m_rect = m.reshape((2, 3)) + if fixed_size: + args_maker = lambda: [m_rect, repeats] + else: + args_maker = lambda: [m_rect] - for repeats in [2, jnp.array([2,1]), jnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=0) + for repeats in [2, jnp.array([2, 1]), jnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=0) - for repeats in [2, jnp.array([1,3,2]), jnp.array([2])]: - test_single(m_rect, args_maker, repeats, axis=1) + for repeats in [2, jnp.array([1, 3, 2]), jnp.array([2])]: + test_single(m_rect, args_maker, repeats, axis=1) - def testIssue2330(self): - ''' + def testIssue2330(self): + ''' Make sure return value of jnp.concatenate is a jax.ndarray and is side-effect save ''' - def attempt_sideeffect(x): - x = [x] - x = bm.concatenate(x).value - x -= 1. - return x - - np_input = np.ones((1)) - bm_input = bm.ones((1)).value - expected_np_input_after_call = np.ones((1)) - expected_bm_input_after_call = bm.ones((1)).value - - self.assertTrue(device_array.type_is_device_array(bm.concatenate([np_input]).value)) - - attempt_sideeffect(np_input) - attempt_sideeffect(bm_input) - - self.assertAllClose(np_input, expected_np_input_after_call) - self.assertAllClose(bm_input, expected_bm_input_after_call) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format( - op, - jtu.format_shape_dtype_string(xshape, dtype), - jtu.format_shape_dtype_string(yshape, dtype), - mode), - "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, - "bm_op": getattr(bm, op), - "np_op": getattr(np, op)} - for mode in ['full', 'same', 'valid'] - for op in ['convolve', 'correlate'] - for dtype in number_dtypes - for xshape in one_dim_array_shapes - for yshape in one_dim_array_shapes)) - def testConvolutions(self, xshape, yshape, dtype, mode, bm_op, np_op): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] - precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None - np_fun = partial(np_op, mode=mode) - bm_fun = partial(bm_op, mode=mode, precision=precision) - tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, - np.complex128: 1e-14} - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), axis, - out_dtype.__name__), - "axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "bm_op": getattr(bm, op), "np_op": getattr(np, op)} - for op in ["cumsum", "cumprod"] - for dtype in all_dtypes - for out_dtype in default_dtypes - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testCumSumProd(self, axis, shape, dtype, out_dtype, np_op, bm_op): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) - bm_fun = lambda arg: bm_op(arg, axis=axis, dtype=out_dtype) - bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) - - args_maker = lambda: [rng(shape, dtype)] - - tol_thresholds = {dtypes.bfloat16: 4e-2} - tol = max(jtu.tolerance(dtype, tol_thresholds), - jtu.tolerance(out_dtype, tol_thresholds)) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), axis, - out_dtype.__name__), - "axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, - "bm_op": getattr(bm, op), "np_op": getattr(np, op)} - for op in ["nancumsum", "nancumprod"] - for dtype in all_dtypes - for out_dtype in default_dtypes - for shape in all_shapes - for axis in [None] + list(range(-len(shape), len(shape))))) - def testNanCumSumProd(self, axis, shape, dtype, out_dtype, np_op, bm_op): - rng = jtu.rand_some_nan(self.rng()) - np_fun = partial(np_op, axis=axis, dtype=out_dtype) - np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) - bm_fun = partial(bm_op, axis=axis, dtype=out_dtype) - bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) - - args_maker = lambda: [rng(shape, dtype)] - - tol_thresholds = {dtypes.bfloat16: 4e-2} - tol = max(jtu.tolerance(dtype, tol_thresholds), - jtu.tolerance(out_dtype, tol_thresholds)) - if dtype != jnp.bfloat16: - # numpy functions do not properly handle bfloat16 - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_yshape={}_xshape={}_dx={}_axis={}".format( - jtu.format_shape_dtype_string(yshape, dtype), - jtu.format_shape_dtype_string(xshape, dtype) if xshape is not None else None, - dx, axis), - "yshape": yshape, "xshape": xshape, "dtype": dtype, "dx": dx, "axis": axis} + + def attempt_sideeffect(x): + x = [x] + x = bm.concatenate(x).value + x -= 1. + return x + + np_input = np.ones((1)) + bm_input = bm.ones((1)).value + expected_np_input_after_call = np.ones((1)) + expected_bm_input_after_call = bm.ones((1)).value + + self.assertTrue(device_array.type_is_device_array(bm.concatenate([np_input]).value)) + + attempt_sideeffect(np_input) + attempt_sideeffect(bm_input) + + self.assertAllClose(np_input, expected_np_input_after_call) + self.assertAllClose(bm_input, expected_bm_input_after_call) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format( + op, + jtu.format_shape_dtype_string(xshape, dtype), + jtu.format_shape_dtype_string(yshape, dtype), + mode), + "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, + "bm_op": getattr(bm, op), + "np_op": getattr(np, op)} + for mode in ['full', 'same', 'valid'] + for op in ['convolve', 'correlate'] + for dtype in number_dtypes + for xshape in one_dim_array_shapes + for yshape in one_dim_array_shapes)) + def testConvolutions(self, xshape, yshape, dtype, mode, bm_op, np_op): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] + precision = lax.Precision.HIGHEST if jtu.device_under_test() == "tpu" else None + np_fun = partial(np_op, mode=mode) + bm_fun = partial(bm_op, mode=mode, precision=precision) + tol = {np.float16: 2e-1, np.float32: 1e-2, np.float64: 1e-14, + np.complex128: 1e-14} + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), axis, + out_dtype.__name__), + "axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, + "bm_op": getattr(bm, op), "np_op": getattr(np, op)} + for op in ["cumsum", "cumprod"] + for dtype in all_dtypes + for out_dtype in default_dtypes + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape))))) + def testCumSumProd(self, axis, shape, dtype, out_dtype, np_op, bm_op): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np_op(arg, axis=axis, dtype=out_dtype) + np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + bm_fun = lambda arg: bm_op(arg, axis=axis, dtype=out_dtype) + bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) + + args_maker = lambda: [rng(shape, dtype)] + + tol_thresholds = {dtypes.bfloat16: 4e-2} + tol = max(jtu.tolerance(dtype, tol_thresholds), + jtu.tolerance(out_dtype, tol_thresholds)) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "op={}_shape=[{}]_axis={}_out_dtype={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), axis, + out_dtype.__name__), + "axis": axis, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, + "bm_op": getattr(bm, op), "np_op": getattr(np, op)} + for op in ["nancumsum", "nancumprod"] + for dtype in all_dtypes + for out_dtype in default_dtypes + for shape in all_shapes + for axis in [None] + list(range(-len(shape), len(shape))))) + def testNanCumSumProd(self, axis, shape, dtype, out_dtype, np_op, bm_op): + rng = jtu.rand_some_nan(self.rng()) + np_fun = partial(np_op, axis=axis, dtype=out_dtype) + np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun) + bm_fun = partial(bm_op, axis=axis, dtype=out_dtype) + bm_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(bm_fun) + + args_maker = lambda: [rng(shape, dtype)] + + tol_thresholds = {dtypes.bfloat16: 4e-2} + tol = max(jtu.tolerance(dtype, tol_thresholds), + jtu.tolerance(out_dtype, tol_thresholds)) + if dtype != jnp.bfloat16: + # numpy functions do not properly handle bfloat16 + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_yshape={}_xshape={}_dx={}_axis={}".format( + jtu.format_shape_dtype_string(yshape, dtype), + jtu.format_shape_dtype_string(xshape, dtype) if xshape is not None else None, + dx, axis), + "yshape": yshape, "xshape": xshape, "dtype": dtype, "dx": dx, "axis": axis} for dtype in default_dtypes for yshape, xshape, dx, axis in [ - ((10,), None, 1.0, -1), - ((3, 10), None, 2.0, -1), - ((3, 10), None, 3.0, -0), - ((10, 3), (10,), 1.0, -2), - ((3, 10), (10,), 1.0, -1), - ((3, 10), (3, 10), 1.0, -1), - ((2, 3, 10), (3, 10), 1.0, -2), + ((10,), None, 1.0, -1), + ((3, 10), None, 2.0, -1), + ((3, 10), None, 3.0, -0), + ((10, 3), (10,), 1.0, -2), + ((3, 10), (10,), 1.0, -1), + ((3, 10), (3, 10), 1.0, -1), + ((2, 3, 10), (3, 10), 1.0, -2), ])) - @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testTrapz(self, yshape, xshape, dtype, dx, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] - np_fun = partial(np.trapz, dx=dx, axis=axis) - bm_fun = partial(bm.trapz, dx=dx, axis=axis) - tol = jtu.tolerance(dtype, {np.float64: 1e-12, - dtypes.bfloat16: 4e-2}) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, tol=tol, - check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, rtol=tol, - check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_dtype={}_m={}_n={}_k={}".format( - np.dtype(dtype).name, m, n, k), - "m": m, "n": n, "k": k, "dtype": dtype} - for dtype in default_dtypes - for n in [0, 4] - for m in [None, 0, 1, 3, 4] - for k in list(range(-4, 4)))) - def testTri(self, m, n, k, dtype): - np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) - bm_fun = lambda: bm.tri(n, M=m, k=k, dtype=dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_shape={}_k={}".format( - op, jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "op": op, "k": k} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for op in ["tril", "triu"] - for k in list(range(-3, 3)))) - def testTriLU(self, dtype, shape, op, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: getattr(np, op)(arg, k=k) - bm_fun = lambda arg: getattr(bm, op)(arg, k=k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "n={}_k={}_m={}".format(n, k, m), - "n": n, "k": k, "m": m} - for n in range(1, 5) - for k in [-1, 0, 1] - for m in range(1, 5))) - def testTrilIndices(self, n, k, m): - np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) - bm_fun = lambda n, k, m: bm.tril_indices(n, k=k, m=m) - args_maker = lambda: [n, k, m] - self._CheckAgainstNumpy(np_fun, bm_fun, args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "n={}_k={}_m={}".format(n, k, m), - "n": n, "k": k, "m": m} - for n in range(1, 5) - for k in [-1, 0, 1] - for m in range(1, 5))) - def testTriuIndices(self, n, k, m): - np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) - bm_fun = lambda n, k, m: bm.triu_indices(n, k=k, m=m) - args_maker = lambda: [n, k, m] - self._CheckAgainstNumpy(np_fun, bm_fun, args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "k": k} - for dtype in default_dtypes - for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)] - for k in [-1, 0, 1])) - def testTriuIndicesFrom(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) - bm_fun = lambda arr, k: bm.triu_indices_from(arr, k=k) - args_maker = lambda: [rng(shape, dtype), k] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "k": k} - for dtype in default_dtypes - for shape in [(1,1), (1,2), (2,2), (2,3), (3,2), (3,3), (4,4)] - for k in [-1, 0, 1])) - def testTrilIndicesFrom(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) - bm_fun = lambda arr, k: bm.tril_indices_from(arr, k=k) - args_maker = lambda: [rng(shape, dtype), k] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_ndim={}_n={}".format(ndim, n), - "ndim": ndim, "n": n} - for ndim in [0, 1, 4] - for n in [0, 1, 7])) - def testDiagIndices(self, ndim, n): - np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim), - bm_func(bm.diag_indices)(n, ndim)) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "arr_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype) - ), - "dtype": dtype, "shape": shape} - for dtype in default_dtypes - for shape in [(1,1), (2,2), (3,3), (4,4), (5,5)])) - def testDiagIndicesFrom(self, dtype, shape): - rng = jtu.rand_default(self.rng()) - np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from) - bm_fun = bm.diag_indices_from - args_maker = lambda : [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "k": k} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] - for k in list(range(-4, 4)))) - def testDiag(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.diag(arg, k) - bm_fun = lambda arg: bm.diag(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_k={}".format( - jtu.format_shape_dtype_string(shape, dtype), k), - "dtype": dtype, "shape": shape, "k": k} - for dtype in default_dtypes - for shape in all_shapes - for k in range(-4, 4))) - def testDiagFlat(self, shape, dtype, k): - rng = jtu.rand_default(self.rng()) - # numpy has inconsistencies for scalar values - # https://github.com/numpy/numpy/issues/16477 - # jax differs in that it treats scalars values as length-1 arrays - np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k) - bm_fun = lambda arg: bm.diagflat(arg, k) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_a1_shape={}_a2_shape2={}".format( - jtu.format_shape_dtype_string(a1_shape, dtype), - jtu.format_shape_dtype_string(a2_shape, dtype)), - "dtype": dtype, "a1_shape": a1_shape, "a2_shape": a2_shape} - for dtype in default_dtypes - for a1_shape in one_dim_array_shapes - for a2_shape in one_dim_array_shapes)) - def testPolyMul(self, a1_shape, a2_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) - bm_fun_np = lambda arg1, arg2: bm.polymul(arg1, arg2, trim_leading_zeros=True) - bm_fun_co = lambda arg1, arg2: bm.polymul(arg1, arg2) - args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)] - tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun_np), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_fun_co), args_maker, check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "a_shape={} , b_shape={}".format( - jtu.format_shape_dtype_string(a_shape, dtype), - jtu.format_shape_dtype_string(b_shape, dtype)), - "dtype": dtype, "a_shape": a_shape, "b_shape" : b_shape} - for dtype in default_dtypes - for a_shape in one_dim_array_shapes - for b_shape in one_dim_array_shapes)) - def testPolyDiv(self, a_shape, b_shape, dtype): - rng = jtu.rand_default(self.rng()) - - @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*") - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_fun(arg1, arg2): - q, r = np.polydiv(arg1, arg2) - while r.size < max(arg1.size, arg2.size): # Pad residual to same size - r = np.pad(r, (1, 0), 'constant') - return q, r - - def bm_fun(arg1, arg2): - q, r = bm.polydiv(arg1, arg2, trim_leading_zeros=True) - while r.size < max(arg1.size, arg2.size): # Pad residual to same size - r = bm.pad(r, (1, 0), 'constant') - return q, r - - args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] - tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} - - bm_compile = bm.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_compile), args_maker, check_dtypes=True, atol=tol, rtol=tol) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( - jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), - "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, - "axis2": axis2} - for dtype in default_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in [a for a in range(-len(shape), len(shape)) - if a % len(shape) != axis1 % len(shape)] - for offset in list(range(-4, 4)))) - def testDiagonal(self, shape, dtype, offset, axis1, axis2): - rng = jtu.rand_default(self.rng()) - np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) - bm_fun = lambda arg: bm.diagonal(arg, offset, axis1, axis2) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_n={}".format(np.dtype(dtype).name, n), - "dtype": dtype, "n": n} - for dtype in default_dtypes - for n in list(range(4)))) - def testIdentity(self, n, dtype): - np_fun = lambda: np.identity(n, dtype) - bm_fun = lambda: bm.identity(n, dtype) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_period={}_left={}_right={}".format( - jtu.format_shape_dtype_string(shape, dtype), period, left, right), - "shape": shape, "dtype": dtype, - "period": period, "left": left, "right": right} - for shape in nonempty_shapes - for period in [None, 0.59] - for left in [None, 0] - for right in [None, 1] - for dtype in default_dtypes - # following types lack precision for meaningful tests - if dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16] - )) - def testInterp(self, shape, dtype, period, left, right): - rng = jtu.rand_default(self.rng(), scale=10) - kwds = dict(period=period, left=left, right=right) - np_fun = partial(np.interp, **kwds) - bm_fun = partial(bm.interp, **kwds) - args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)), np.linspace(0, 1, 20)] - - # skip numpy comparison for integer types with period specified, because numpy - # uses an unstable sort and so results differ for duplicate values. - if not (period and np.issubdtype(dtype, np.integer)): - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, tol={np.float32: 2E-4}) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_x1={}_x2={}_x1_rng={}".format( - jtu.format_shape_dtype_string(x1_shape, x1_dtype), - jtu.format_shape_dtype_string(x2_shape, np.int32), - x1_rng_factory_id), - "x1_shape": x1_shape, "x1_dtype": x1_dtype, - "x2_shape": x2_shape, "x1_rng_factory": x1_rng_factory, - "x2_rng_factory": x2_rng_factory} - for x1_rng_factory_id, x1_rng_factory in + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testTrapz(self, yshape, xshape, dtype, dx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] + np_fun = partial(np.trapz, dx=dx, axis=axis) + bm_fun = partial(bm.trapz, dx=dx, axis=axis) + tol = jtu.tolerance(dtype, {np.float64: 1e-12, + dtypes.bfloat16: 4e-2}) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, tol=tol, + check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, rtol=tol, + check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_dtype={}_m={}_n={}_k={}".format( + np.dtype(dtype).name, m, n, k), + "m": m, "n": n, "k": k, "dtype": dtype} + for dtype in default_dtypes + for n in [0, 4] + for m in [None, 0, 1, 3, 4] + for k in list(range(-4, 4)))) + def testTri(self, m, n, k, dtype): + np_fun = lambda: np.tri(n, M=m, k=k, dtype=dtype) + bm_fun = lambda: bm.tri(n, M=m, k=k, dtype=dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_shape={}_k={}".format( + op, jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "op": op, "k": k} + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for op in ["tril", "triu"] + for k in list(range(-3, 3)))) + def testTriLU(self, dtype, shape, op, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: getattr(np, op)(arg, k=k) + bm_fun = lambda arg: getattr(bm, op)(arg, k=k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "n={}_k={}_m={}".format(n, k, m), + "n": n, "k": k, "m": m} + for n in range(1, 5) + for k in [-1, 0, 1] + for m in range(1, 5))) + def testTrilIndices(self, n, k, m): + np_fun = lambda n, k, m: np.tril_indices(n, k=k, m=m) + bm_fun = lambda n, k, m: bm.tril_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, bm_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "n={}_k={}_m={}".format(n, k, m), + "n": n, "k": k, "m": m} + for n in range(1, 5) + for k in [-1, 0, 1] + for m in range(1, 5))) + def testTriuIndices(self, n, k, m): + np_fun = lambda n, k, m: np.triu_indices(n, k=k, m=m) + bm_fun = lambda n, k, m: bm.triu_indices(n, k=k, m=m) + args_maker = lambda: [n, k, m] + self._CheckAgainstNumpy(np_fun, bm_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "k": k} + for dtype in default_dtypes + for shape in [(1, 1), (1, 2), (2, 2), (2, 3), (3, 2), (3, 3), (4, 4)] + for k in [-1, 0, 1])) + def testTriuIndicesFrom(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arr, k: np.triu_indices_from(arr, k=k) + bm_fun = lambda arr, k: bm.triu_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "k": k} + for dtype in default_dtypes + for shape in [(1, 1), (1, 2), (2, 2), (2, 3), (3, 2), (3, 3), (4, 4)] + for k in [-1, 0, 1])) + def testTrilIndicesFrom(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arr, k: np.tril_indices_from(arr, k=k) + bm_fun = lambda arr, k: bm.tril_indices_from(arr, k=k) + args_maker = lambda: [rng(shape, dtype), k] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_ndim={}_n={}".format(ndim, n), + "ndim": ndim, "n": n} + for ndim in [0, 1, 4] + for n in [0, 1, 7])) + def testDiagIndices(self, ndim, n): + np.testing.assert_equal(jtu.with_jax_dtype_defaults(np.diag_indices)(n, ndim), + bm_func(bm.diag_indices)(n, ndim)) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "arr_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype) + ), + "dtype": dtype, "shape": shape} + for dtype in default_dtypes + for shape in [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)])) + def testDiagIndicesFrom(self, dtype, shape): + rng = jtu.rand_default(self.rng()) + np_fun = jtu.with_jax_dtype_defaults(np.diag_indices_from) + bm_fun = bm.diag_indices_from + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "k": k} + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] + for k in list(range(-4, 4)))) + def testDiag(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.diag(arg, k) + bm_fun = lambda arg: bm.diag(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_k={}".format( + jtu.format_shape_dtype_string(shape, dtype), k), + "dtype": dtype, "shape": shape, "k": k} + for dtype in default_dtypes + for shape in all_shapes + for k in range(-4, 4))) + def testDiagFlat(self, shape, dtype, k): + rng = jtu.rand_default(self.rng()) + # numpy has inconsistencies for scalar values + # https://github.com/numpy/numpy/issues/16477 + # jax differs in that it treats scalars values as length-1 arrays + np_fun = lambda arg: np.diagflat(np.atleast_1d(arg), k) + bm_fun = lambda arg: bm.diagflat(arg, k) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=True) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=True) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_a1_shape={}_a2_shape2={}".format( + jtu.format_shape_dtype_string(a1_shape, dtype), + jtu.format_shape_dtype_string(a2_shape, dtype)), + "dtype": dtype, "a1_shape": a1_shape, "a2_shape": a2_shape} + for dtype in default_dtypes + for a1_shape in one_dim_array_shapes + for a2_shape in one_dim_array_shapes)) + def testPolyMul(self, a1_shape, a2_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg1, arg2: np.polymul(arg1, arg2) + bm_fun_np = lambda arg1, arg2: bm.polymul(arg1, arg2, trim_leading_zeros=True) + bm_fun_co = lambda arg1, arg2: bm.polymul(arg1, arg2) + args_maker = lambda: [rng(a1_shape, dtype), rng(a2_shape, dtype)] + tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun_np), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_fun_co), args_maker, check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "a_shape={} , b_shape={}".format( + jtu.format_shape_dtype_string(a_shape, dtype), + jtu.format_shape_dtype_string(b_shape, dtype)), + "dtype": dtype, "a_shape": a_shape, "b_shape": b_shape} + for dtype in default_dtypes + for a_shape in one_dim_array_shapes + for b_shape in one_dim_array_shapes)) + def testPolyDiv(self, a_shape, b_shape, dtype): + rng = jtu.rand_default(self.rng()) + + @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero.*") + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_fun(arg1, arg2): + q, r = np.polydiv(arg1, arg2) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = np.pad(r, (1, 0), 'constant') + return q, r + + def bm_fun(arg1, arg2): + q, r = bm.polydiv(arg1, arg2, trim_leading_zeros=True) + while r.size < max(arg1.size, arg2.size): # Pad residual to same size + r = bm.pad(r, (1, 0), 'constant') + return q, r + + args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)] + tol = {np.float16: 2e-1, np.float32: 5e-2, np.float64: 1e-13} + + bm_compile = bm.polydiv # Without trim_leading_zeros (trim_zeros make it unable to be compiled by XLA) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_compile), args_maker, check_dtypes=True, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( + jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), + "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, + "axis2": axis2} + for dtype in default_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in [a for a in range(-len(shape), len(shape)) + if a % len(shape) != axis1 % len(shape)] + for offset in list(range(-4, 4)))) + def testDiagonal(self, shape, dtype, offset, axis1, axis2): + rng = jtu.rand_default(self.rng()) + np_fun = lambda arg: np.diagonal(arg, offset, axis1, axis2) + bm_fun = lambda arg: bm.diagonal(arg, offset, axis1, axis2) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_n={}".format(np.dtype(dtype).name, n), + "dtype": dtype, "n": n} + for dtype in default_dtypes + for n in list(range(4)))) + def testIdentity(self, n, dtype): + np_fun = lambda: np.identity(n, dtype) + bm_fun = lambda: bm.identity(n, dtype) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_period={}_left={}_right={}".format( + jtu.format_shape_dtype_string(shape, dtype), period, left, right), + "shape": shape, "dtype": dtype, + "period": period, "left": left, "right": right} + for shape in nonempty_shapes + for period in [None, 0.59] + for left in [None, 0] + for right in [None, 1] + for dtype in default_dtypes + # following types lack precision for meaningful tests + if dtype not in [np.int8, np.int16, np.float16, jnp.bfloat16] + )) + def testInterp(self, shape, dtype, period, left, right): + rng = jtu.rand_default(self.rng(), scale=10) + kwds = dict(period=period, left=left, right=right) + np_fun = partial(np.interp, **kwds) + bm_fun = partial(bm.interp, **kwds) + args_maker = lambda: [rng(shape, dtype), np.sort(rng((20,), dtype)), np.linspace(0, 1, 20)] + + # skip numpy comparison for integer types with period specified, because numpy + # uses an unstable sort and so results differ for duplicate values. + if not (period and np.issubdtype(dtype, np.integer)): + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, tol={np.float32: 2E-4}) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_x1={}_x2={}_x1_rng={}".format( + jtu.format_shape_dtype_string(x1_shape, x1_dtype), + jtu.format_shape_dtype_string(x2_shape, np.int32), + x1_rng_factory_id), + "x1_shape": x1_shape, "x1_dtype": x1_dtype, + "x2_shape": x2_shape, "x1_rng_factory": x1_rng_factory, + "x2_rng_factory": x2_rng_factory} + for x1_rng_factory_id, x1_rng_factory in enumerate([jtu.rand_some_inf_and_nan, jtu.rand_some_zero]) - for x2_rng_factory in [partial(jtu.rand_int, low=-1075, high=1024)] - for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(array_shapes, 2)) - for x1_dtype in default_dtypes)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory): - # integer types are converted to float64 in numpy's implementation - if (x1_dtype not in [jnp.bfloat16, np.float16, np.float32] - and not config.x64_enabled): - self.skipTest("Only run float64 testcase when float64 is enabled.") - x1_rng = x1_rng_factory(self.rng()) - x2_rng = x2_rng_factory(self.rng()) - np_fun = lambda x1, x2: np.ldexp(x1, x2) - np_fun = jtu.ignore_warning(category=RuntimeWarning, - message="overflow.*")(np_fun) - bm_fun = lambda x1, x2: bm.ldexp(x1, x2) - args_maker = lambda: [x1_rng(x1_shape, x1_dtype), - x2_rng(x2_shape, np.int32)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_x={}_rng_factory={}".format( - jtu.format_shape_dtype_string(shape, dtype), rng_factory_id), - "shape": shape, "dtype": dtype, "rng_factory": rng_factory} - for rng_factory_id, rng_factory in enumerate([ - jtu.rand_some_inf_and_nan, - jtu.rand_some_zero, - partial(jtu.rand_not_small, offset=1e8), - ]) - for shape in all_shapes - for dtype in default_dtypes)) - def testFrexp(self, shape, dtype, rng_factory): - # integer types are converted to float64 in numpy's implementation - if (dtype not in [jnp.bfloat16, np.float16, np.float32] - and not config.x64_enabled): - self.skipTest("Only run float64 testcase when float64 is enabled.") - rng = rng_factory(self.rng()) - np_fun = lambda x: np.frexp(x) - bm_fun = lambda x: bm.frexp(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=np.issubdtype(dtype, np.inexact)) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( - jtu.format_shape_dtype_string(shape, dtype), - out_dtype, offset, axis1, axis2), - "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset, - "axis1": axis1, "axis2": axis2} - for dtype in default_dtypes - for out_dtype in [None] + number_dtypes - for shape in [shape for shape in all_shapes if len(shape) >= 2] - for axis1 in range(-len(shape), len(shape)) - for axis2 in range(-len(shape), len(shape)) - if (axis1 % len(shape)) != (axis2 % len(shape)) - for offset in list(range(-4, 4)))) - def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2): - rng = jtu.rand_default(self.rng()) - def np_fun(arg): - if out_dtype == jnp.bfloat16: - return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16) - else: - return np.trace(arg, offset, axis1, axis2, out_dtype) - bm_fun = lambda arg: bm.trace(arg, offset, axis1, axis2, out_dtype) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_a={}_v={}_side={}".format( - jtu.format_shape_dtype_string(ashape, dtype), - jtu.format_shape_dtype_string(vshape, dtype), - side), "ashape": ashape, "vshape": vshape, "side": side, - "dtype": dtype} - for ashape in [(15,), (16,), (17,)] - for vshape in [(), (5,), (5, 5)] - for side in ['left', 'right'] - for dtype in number_dtypes - )) - def testSearchsorted(self, ashape, vshape, side, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] - np_fun = lambda a, v: np.searchsorted(a, v, side=side) - bm_fun = lambda a, v: bm.searchsorted(a, v, side=side) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={dtype.__name__}_side={side}", "dtype": dtype, "side": side} - for dtype in inexact_dtypes - for side in ['left', 'right'])) - def testSearchsortedNans(self, dtype, side): - if np.issubdtype(dtype, np.complexfloating): - raise SkipTest("Known failure for complex inputs; see #9107") - x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) - # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be - # equivalent to argsorting the following: - x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) - - if jnp.issubdtype(dtype, jnp.complexfloating): - x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)]) - x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)]) - - bm_fun = partial(bm.searchsorted, side=side) - self.assertArraysEqual(bm_func(bm_fun)(x, x), bm_func(bm_fun)(x_equiv, x_equiv)) - self.assertArraysEqual(jax.jit(bm_func(bm_fun))(x, x), bm_func(bm_fun)(x_equiv, x_equiv)) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_x={}_bins={}_right={}_reverse={}".format( - jtu.format_shape_dtype_string(xshape, dtype), - jtu.format_shape_dtype_string(binshape, dtype), - right, reverse), "xshape": xshape, "binshape": binshape, - "right": right, "reverse": reverse, "dtype": dtype} - for xshape in [(20,), (5, 4)] - for binshape in [(1,), (5,)] - for right in [True, False] - for reverse in [True, False] - for dtype in default_dtypes - )) - def testDigitize(self, xshape, binshape, right, reverse, dtype): - order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(xshape, dtype), bm.sort(rng(binshape, dtype))[order]] - np_fun = lambda x, bins: np.digitize(x, bins, right=right) - bm_fun = lambda x, bins: bm.digitize(x, bins, right=right) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_array={}".format( - jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input), - "shape": shape, "dtypes": dtypes, "array_input": array_input} - for dtypes in [ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ] - for shape in [(), (2,), (3, 4), (1, 5)] - for array_input in [True, False])) - def testColumnStack(self, shape, dtypes, array_input): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - np_fun = _promote_like_jnp(np.column_stack) - bm_fun = bm.column_stack - self._CheckAgainstNumpy(bm_func(bm_fun), np_fun, args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_array={}".format( - jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input), - "shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input} - for dtypes in [ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ] - for shape in [(), (2,), (3, 4), (1, 100)] - for axis in range(-len(shape), len(shape) + 1) - for array_input in [True, False])) - def testStack(self, shape, axis, dtypes, array_input): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - np_fun = _promote_like_jnp(partial(np.stack, axis=axis)) - bm_fun = partial(bm.stack, axis=axis) - self._CheckAgainstNumpy(bm_func(bm_fun), np_fun, args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_{}_array={}".format( - op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input), - "shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input} - for op in ["hstack", "vstack", "dstack"] - for dtypes in [ - [np.float32], - [np.float32, np.float32], - [np.float32, np.int32, np.float32], - [np.float32, np.int64, np.float32], - [np.float32, np.int32, np.float64], - ] - for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)] - for array_input in [True, False])) - def testHVDStack(self, shape, op, dtypes, array_input): - rng = jtu.rand_default(self.rng()) - if array_input: - args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] - else: - args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] - np_fun = _promote_like_jnp(getattr(np, op)) - bm_fun = getattr(bm, op) - self._CheckAgainstNumpy(bm_func(bm_fun), np_fun, args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outdtype={}_fillshape={}".format( - jtu.format_shape_dtype_string(shape, fill_value_dtype), - np.dtype(out_dtype).name if out_dtype else "None", - fill_value_shape), - "fill_value_dtype": fill_value_dtype, "fill_value_shape": fill_value_shape, - "shape": shape, "out_dtype": out_dtype} - for shape in array_shapes + [3, np.array(7, dtype=np.int32)] - for fill_value_dtype in default_dtypes - for fill_value_shape in _compatible_shapes(shape) - for out_dtype in [None] + default_dtypes)) - def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) - bm_fun = lambda fill_value: bm.full(shape, fill_value, dtype=out_dtype) - args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "_shape={}_n={}_axis={}_prepend={}_append={}".format( - jtu.format_shape_dtype_string(shape, dtype), - n, axis, prepend, append), + for x2_rng_factory in [partial(jtu.rand_int, low=-1075, high=1024)] + for x1_shape, x2_shape in filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(array_shapes, 2)) + for x1_dtype in default_dtypes)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLdexp(self, x1_shape, x1_dtype, x2_shape, x1_rng_factory, x2_rng_factory): + # integer types are converted to float64 in numpy's implementation + if (x1_dtype not in [jnp.bfloat16, np.float16, np.float32] + and not config.x64_enabled): + self.skipTest("Only run float64 testcase when float64 is enabled.") + x1_rng = x1_rng_factory(self.rng()) + x2_rng = x2_rng_factory(self.rng()) + np_fun = lambda x1, x2: np.ldexp(x1, x2) + np_fun = jtu.ignore_warning(category=RuntimeWarning, + message="overflow.*")(np_fun) + bm_fun = lambda x1, x2: bm.ldexp(x1, x2) + args_maker = lambda: [x1_rng(x1_shape, x1_dtype), + x2_rng(x2_shape, np.int32)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_x={}_rng_factory={}".format( + jtu.format_shape_dtype_string(shape, dtype), rng_factory_id), + "shape": shape, "dtype": dtype, "rng_factory": rng_factory} + for rng_factory_id, rng_factory in enumerate([ + jtu.rand_some_inf_and_nan, + jtu.rand_some_zero, + partial(jtu.rand_not_small, offset=1e8), + ]) + for shape in all_shapes + for dtype in default_dtypes)) + def testFrexp(self, shape, dtype, rng_factory): + # integer types are converted to float64 in numpy's implementation + if (dtype not in [jnp.bfloat16, np.float16, np.float32] + and not config.x64_enabled): + self.skipTest("Only run float64 testcase when float64 is enabled.") + rng = rng_factory(self.rng()) + np_fun = lambda x: np.frexp(x) + bm_fun = lambda x: bm.frexp(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=np.issubdtype(dtype, np.inexact)) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( + jtu.format_shape_dtype_string(shape, dtype), + out_dtype, offset, axis1, axis2), + "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset, + "axis1": axis1, "axis2": axis2} + for dtype in default_dtypes + for out_dtype in [None] + number_dtypes + for shape in [shape for shape in all_shapes if len(shape) >= 2] + for axis1 in range(-len(shape), len(shape)) + for axis2 in range(-len(shape), len(shape)) + if (axis1 % len(shape)) != (axis2 % len(shape)) + for offset in list(range(-4, 4)))) + def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2): + rng = jtu.rand_default(self.rng()) + + def np_fun(arg): + if out_dtype == jnp.bfloat16: + return np.trace(arg, offset, axis1, axis2, np.float32).astype(jnp.bfloat16) + else: + return np.trace(arg, offset, axis1, axis2, out_dtype) + + bm_fun = lambda arg: bm.trace(arg, offset, axis1, axis2, out_dtype) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_a={}_v={}_side={}".format( + jtu.format_shape_dtype_string(ashape, dtype), + jtu.format_shape_dtype_string(vshape, dtype), + side), "ashape": ashape, "vshape": vshape, "side": side, + "dtype": dtype} + for ashape in [(15,), (16,), (17,)] + for vshape in [(), (5,), (5, 5)] + for side in ['left', 'right'] + for dtype in number_dtypes + )) + def testSearchsorted(self, ashape, vshape, side, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [np.sort(rng(ashape, dtype)), rng(vshape, dtype)] + np_fun = lambda a, v: np.searchsorted(a, v, side=side) + bm_fun = lambda a, v: bm.searchsorted(a, v, side=side) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"_dtype={dtype.__name__}_side={side}", "dtype": dtype, "side": side} + for dtype in inexact_dtypes + for side in ['left', 'right'])) + def testSearchsortedNans(self, dtype, side): + if np.issubdtype(dtype, np.complexfloating): + raise SkipTest("Known failure for complex inputs; see #9107") + x = np.array([-np.inf, -1.0, 0.0, -0.0, 1.0, np.inf, np.nan, -np.nan], dtype=dtype) + # The sign bit should not matter for 0.0 or NaN, so argsorting the above should be + # equivalent to argsorting the following: + x_equiv = np.array([0, 1, 2, 2, 3, 4, 5, 5]) + + if jnp.issubdtype(dtype, jnp.complexfloating): + x = np.array([complex(r, c) for r, c in itertools.product(x, repeat=2)]) + x_equiv = np.array([complex(r, c) for r, c in itertools.product(x_equiv, repeat=2)]) + + bm_fun = partial(bm.searchsorted, side=side) + self.assertArraysEqual(bm_func(bm_fun)(x, x), bm_func(bm_fun)(x_equiv, x_equiv)) + self.assertArraysEqual(jax.jit(bm_func(bm_fun))(x, x), bm_func(bm_fun)(x_equiv, x_equiv)) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_x={}_bins={}_right={}_reverse={}".format( + jtu.format_shape_dtype_string(xshape, dtype), + jtu.format_shape_dtype_string(binshape, dtype), + right, reverse), "xshape": xshape, "binshape": binshape, + "right": right, "reverse": reverse, "dtype": dtype} + for xshape in [(20,), (5, 4)] + for binshape in [(1,), (5,)] + for right in [True, False] + for reverse in [True, False] + for dtype in default_dtypes + )) + def testDigitize(self, xshape, binshape, right, reverse, dtype): + order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), bm.sort(rng(binshape, dtype))[order]] + np_fun = lambda x, bins: np.digitize(x, bins, right=right) + bm_fun = lambda x, bins: bm.digitize(x, bins, right=right) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_array={}".format( + jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input), + "shape": shape, "dtypes": dtypes, "array_input": array_input} + for dtypes in [ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ] + for shape in [(), (2,), (3, 4), (1, 5)] + for array_input in [True, False])) + def testColumnStack(self, shape, dtypes, array_input): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + np_fun = _promote_like_jnp(np.column_stack) + bm_fun = bm.column_stack + self._CheckAgainstNumpy(bm_func(bm_fun), np_fun, args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_array={}".format( + jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), axis, array_input), + "shape": shape, "axis": axis, "dtypes": dtypes, "array_input": array_input} + for dtypes in [ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ] + for shape in [(), (2,), (3, 4), (1, 100)] + for axis in range(-len(shape), len(shape) + 1) + for array_input in [True, False])) + def testStack(self, shape, axis, dtypes, array_input): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + np_fun = _promote_like_jnp(partial(np.stack, axis=axis)) + bm_fun = partial(bm.stack, axis=axis) + self._CheckAgainstNumpy(bm_func(bm_fun), np_fun, args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_{}_array={}".format( + op, jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes), array_input), + "shape": shape, "op": op, "dtypes": dtypes, "array_input": array_input} + for op in ["hstack", "vstack", "dstack"] + for dtypes in [ + [np.float32], + [np.float32, np.float32], + [np.float32, np.int32, np.float32], + [np.float32, np.int64, np.float32], + [np.float32, np.int32, np.float64], + ] + for shape in [(), (2,), (3, 4), (1, 100), (2, 3, 4)] + for array_input in [True, False])) + def testHVDStack(self, shape, op, dtypes, array_input): + rng = jtu.rand_default(self.rng()) + if array_input: + args_maker = lambda: [np.array([rng(shape, dtype) for dtype in dtypes])] + else: + args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] + np_fun = _promote_like_jnp(getattr(np, op)) + bm_fun = getattr(bm, op) + self._CheckAgainstNumpy(bm_func(bm_fun), np_fun, args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outdtype={}_fillshape={}".format( + jtu.format_shape_dtype_string(shape, fill_value_dtype), + np.dtype(out_dtype).name if out_dtype else "None", + fill_value_shape), + "fill_value_dtype": fill_value_dtype, "fill_value_shape": fill_value_shape, + "shape": shape, "out_dtype": out_dtype} + for shape in array_shapes + [3, np.array(7, dtype=np.int32)] + for fill_value_dtype in default_dtypes + for fill_value_shape in _compatible_shapes(shape) + for out_dtype in [None] + default_dtypes)) + def testFull(self, shape, fill_value_dtype, fill_value_shape, out_dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda fill_value: np.full(shape, fill_value, dtype=out_dtype) + bm_fun = lambda fill_value: bm.full(shape, fill_value, dtype=out_dtype) + args_maker = lambda: [rng(fill_value_shape, fill_value_dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "_shape={}_n={}_axis={}_prepend={}_append={}".format( + jtu.format_shape_dtype_string(shape, dtype), + n, axis, prepend, append), "shape": shape, "dtype": dtype, "n": n, "axis": axis, "prepend": prepend, "append": append } for shape, dtype in s(_shape_and_dtypes(nonempty_nonscalar_array_shapes, default_dtypes)) - for n in s([0, 1, 2]) - for axis in s(list(range(-len(shape), max(1, len(shape))))) - for prepend in s([None, 1, np.zeros(shape, dtype=dtype)]) - for append in s([None, 1, np.zeros(shape, dtype=dtype)]) - ))) - def testDiff(self, shape, dtype, n, axis, prepend, append): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - - def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): - if prepend is None: - prepend = np._NoValue - elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16: - prepend = prepend.astype(np.float32) - - if append is None: - append = np._NoValue - elif not np.isscalar(append) and append.dtype == jnp.bfloat16: - append = append.astype(np.float32) - - if x.dtype == jnp.bfloat16: - return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype(jnp.bfloat16) - else: - return np.diff(x, n=n, axis=axis, prepend=prepend, append=append) - - bm_fun = lambda x: bm.diff(x, n=n, axis=axis, prepend=prepend, append=append) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype), - "np_op": getattr(np, op), "bm_op": getattr(bm, op), - "shape": shape, "dtype": dtype} - for op in ["zeros", "ones"] - for shape in [2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32), - np.array(4, dtype=np.int32)] - for dtype in all_dtypes)) - def testZerosOnes(self, np_op, bm_op, shape, dtype): - args_maker = lambda: [] - np_op = partial(np_op, shape, dtype) - bm_op = partial(bm_op, shape, dtype) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - def testOnesWithInvalidShape(self): - with self.assertRaises(TypeError): - bm.ones((-1, 1)) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format( - jtu.format_shape_dtype_string(shape, in_dtype), - np.dtype(fill_value_dtype).name, fill_value_shape, - np.dtype(out_dtype).name, out_shape), - "shape": shape, "in_dtype": in_dtype, - "fill_value_dtype": fill_value_dtype, "fill_value_shape": fill_value_shape, - "out_dtype": out_dtype, "out_shape": out_shape + for n in s([0, 1, 2]) + for axis in s(list(range(-len(shape), max(1, len(shape))))) + for prepend in s([None, 1, np.zeros(shape, dtype=dtype)]) + for append in s([None, 1, np.zeros(shape, dtype=dtype)]) + ))) + def testDiff(self, shape, dtype, n, axis, prepend, append): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): + if prepend is None: + prepend = np._NoValue + elif not np.isscalar(prepend) and prepend.dtype == jnp.bfloat16: + prepend = prepend.astype(np.float32) + + if append is None: + append = np._NoValue + elif not np.isscalar(append) and append.dtype == jnp.bfloat16: + append = append.astype(np.float32) + + if x.dtype == jnp.bfloat16: + return np.diff(x.astype(np.float32), n=n, axis=axis, prepend=prepend, append=append).astype( + jnp.bfloat16) + else: + return np.diff(x, n=n, axis=axis, prepend=prepend, append=append) + + bm_fun = lambda x: bm.diff(x, n=n, axis=axis, prepend=prepend, append=append) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": ("_op={}_shape={}_dtype={}").format(op, shape, dtype), + "np_op": getattr(np, op), "bm_op": getattr(bm, op), + "shape": shape, "dtype": dtype} + for op in ["zeros", "ones"] + for shape in [2, (), (2,), (3, 0), np.array((4, 5, 6), dtype=np.int32), + np.array(4, dtype=np.int32)] + for dtype in all_dtypes)) + def testZerosOnes(self, np_op, bm_op, shape, dtype): + args_maker = lambda: [] + np_op = partial(np_op, shape, dtype) + bm_op = partial(bm_op, shape, dtype) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + def testOnesWithInvalidShape(self): + with self.assertRaises(TypeError): + bm.ones((-1, 1)) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "_inshape={}_filldtype={}_fillshape={}_outdtype={}_outshape={}".format( + jtu.format_shape_dtype_string(shape, in_dtype), + np.dtype(fill_value_dtype).name, fill_value_shape, + np.dtype(out_dtype).name, out_shape), + "shape": shape, "in_dtype": in_dtype, + "fill_value_dtype": fill_value_dtype, "fill_value_shape": fill_value_shape, + "out_dtype": out_dtype, "out_shape": out_shape } for shape in s(array_shapes) - for out_shape in s([None] + array_shapes) - for in_dtype in s(default_dtypes) - for fill_value_dtype in s(default_dtypes) - for fill_value_shape in s(_compatible_shapes(shape if out_shape is None else out_shape)) - for out_dtype in s(default_dtypes)))) - def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape): - if numpy_version < (1, 19) and out_shape == (): - raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None") - rng = jtu.rand_default(self.rng()) - np_fun = lambda x, fill_value: np.full_like( - x, fill_value, dtype=out_dtype, shape=out_shape) - bm_fun = lambda x, fill_value: bm.full_like( - x, fill_value, dtype=out_dtype, shape=out_shape) - args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format( - func, jtu.format_shape_dtype_string(shape, in_dtype), - out_shape, out_dtype), - "func": func, "shape": shape, "in_dtype": in_dtype, - "out_shape": out_shape, "out_dtype": out_dtype} - for shape in array_shapes - for out_shape in [None] + array_shapes - for in_dtype in default_dtypes - for func in ["ones_like", "zeros_like"] - for out_dtype in default_dtypes)) - def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): - if numpy_version < (1, 19) and out_shape == (): - raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None") - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape) - bm_fun = lambda x: getattr(bm, func)(x, dtype=out_dtype, shape=out_shape) - args_maker = lambda: [rng(shape, in_dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format( - func, jtu.format_shape_dtype_string(shape, in_dtype), - weak_type, out_shape, out_dtype), - "func": func, "args": args, - "shape": shape, "in_dtype": in_dtype, "weak_type": weak_type, - "out_shape": out_shape, "out_dtype": out_dtype} - for shape in array_shapes - for in_dtype in [np.int32, np.float32, np.complex64] - for weak_type in [True, False] - for out_shape in [None, (), (10,)] - for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())] - for out_dtype in [None, float])) - def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype): - if numpy_version < (1, 19) and out_shape == (): - raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None") - rng = jtu.rand_default(self.rng()) - x = lax_internal._convert_element_type(rng(shape, in_dtype), - weak_type=weak_type) - fun = lambda x: getattr(bm, func)(x, *args, dtype=out_dtype, shape=out_shape) - expected_weak_type = weak_type and (out_dtype is None) - self.assertEqual(dtypes.is_weakly_typed(bm_func(fun)(x)), expected_weak_type) - self.assertEqual(dtypes.is_weakly_typed(jax.jit(bm_func(fun))(x)), expected_weak_type) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_funcname={}_input_type={}_val={}_dtype={}".format( - funcname, input_type, val, dtype), - "funcname": funcname, "input_type": input_type, "val": val, "dtype": dtype} - for funcname in ["array", "asarray"] - for dtype in [int, float, None] - for val in [0, 1] - for input_type in [int, float, np.int32, np.float32])) - def testArrayWeakType(self, funcname, input_type, val, dtype): - bm_fun = lambda x: getattr(bm, funcname)(x, dtype=dtype) - fjit = jax.jit(bm_func(bm_fun)) - val = input_type(val) - expected_weak_type = dtype is None and input_type in set(dtypes._weak_types) - self.assertEqual(dtypes.is_weakly_typed(bm_func(bm_fun)(val)), expected_weak_type) - self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_weak_type={}_slc={}".format( - jtu.format_shape_dtype_string(shape, dtype), weak_type, slc), - "shape": shape, "dtype": dtype, "weak_type": weak_type, "slc": slc} - for shape in nonempty_nonscalar_array_shapes - for dtype in [int, float, complex] - for weak_type in [True, False] - for slc in [slice(None), slice(0), slice(3), 0, ...])) - def testSliceWeakTypes(self, shape, dtype, weak_type, slc): - rng = jtu.rand_default(self.rng()) - x = lax_internal._convert_element_type(rng(shape, dtype), - weak_type=weak_type) - op = lambda x: x[slc] - self.assertEqual(op(x).aval.weak_type, weak_type) - self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, - "dtype": dtype} - for shape, axis, num_sections in [ - ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] - for dtype in default_dtypes)) - def testSplitStaticInt(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.split(x, num_sections, axis=axis) - bm_fun = lambda x: bm.split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype} - # All testcases split the specified axis unequally - for shape, axis, num_sections in [ - ((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3), - ((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)] - for dtype in default_dtypes)) - def testArraySplitStaticInt(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.array_split(x, num_sections, axis=axis) - bm_fun = lambda x: bm.array_split(x, num_sections, axis=axis) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testSplitTypeError(self): - # If we pass an ndarray for indices_or_sections -> no error - self.assertEqual(3, len(bm_func(bm.split)(bm.zeros(3), bm.array([1, 2])))) - - CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected." - with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): - # An abstract tracer for idx - jax.jit(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), idx))(2.) - with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): - # A list including an abstract tracer - jax.jit(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), [2, idx]))(2.) - - # A concrete tracer -> no error - jax.jvp(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), idx), - (2.,), (1.,)) - # A tuple including a concrete tracer -> no error - jax.jvp(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), (1, idx)), - (2.,), (1.,)) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_bins={}_range={}_weights={}".format( - jtu.format_shape_dtype_string(shape, dtype), bins, range, weights), - "shape": shape, - "dtype": dtype, - "bins": bins, - "range": range, - "weights": weights, - } - for shape in [(5,), (5, 5)] - for dtype in number_dtypes - for bins in [10, np.arange(-5, 6), np.array([-5, 0, 3])] - for range in [None, (0, 0), (0, 10)] - for weights in [True, False] - )) - def testHistogramBinEdges(self, shape, dtype, bins, range, weights): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r, - weights=_weights(w)) - bm_fun = lambda a, w, r: bm.histogram_bin_edges(a, bins=bins, range=r, - weights=_weights(w)) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2} - # linspace() compares poorly to numpy when using bfloat16 - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, - atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_bins={}_density={}_weights={}".format( - jtu.format_shape_dtype_string(shape, dtype), bins, density, weights), - "shape": shape, - "dtype": dtype, - "bins": bins, - "density": density, - "weights": weights, - } - for shape in [(5,), (5, 5)] - for dtype in default_dtypes - # We only test explicit integer-valued bin edges because in other cases - # rounding errors lead to flaky tests. - for bins in [np.arange(-5, 6), np.array([-5, 0, 3])] - for density in [True, False] - for weights in [True, False] - )) - def testHistogram(self, shape, dtype, bins, density, weights): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = lambda a, w: np.histogram(a, bins=bins, density=density, - weights=_weights(w)) - bm_fun = lambda a, w: bm.histogram(a, bins=bins, density=density, - weights=_weights(w)) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_bins={}_weights={}_density={}_range={}".format( - jtu.format_shape_dtype_string(shape, dtype), bins, weights, density, range), - "shape": shape, "dtype": dtype, "bins": bins, "weights": weights, "density": density, "range": range, - } - for shape in [(5,), (12,)] - for dtype in int_dtypes - for bins in [2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]] - for weights in [False, True] - for density in [False, True] - for range in [None, [(-1, 1), None], [(-1, 1), (-2, 2)]] - )) - def testHistogram2d(self, shape, dtype, bins, weights, density, range): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( - lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)) - bm_fun = lambda a, b, w: bm.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - with np.errstate(divide='ignore', invalid='ignore'): - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_bins={}_weights={}_density={}_range={}".format( - jtu.format_shape_dtype_string(shape, dtype), bins, weights, density, range), - "shape": shape, "dtype": dtype, "bins": bins, "weights": weights, "density": density, "range": range, - } - for shape in [(5, 3), (10, 3)] - for dtype in int_dtypes - for bins in [(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]] - for weights in [False, True] - for density in [False, True] - for range in [None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]] - )) - def testHistogramdd(self, shape, dtype, bins, weights, density, range): - rng = jtu.rand_default(self.rng()) - _weights = lambda w: abs(w) if weights else None - np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( - lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)) - bm_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range) - args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)] - tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} - # np.searchsorted errors on bfloat16 with - # "TypeError: invalid type promotion with custom data type" - if dtype != jnp.bfloat16: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_{}sections".format( - jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), - "shape": shape, "num_sections": num_sections, "axis": axis, - "dtype": dtype} - for shape, axis, num_sections in [ - ((12, 4), 0, 4), ((12, 4), 1, 2), - ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)] - for dtype in default_dtypes)) - def testHVDSplit(self, shape, num_sections, axis, dtype): - rng = jtu.rand_default(self.rng()) - def fn(module, axis): - if axis == 0: - return module.vsplit - elif axis == 1: - return module.hsplit - else: - assert axis == 2 - return module.dsplit - - np_fun = lambda x: fn(np, axis)(x, num_sections) - bm_fun = lambda x: fn(bm, axis)(x, num_sections) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}_order={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype), - order), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, - "order": order} - for dtype in default_dtypes - for order in ["C", "F"] - for arg_shape, out_shape in [ - (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), - ((), (1, 1, 1)), - ((7, 0), (0, 42, 101)), - ((3, 4), 12), - ((3, 4), (12,)), - ((3, 4), -1), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ])) - def testReshape(self, arg_shape, out_shape, dtype, order): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.reshape(x, out_shape, order=order) - bm_fun = lambda x: bm.reshape(x, out_shape, order=order) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype)), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype} - for dtype in default_dtypes - for arg_shape, out_shape in [ - ((7, 0), (0, 42, 101)), - ((2, 1, 4), (-1,)), - ((2, 2, 4), (2, 8)) - ])) - def testReshapeMethod(self, arg_shape, out_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.reshape(x, out_shape) - bm_fun = lambda x: bm.reshape(x, out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_outshape={}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), - jtu.format_shape_dtype_string(out_shape, dtype)), - "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype} - for dtype in default_dtypes - for arg_shape, out_shape in itertools.product(all_shapes, array_shapes))) - def testResize(self, arg_shape, out_shape, dtype): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.resize(x, out_shape) - bm_fun = lambda x: bm.resize(x, out_shape) - args_maker = lambda: [rng(arg_shape, dtype)] - if len(out_shape) > 0 or numpy_version >= (1, 20, 0): - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_expanddim={!r}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), dim), - "arg_shape": arg_shape, "dtype": dtype, "dim": dim} - for arg_shape in [(), (3,), (3, 4)] - for dtype in default_dtypes - for dim in (list(range(-len(arg_shape)+1, len(arg_shape))) - + [np.array(0), np.array(-1), (0,), [np.array(0)], - (len(arg_shape), len(arg_shape) + 1)]))) - def testExpandDimsStaticDim(self, arg_shape, dtype, dim): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.expand_dims(x, dim) - bm_fun = lambda x: bm.expand_dims(x, dim) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CompileAndCheck(bm_func(bm_fun), args_maker) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - - def testExpandDimsRepeatedAxisError(self): - x = bm.ones((2, 3)) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: bm.expand_dims(x, [1, 1])) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: bm.expand_dims(x, [3, -1])) - - # ensure this is numpy's behavior too, so that we remain consistent - x = np.ones((2, 3)) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: np.expand_dims(x, [1, 1])) - self.assertRaisesRegex( - ValueError, 'repeated axis.*', - lambda: np.expand_dims(x, [3, -1])) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_axes=({},{})".format( - jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), - "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2} - for arg_shape, ax1, ax2 in [ - ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), - ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] - for dtype in default_dtypes)) - def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.swapaxes(x, ax1, ax2) - bm_fun = lambda x: bm.swapaxes(x, ax1, ax2) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_inshape={}_axis={!r}".format( - jtu.format_shape_dtype_string(arg_shape, dtype), ax), - "arg_shape": arg_shape, "dtype": dtype, "ax": ax} - for arg_shape, ax in [ - ((3, 1), None), - ((3, 1), 1), - ((3, 1), -1), - ((3, 1), np.array(1)), - ((1, 3, 1), (0, 2)), - ((1, 3, 1), (0,)), - ((1, 4, 1), (np.array(0),))] - for dtype in default_dtypes)) - def testSqueeze(self, arg_shape, dtype, ax): - rng = jtu.rand_default(self.rng()) - np_fun = lambda x: np.squeeze(x, ax) - bm_fun = lambda x: bm.squeeze(x, ax) - args_maker = lambda: [rng(arg_shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( - jtu.format_shape_dtype_string(shape, dtype), - axis, - (None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)), - returned), - "shape": shape, "dtype": dtype, "axis": axis, - "weights_shape": weights_shape, "returned": returned} - for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes) - for axis in list(range(-len(shape), len(shape))) + [None] - # `weights_shape` is either `None`, same as the averaged axis, or same as - # that of the input - for weights_shape in ([None, shape] if axis is None or len(shape) == 1 - else [None, (shape[axis],), shape]) - for returned in [False, True])) - def testAverage(self, shape, dtype, axis, weights_shape, returned): - rng = jtu.rand_default(self.rng()) - if weights_shape is None: - np_fun = lambda x: np.average(x, axis, returned=returned) - bm_fun = lambda x: bm.average(x, axis, returned=returned) - args_maker = lambda: [rng(shape, dtype)] - else: - np_fun = lambda x, weights: np.average(x, axis, weights, returned) - bm_fun = lambda x, weights: bm.average(x, axis, weights, returned) - args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] - np_fun = _promote_like_jnp(np_fun, inexact=True) - tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, - np.float64: 1e-12, np.complex64: 1e-5} - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - try: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - check_dtypes=check_dtypes, tol=tol) - except ZeroDivisionError: - self.skipTest("don't support checking for ZeroDivisionError") - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=check_dtypes, - rtol=tol, atol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": - f"_arg{i}_ndmin={ndmin}_dtype={np.dtype(dtype) if dtype else None}", - "arg": arg, "ndmin": ndmin, "dtype": dtype} - for i, (arg, dtypes) in enumerate([ - ([True, False, True], all_dtypes), - (3., all_dtypes), - ([1, 2, 3], all_dtypes), - (np.array([1, 2, 3], dtype=np.int64), all_dtypes), - ([1., 2., 3.], all_dtypes), - ([[1, 2], [3, 4], [5, 6]], all_dtypes), - ([[1, 2.], [3, 4], [5, 6]], all_dtypes), - ([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes), - ([[3, np.array(2, dtype=bm.float_), 1], - np.arange(3., dtype=bm.float_)], all_dtypes), - ]) - for dtype in [None] + dtypes - for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2])) - def testArray(self, arg, ndmin, dtype): - args_maker = lambda: [arg] - canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype) - if ndmin is not None: - np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype) - bm_fun = partial(bm.array, ndmin=ndmin, dtype=dtype) - else: - np_fun = partial(np.array, dtype=canonical_dtype) - bm_fun = partial(bm.array, dtype=dtype) - - # We are testing correct canonicalization behavior here, so we turn off the - # permissive canonicalization logic in the test harness. - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - canonicalize_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") - def testArrayDtypeInference(self): - def _check(obj, out_dtype, weak_type): - dtype_reference = np.array(obj, dtype=out_dtype) - - out = bm_func(bm.array)(obj) - self.assertDtypesMatch(out, dtype_reference) - self.assertEqual(dtypes.is_weakly_typed(out), weak_type) - - out_jit = jax.jit(bm_func(bm.array))(obj) - self.assertDtypesMatch(out_jit, dtype_reference) - self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type) - - # Python scalars become 64-bit weak types. - _check(1, np.int64, True) - _check(1.0, np.float64, True) - _check(1.0j, np.complex128, True) - - # Lists become strongly-typed defaults. - _check([1], bm.int_, False) - _check([1.0], bm.float_, False) - _check([1.0j], bm.complex_, False) - - # Lists of weakly-typed objects become strongly-typed defaults. - _check([bm.array(1).value], bm.int_, False) - _check([bm.array(1.0).value], bm.float_, False) - _check([bm.array(1.0j).value], bm.complex_, False) - - # Lists of strongly-typed objects maintain their strong type. - _check([bm.int64(1)], np.int64, False) - _check([bm.float64(1)], np.float64, False) - _check([bm.complex128(1)], np.complex128, False) - - # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) - _check([0, np.int16(1)], np.int16, False) - _check([0.0, np.float16(1)], np.float16, False) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}", - "dtype": dtype, "func": func} - for dtype in all_dtypes - for func in ["array", "copy"])) - def testArrayCopy(self, dtype, func): - x = bm_func(bm.ones)(10, dtype=dtype) - copy_func = getattr(bm, func) - - x_view = bm_func(bm.asarray)(x) - x_view_jit = jax.jit(bm_func(bm.asarray))(x) - x_copy = bm_func(copy_func)(x) - x_copy_jit = jax.jit(bm_func(copy_func))(x) - - _ptr = lambda x: x.device_buffer.unsafe_buffer_pointer() - - self.assertEqual(_ptr(x), _ptr(x_view)) - self.assertEqual(_ptr(x), _ptr(x_view_jit)) - self.assertNotEqual(_ptr(x), _ptr(x_copy)) - self.assertNotEqual(_ptr(x), _ptr(x_copy_jit)) - - x.delete() - - self.assertTrue(x_view.is_deleted()) - self.assertTrue(x_view_jit.is_deleted()) - - self.assertFalse(x_copy.is_deleted()) - self.assertFalse(x_copy_jit.is_deleted()) - - def testArrayCopyAutodiff(self): - f = lambda x: jnp.array(x, copy=True) - - x = jnp.ones(10) - xdot = jnp.ones(10) - y, ydot = jax.jvp(f, (x,), (xdot,)) - self.assertIsNot(x, y) - self.assertIsNot(xdot, ydot) - - ybar = jnp.ones(10) - y, f_vjp = jax.vjp(f, x) - xbar, = f_vjp(ybar) - self.assertIsNot(x, y) - self.assertIsNot(xbar, ybar) - - def testArrayCopyVmap(self): - f = lambda x: jnp.array(x, copy=True) - x = jnp.ones(10) - y = jax.vmap(f)(x) - self.assertIsNot(x, y) - - def testArrayUnsupportedDtypeError(self): - with self.assertRaisesRegex(TypeError, - "JAX only supports number and bool dtypes.*"): - bm.array(3, [('a',' 0.: - return x * 2 - else: - return x + 2 - - self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.)) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for shape in [(3,), (2, 3)] - for dtype in default_dtypes - for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] # Test negative axes and tuples + for out_shape in s([None] + array_shapes) + for in_dtype in s(default_dtypes) + for fill_value_dtype in s(default_dtypes) + for fill_value_shape in s(_compatible_shapes(shape if out_shape is None else out_shape)) + for out_dtype in s(default_dtypes)))) + def testFullLike(self, shape, in_dtype, fill_value_dtype, fill_value_shape, out_dtype, out_shape): + if numpy_version < (1, 19) and out_shape == (): + raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None") + rng = jtu.rand_default(self.rng()) + np_fun = lambda x, fill_value: np.full_like( + x, fill_value, dtype=out_dtype, shape=out_shape) + bm_fun = lambda x, fill_value: bm.full_like( + x, fill_value, dtype=out_dtype, shape=out_shape) + args_maker = lambda: [rng(shape, in_dtype), rng(fill_value_shape, fill_value_dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_func={}_inshape={}_outshape={}_outdtype={}".format( + func, jtu.format_shape_dtype_string(shape, in_dtype), + out_shape, out_dtype), + "func": func, "shape": shape, "in_dtype": in_dtype, + "out_shape": out_shape, "out_dtype": out_dtype} + for shape in array_shapes + for out_shape in [None] + array_shapes + for in_dtype in default_dtypes + for func in ["ones_like", "zeros_like"] + for out_dtype in default_dtypes)) + def testZerosOnesLike(self, func, shape, in_dtype, out_shape, out_dtype): + if numpy_version < (1, 19) and out_shape == (): + raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None") + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: getattr(np, func)(x, dtype=out_dtype, shape=out_shape) + bm_fun = lambda x: getattr(bm, func)(x, dtype=out_dtype, shape=out_shape) + args_maker = lambda: [rng(shape, in_dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_func={}_inshape={}_weak_type={}_outshape={}_outdtype={}".format( + func, jtu.format_shape_dtype_string(shape, in_dtype), + weak_type, out_shape, out_dtype), + "func": func, "args": args, + "shape": shape, "in_dtype": in_dtype, "weak_type": weak_type, + "out_shape": out_shape, "out_dtype": out_dtype} + for shape in array_shapes + for in_dtype in [np.int32, np.float32, np.complex64] + for weak_type in [True, False] + for out_shape in [None, (), (10,)] + for func, args in [("full_like", (-100,)), ("ones_like", ()), ("zeros_like", ())] + for out_dtype in [None, float])) + def testZerosOnesFullLikeWeakType(self, func, args, shape, in_dtype, weak_type, out_shape, out_dtype): + if numpy_version < (1, 19) and out_shape == (): + raise SkipTest("Numpy < 1.19 treats out_shape=() like out_shape=None") + rng = jtu.rand_default(self.rng()) + x = lax_internal._convert_element_type(rng(shape, in_dtype), + weak_type=weak_type) + fun = lambda x: getattr(bm, func)(x, *args, dtype=out_dtype, shape=out_shape) + expected_weak_type = weak_type and (out_dtype is None) + self.assertEqual(dtypes.is_weakly_typed(bm_func(fun)(x)), expected_weak_type) + self.assertEqual(dtypes.is_weakly_typed(jax.jit(bm_func(fun))(x)), expected_weak_type) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_funcname={}_input_type={}_val={}_dtype={}".format( + funcname, input_type, val, dtype), + "funcname": funcname, "input_type": input_type, "val": val, "dtype": dtype} + for funcname in ["array", "asarray"] + for dtype in [int, float, None] + for val in [0, 1] + for input_type in [int, float, np.int32, np.float32])) + def testArrayWeakType(self, funcname, input_type, val, dtype): + bm_fun = lambda x: getattr(bm, funcname)(x, dtype=dtype) + fjit = jax.jit(bm_func(bm_fun)) + val = input_type(val) + expected_weak_type = dtype is None and input_type in set(dtypes._weak_types) + self.assertEqual(dtypes.is_weakly_typed(bm_func(bm_fun)(val)), expected_weak_type) + self.assertEqual(dtypes.is_weakly_typed(fjit(val)), expected_weak_type) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_weak_type={}_slc={}".format( + jtu.format_shape_dtype_string(shape, dtype), weak_type, slc), + "shape": shape, "dtype": dtype, "weak_type": weak_type, "slc": slc} + for shape in nonempty_nonscalar_array_shapes + for dtype in [int, float, complex] + for weak_type in [True, False] + for slc in [slice(None), slice(0), slice(3), 0, ...])) + def testSliceWeakTypes(self, shape, dtype, weak_type, slc): + rng = jtu.rand_default(self.rng()) + x = lax_internal._convert_element_type(rng(shape, dtype), + weak_type=weak_type) + op = lambda x: x[slc] + self.assertEqual(op(x).aval.weak_type, weak_type) + self.assertEqual(jax.jit(op)(x).aval.weak_type, weak_type) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), + "shape": shape, "num_sections": num_sections, "axis": axis, + "dtype": dtype} + for shape, axis, num_sections in [ + ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), + ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] + for dtype in default_dtypes)) + def testSplitStaticInt(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.split(x, num_sections, axis=axis) + bm_fun = lambda x: bm.split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), + "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype} + # All testcases split the specified axis unequally + for shape, axis, num_sections in [ + ((3,), 0, 2), ((12,), 0, 5), ((12, 4), 0, 7), ((12, 4), 1, 3), + ((2, 3, 5), -1, 2), ((2, 4, 4), -2, 3), ((7, 2, 2), 0, 3)] + for dtype in default_dtypes)) + def testArraySplitStaticInt(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.array_split(x, num_sections, axis=axis) + bm_fun = lambda x: bm.array_split(x, num_sections, axis=axis) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testSplitTypeError(self): + # If we pass an ndarray for indices_or_sections -> no error + self.assertEqual(3, len(bm_func(bm.split)(bm.zeros(3), bm.array([1, 2])))) + + CONCRETIZATION_MSG = "Abstract tracer value encountered where concrete value is expected." + with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): + # An abstract tracer for idx + jax.jit(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), idx))(2.) + with self.assertRaisesRegex(TypeError, CONCRETIZATION_MSG): + # A list including an abstract tracer + jax.jit(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), [2, idx]))(2.) + + # A concrete tracer -> no error + jax.jvp(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), idx), + (2.,), (1.,)) + # A tuple including a concrete tracer -> no error + jax.jvp(lambda idx: bm_func(bm.split)(bm.zeros((12, 2)), (1, idx)), + (2.,), (1.,)) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_bins={}_range={}_weights={}".format( + jtu.format_shape_dtype_string(shape, dtype), bins, range, weights), + "shape": shape, + "dtype": dtype, + "bins": bins, + "range": range, + "weights": weights, + } + for shape in [(5,), (5, 5)] + for dtype in number_dtypes + for bins in [10, np.arange(-5, 6), np.array([-5, 0, 3])] + for range in [None, (0, 0), (0, 10)] + for weights in [True, False] + )) + def testHistogramBinEdges(self, shape, dtype, bins, range, weights): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = lambda a, w, r: np.histogram_bin_edges(a, bins=bins, range=r, + weights=_weights(w)) + bm_fun = lambda a, w, r: bm.histogram_bin_edges(a, bins=bins, range=r, + weights=_weights(w)) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), range] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-2} + # linspace() compares poorly to numpy when using bfloat16 + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, + atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_bins={}_density={}_weights={}".format( + jtu.format_shape_dtype_string(shape, dtype), bins, density, weights), + "shape": shape, + "dtype": dtype, + "bins": bins, + "density": density, + "weights": weights, + } + for shape in [(5,), (5, 5)] + for dtype in default_dtypes + # We only test explicit integer-valued bin edges because in other cases + # rounding errors lead to flaky tests. + for bins in [np.arange(-5, 6), np.array([-5, 0, 3])] + for density in [True, False] + for weights in [True, False] + )) + def testHistogram(self, shape, dtype, bins, density, weights): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = lambda a, w: np.histogram(a, bins=bins, density=density, + weights=_weights(w)) + bm_fun = lambda a, w: bm.histogram(a, bins=bins, density=density, + weights=_weights(w)) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_bins={}_weights={}_density={}_range={}".format( + jtu.format_shape_dtype_string(shape, dtype), bins, weights, density, range), + "shape": shape, "dtype": dtype, "bins": bins, "weights": weights, "density": density, "range": range, + } + for shape in [(5,), (12,)] + for dtype in int_dtypes + for bins in [2, [2, 2], [np.array([0, 1, 3, 5]), np.array([0, 2, 3, 4, 6])]] + for weights in [False, True] + for density in [False, True] + for range in [None, [(-1, 1), None], [(-1, 1), (-2, 2)]] + )) + def testHistogram2d(self, shape, dtype, bins, weights, density, range): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( + lambda a, b, w: np.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range)) + bm_fun = lambda a, b, w: bm.histogram2d(a, b, bins=bins, weights=_weights(w), density=density, range=range) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype), rng(shape, dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + with np.errstate(divide='ignore', invalid='ignore'): + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_bins={}_weights={}_density={}_range={}".format( + jtu.format_shape_dtype_string(shape, dtype), bins, weights, density, range), + "shape": shape, "dtype": dtype, "bins": bins, "weights": weights, "density": density, "range": range, + } + for shape in [(5, 3), (10, 3)] + for dtype in int_dtypes + for bins in [(2, 2, 2), [np.array([-5, 0, 4]), np.array([-4, -1, 2]), np.array([-6, -1, 4])]] + for weights in [False, True] + for density in [False, True] + for range in [None, [(-1, 1), None, None], [(-1, 1), (-2, 2), (-3, 3)]] + )) + def testHistogramdd(self, shape, dtype, bins, weights, density, range): + rng = jtu.rand_default(self.rng()) + _weights = lambda w: abs(w) if weights else None + np_fun = jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*")( + lambda a, w: np.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range)) + bm_fun = lambda a, w: jnp.histogramdd(a, bins=bins, weights=_weights(w), density=density, range=range) + args_maker = lambda: [rng(shape, dtype), rng((shape[0],), dtype)] + tol = {jnp.bfloat16: 2E-2, np.float16: 1E-1} + # np.searchsorted errors on bfloat16 with + # "TypeError: invalid type promotion with custom data type" + if dtype != jnp.bfloat16: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_{}sections".format( + jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), + "shape": shape, "num_sections": num_sections, "axis": axis, + "dtype": dtype} + for shape, axis, num_sections in [ + ((12, 4), 0, 4), ((12, 4), 1, 2), + ((2, 3, 4), 2, 2), ((4, 3, 4), 0, 2)] + for dtype in default_dtypes)) + def testHVDSplit(self, shape, num_sections, axis, dtype): + rng = jtu.rand_default(self.rng()) + + def fn(module, axis): + if axis == 0: + return module.vsplit + elif axis == 1: + return module.hsplit + else: + assert axis == 2 + return module.dsplit + + np_fun = lambda x: fn(np, axis)(x, num_sections) + bm_fun = lambda x: fn(bm, axis)(x, num_sections) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}_order={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype), + order), + "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, + "order": order} + for dtype in default_dtypes + for order in ["C", "F"] + for arg_shape, out_shape in [ + (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), + ((), (1, 1, 1)), + ((7, 0), (0, 42, 101)), + ((3, 4), 12), + ((3, 4), (12,)), + ((3, 4), -1), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)) + ])) + def testReshape(self, arg_shape, out_shape, dtype, order): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.reshape(x, out_shape, order=order) + bm_fun = lambda x: bm.reshape(x, out_shape, order=order) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype)), + "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype} + for dtype in default_dtypes + for arg_shape, out_shape in [ + ((7, 0), (0, 42, 101)), + ((2, 1, 4), (-1,)), + ((2, 2, 4), (2, 8)) + ])) + def testReshapeMethod(self, arg_shape, out_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.reshape(x, out_shape) + bm_fun = lambda x: bm.reshape(x, out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_outshape={}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), + jtu.format_shape_dtype_string(out_shape, dtype)), + "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype} + for dtype in default_dtypes + for arg_shape, out_shape in itertools.product(all_shapes, array_shapes))) + def testResize(self, arg_shape, out_shape, dtype): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.resize(x, out_shape) + bm_fun = lambda x: bm.resize(x, out_shape) + args_maker = lambda: [rng(arg_shape, dtype)] + if len(out_shape) > 0 or numpy_version >= (1, 20, 0): + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_expanddim={!r}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), dim), + "arg_shape": arg_shape, "dtype": dtype, "dim": dim} + for arg_shape in [(), (3,), (3, 4)] + for dtype in default_dtypes + for dim in (list(range(-len(arg_shape) + 1, len(arg_shape))) + + [np.array(0), np.array(-1), (0,), [np.array(0)], + (len(arg_shape), len(arg_shape) + 1)]))) + def testExpandDimsStaticDim(self, arg_shape, dtype, dim): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.expand_dims(x, dim) + bm_fun = lambda x: bm.expand_dims(x, dim) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CompileAndCheck(bm_func(bm_fun), args_maker) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + + def testExpandDimsRepeatedAxisError(self): + x = bm.ones((2, 3)) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: bm.expand_dims(x, [1, 1])) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: bm.expand_dims(x, [3, -1])) + + # ensure this is numpy's behavior too, so that we remain consistent + x = np.ones((2, 3)) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: np.expand_dims(x, [1, 1])) + self.assertRaisesRegex( + ValueError, 'repeated axis.*', + lambda: np.expand_dims(x, [3, -1])) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_axes=({},{})".format( + jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), + "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2} + for arg_shape, ax1, ax2 in [ + ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), + ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] + for dtype in default_dtypes)) + def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.swapaxes(x, ax1, ax2) + bm_fun = lambda x: bm.swapaxes(x, ax1, ax2) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_inshape={}_axis={!r}".format( + jtu.format_shape_dtype_string(arg_shape, dtype), ax), + "arg_shape": arg_shape, "dtype": dtype, "ax": ax} + for arg_shape, ax in [ + ((3, 1), None), + ((3, 1), 1), + ((3, 1), -1), + ((3, 1), np.array(1)), + ((1, 3, 1), (0, 2)), + ((1, 3, 1), (0,)), + ((1, 4, 1), (np.array(0),))] + for dtype in default_dtypes)) + def testSqueeze(self, arg_shape, dtype, ax): + rng = jtu.rand_default(self.rng()) + np_fun = lambda x: np.squeeze(x, ax) + bm_fun = lambda x: bm.squeeze(x, ax) + args_maker = lambda: [rng(arg_shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_axis={}_weights={}_returned={}".format( + jtu.format_shape_dtype_string(shape, dtype), + axis, + (None if weights_shape is None else jtu.format_shape_dtype_string(weights_shape, dtype)), + returned), + "shape": shape, "dtype": dtype, "axis": axis, + "weights_shape": weights_shape, "returned": returned} + for shape, dtype in _shape_and_dtypes(nonempty_shapes, number_dtypes) + for axis in list(range(-len(shape), len(shape))) + [None] + # `weights_shape` is either `None`, same as the averaged axis, or same as + # that of the input + for weights_shape in ([None, shape] if axis is None or len(shape) == 1 + else [None, (shape[axis],), shape]) + for returned in [False, True])) + def testAverage(self, shape, dtype, axis, weights_shape, returned): + rng = jtu.rand_default(self.rng()) + if weights_shape is None: + np_fun = lambda x: np.average(x, axis, returned=returned) + bm_fun = lambda x: bm.average(x, axis, returned=returned) + args_maker = lambda: [rng(shape, dtype)] + else: + np_fun = lambda x, weights: np.average(x, axis, weights, returned) + bm_fun = lambda x, weights: bm.average(x, axis, weights, returned) + args_maker = lambda: [rng(shape, dtype), rng(weights_shape, dtype)] + np_fun = _promote_like_jnp(np_fun, inexact=True) + tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-2, np.float32: 1e-5, + np.float64: 1e-12, np.complex64: 1e-5} + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + try: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + check_dtypes=check_dtypes, tol=tol) + except ZeroDivisionError: + self.skipTest("don't support checking for ZeroDivisionError") + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=check_dtypes, + rtol=tol, atol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": + f"_arg{i}_ndmin={ndmin}_dtype={np.dtype(dtype) if dtype else None}", + "arg": arg, "ndmin": ndmin, "dtype": dtype} + for i, (arg, dtypes) in enumerate([ + ([True, False, True], all_dtypes), + (3., all_dtypes), + ([1, 2, 3], all_dtypes), + (np.array([1, 2, 3], dtype=np.int64), all_dtypes), + ([1., 2., 3.], all_dtypes), + ([[1, 2], [3, 4], [5, 6]], all_dtypes), + ([[1, 2.], [3, 4], [5, 6]], all_dtypes), + ([[1., 2j], [3., 4.], [5., 6.]], complex_dtypes), + ([[3, np.array(2, dtype=bm.float_), 1], + np.arange(3., dtype=bm.float_)], all_dtypes), + ]) + for dtype in [None] + dtypes + for ndmin in [None, np.ndim(arg), np.ndim(arg) + 1, np.ndim(arg) + 2])) + def testArray(self, arg, ndmin, dtype): + args_maker = lambda: [arg] + canonical_dtype = dtypes.canonicalize_dtype(dtype or np.array(arg).dtype) + if ndmin is not None: + np_fun = partial(np.array, ndmin=ndmin, dtype=canonical_dtype) + bm_fun = partial(bm.array, ndmin=ndmin, dtype=dtype) + else: + np_fun = partial(np.array, dtype=canonical_dtype) + bm_fun = partial(bm.array, dtype=dtype) + + # We are testing correct canonicalization behavior here, so we turn off the + # permissive canonicalization logic in the test harness. + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + canonicalize_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*") + def testArrayDtypeInference(self): + def _check(obj, out_dtype, weak_type): + dtype_reference = np.array(obj, dtype=out_dtype) + + out = bm_func(bm.array)(obj) + self.assertDtypesMatch(out, dtype_reference) + self.assertEqual(dtypes.is_weakly_typed(out), weak_type) + + out_jit = jax.jit(bm_func(bm.array))(obj) + self.assertDtypesMatch(out_jit, dtype_reference) + self.assertEqual(dtypes.is_weakly_typed(out_jit), weak_type) + + # Python scalars become 64-bit weak types. + _check(1, np.int64, True) + _check(1.0, np.float64, True) + _check(1.0j, np.complex128, True) + + # Lists become strongly-typed defaults. + _check([1], bm.int_, False) + _check([1.0], bm.float_, False) + _check([1.0j], bm.complex_, False) + + # Lists of weakly-typed objects become strongly-typed defaults. + _check([bm.array(1).value], bm.int_, False) + _check([bm.array(1.0).value], bm.float_, False) + _check([bm.array(1.0j).value], bm.complex_, False) + + # Lists of strongly-typed objects maintain their strong type. + _check([bm.int64(1)], np.int64, False) + _check([bm.float64(1)], np.float64, False) + _check([bm.complex128(1)], np.complex128, False) + + # Mixed inputs use JAX-style promotion. + # (regression test for https://github.com/google/jax/issues/8945) + _check([0, np.int16(1)], np.int16, False) + _check([0.0, np.float16(1)], np.float16, False) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": f"_dtype={np.dtype(dtype)}_func={func}", + "dtype": dtype, "func": func} + for dtype in all_dtypes + for func in ["array", "copy"])) + def testArrayCopy(self, dtype, func): + x = bm_func(bm.ones)(10, dtype=dtype) + copy_func = getattr(bm, func) + + x_view = bm_func(bm.asarray)(x) + x_view_jit = jax.jit(bm_func(bm.asarray))(x) + x_copy = bm_func(copy_func)(x) + x_copy_jit = jax.jit(bm_func(copy_func))(x) + + _ptr = lambda x: x.device_buffer.unsafe_buffer_pointer() + + self.assertEqual(_ptr(x), _ptr(x_view)) + self.assertEqual(_ptr(x), _ptr(x_view_jit)) + self.assertNotEqual(_ptr(x), _ptr(x_copy)) + self.assertNotEqual(_ptr(x), _ptr(x_copy_jit)) + + x.delete() + + self.assertTrue(x_view.is_deleted()) + self.assertTrue(x_view_jit.is_deleted()) + + self.assertFalse(x_copy.is_deleted()) + self.assertFalse(x_copy_jit.is_deleted()) + + def testArrayCopyAutodiff(self): + f = lambda x: jnp.array(x, copy=True) + + x = jnp.ones(10) + xdot = jnp.ones(10) + y, ydot = jax.jvp(f, (x,), (xdot,)) + self.assertIsNot(x, y) + self.assertIsNot(xdot, ydot) + + ybar = jnp.ones(10) + y, f_vjp = jax.vjp(f, x) + xbar, = f_vjp(ybar) + self.assertIsNot(x, y) + self.assertIsNot(xbar, ybar) + + def testArrayCopyVmap(self): + f = lambda x: jnp.array(x, copy=True) + x = jnp.ones(10) + y = jax.vmap(f)(x) + self.assertIsNot(x, y) + + def testArrayUnsupportedDtypeError(self): + with self.assertRaisesRegex(TypeError, + "JAX only supports number and bool dtypes.*"): + bm.array(3, [('a', ' 0.: + return x * 2 + else: + return x + 2 + + self.assertRaises(jax.errors.ConcretizationTypeError, lambda: g(3.)) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for shape in [(3,), (2, 3)] + for dtype in default_dtypes + for axis in list(range(-len(shape), len(shape))) + [None] + [tuple(range(len(shape)))] + # Test negative axes and tuples )) - def testFlip(self, shape, dtype, axis): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - bm_op = lambda x: bm.flip(x, axis) - np_op = lambda x: np.flip(x, axis) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in [(3,), (2, 3), (3, 2, 4)] - for dtype in default_dtypes)) - def testFlipud(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - bm_op = lambda x: bm.flipud(x) - np_op = lambda x: np.flipud(x) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in [(3, 2), (2, 3), (3, 2, 4)] - for dtype in default_dtypes)) - def testFliplr(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - bm_op = lambda x: bm.fliplr(x) - np_op = lambda x: np.fliplr(x) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_k={}_axes={}".format( - jtu.format_shape_dtype_string(shape, dtype), k, axes), - "shape": shape, "dtype": dtype, "k": k, "axes": axes} - for shape, axes in [ - [(2, 3), (0, 1)], - [(2, 3), (1, 0)], - [(4, 3, 2), (0, 2)], - [(4, 3, 2), (2, 1)], - ] - for k in range(-3, 4) - for dtype in default_dtypes)) - def testRot90(self, shape, dtype, k, axes): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - bm_op = lambda x: bm.rot90(x, k, axes) - np_op = lambda x: np.rot90(x, k, axes) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - # TODO(mattjj): test infix operator overrides - - def testRavel(self): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - self._CompileAndCheck(lambda x: x.ravel(), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_order={}_mode={}".format( - shape, order, mode), - "shape": shape, "order": order, "mode": mode} - for shape in nonempty_nonscalar_array_shapes - for order in ['C', 'F'] - for mode in ['wrap', 'clip', 'raise'])) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testRavelMultiIndex(self, shape, order, mode): - # generate indices in each dimension with a few out of bounds. - rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1) - for dim in shape] - # generate multi_indices of different dimensions that broadcast. - args_maker = lambda: [tuple(rng(ndim * (3,), bm.int_) - for ndim, rng in enumerate(rngs))] - def np_fun(x): - try: - return np.ravel_multi_index(x, shape, order=order, mode=mode) - except ValueError as err: - if str(err).startswith('invalid entry'): - # sentinel indicating expected error. - return -999 + def testFlip(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + bm_op = lambda x: bm.flip(x, axis) + np_op = lambda x: np.flip(x, axis) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in [(3,), (2, 3), (3, 2, 4)] + for dtype in default_dtypes)) + def testFlipud(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + bm_op = lambda x: bm.flipud(x) + np_op = lambda x: np.flipud(x) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in [(3, 2), (2, 3), (3, 2, 4)] + for dtype in default_dtypes)) + def testFliplr(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + bm_op = lambda x: bm.fliplr(x) + np_op = lambda x: np.fliplr(x) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_k={}_axes={}".format( + jtu.format_shape_dtype_string(shape, dtype), k, axes), + "shape": shape, "dtype": dtype, "k": k, "axes": axes} + for shape, axes in [ + [(2, 3), (0, 1)], + [(2, 3), (1, 0)], + [(4, 3, 2), (0, 2)], + [(4, 3, 2), (2, 1)], + ] + for k in range(-3, 4) + for dtype in default_dtypes)) + def testRot90(self, shape, dtype, k, axes): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + bm_op = lambda x: bm.rot90(x, k, axes) + np_op = lambda x: np.rot90(x, k, axes) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + # TODO(mattjj): test infix operator overrides + + def testRavel(self): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + self._CompileAndCheck(lambda x: x.ravel(), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_order={}_mode={}".format( + shape, order, mode), + "shape": shape, "order": order, "mode": mode} + for shape in nonempty_nonscalar_array_shapes + for order in ['C', 'F'] + for mode in ['wrap', 'clip', 'raise'])) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testRavelMultiIndex(self, shape, order, mode): + # generate indices in each dimension with a few out of bounds. + rngs = [jtu.rand_int(self.rng(), low=-1, high=dim + 1) + for dim in shape] + # generate multi_indices of different dimensions that broadcast. + args_maker = lambda: [tuple(rng(ndim * (3,), bm.int_) + for ndim, rng in enumerate(rngs))] + + def np_fun(x): + try: + return np.ravel_multi_index(x, shape, order=order, mode=mode) + except ValueError as err: + if str(err).startswith('invalid entry'): + # sentinel indicating expected error. + return -999 + else: + raise + + def bm_fun(x): + try: + return bm.ravel_multi_index(x, shape, order=order, mode=mode) + except ValueError as err: + if str(err).startswith('invalid entry'): + # sentinel indicating expected error. + return -999 + else: + raise + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + if mode == 'raise': + msg = ("The error occurred because ravel_multi_index was jit-compiled " + "with mode='raise'. Use mode='wrap' or mode='clip' instead.") + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg): + jax.jit(bm_fun)(*args_maker()) else: - raise - def bm_fun(x): - try: - return bm.ravel_multi_index(x, shape, order=order, mode=mode) - except ValueError as err: - if str(err).startswith('invalid entry'): - # sentinel indicating expected error. - return -999 + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_ashape={}{}_cshapes={}{}_mode={}".format( + adtype.__name__, ashape, cdtype.__name__, cshapes, mode), + "ashape": ashape, "adtype": adtype, "cshapes": cshapes, "cdtype": cdtype, "mode": mode} + for ashape in ((), (4,), (3, 4)) + for cshapes in [ + [(), (4,)], + [(3, 4), (4,), (3, 1)] + ] + for adtype in int_dtypes + for cdtype in default_dtypes + for mode in ['wrap', 'clip', 'raise'])) + def testChoose(self, ashape, adtype, cshapes, cdtype, mode): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]] + + def np_fun(a, c): + try: + return np.choose(a, c, mode=mode) + except ValueError as err: + if mode == 'raise' and str(err).startswith('invalid entry'): + return -999 # sentinel indicating expected error. + else: + raise + + def bm_fun(a, c): + try: + return bm.choose(a, c, mode=mode) + except ValueError as err: + if mode == 'raise' and str(err).startswith('invalid entry'): + return -999 # sentinel indicating expected error. + else: + raise + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + if mode == 'raise': + msg = ("The error occurred because jnp.choose was jit-compiled" + " with mode='raise'. Use mode='wrap' or mode='clip' instead.") + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg): + jax.jit(bm_fun)(*args_maker()) else: - raise - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - if mode == 'raise': - msg = ("The error occurred because ravel_multi_index was jit-compiled " - "with mode='raise'. Use mode='wrap' or mode='clip' instead.") - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg): - jax.jit(bm_fun)(*args_maker()) - else: - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_ashape={}{}_cshapes={}{}_mode={}".format( - adtype.__name__, ashape, cdtype.__name__, cshapes, mode), - "ashape": ashape, "adtype": adtype, "cshapes": cshapes, "cdtype": cdtype, "mode": mode} - for ashape in ((), (4,), (3, 4)) - for cshapes in [ - [(), (4,)], - [(3, 4), (4,), (3, 1)] - ] - for adtype in int_dtypes - for cdtype in default_dtypes - for mode in ['wrap', 'clip', 'raise'])) - def testChoose(self, ashape, adtype, cshapes, cdtype, mode): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(ashape, adtype), [rng(s, cdtype) for s in cshapes]] - def np_fun(a, c): - try: - return np.choose(a, c, mode=mode) - except ValueError as err: - if mode == 'raise' and str(err).startswith('invalid entry'): - return -999 # sentinel indicating expected error. + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + + return f + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_idx={}".format(shape, + jtu.format_shape_dtype_string(idx_shape, dtype)), + "shape": shape, "idx_shape": idx_shape, "dtype": dtype} + for shape in nonempty_nonscalar_array_shapes + for dtype in int_dtypes + for idx_shape in all_shapes)) + def testUnravelIndex(self, shape, idx_shape, dtype): + size = prod(shape) + rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) + + def np_fun(index, shape): + # JAX's version outputs the same dtype as the input in the typical case + # where shape is weakly-typed. + out_dtype = index.dtype + # Adjust out-of-bounds behavior to match jax's documented behavior. + index = np.clip(index, -size, size - 1) + index = np.where(index < 0, index + size, index) + return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] + + bm_fun = bm.unravel_index + args_maker = lambda: [rng(idx_shape, dtype), shape] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testAstype(self): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype("float32")] + np_op = lambda x: np.asarray(x).astype(bm.int32) + bm_op = lambda x: bm.asarray(x).astype(bm.int32) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + def testAstypeNone(self): + rng = self.rng() + args_maker = lambda: [rng.randn(3, 4).astype("int32")] + np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None)) + bm_op = lambda x: bm.asarray(x).astype(None) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in array_shapes + for dtype in all_dtypes)) + def testNbytes(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + np_op = lambda x: np.asarray(x).nbytes + bm_op = lambda x: bm.asarray(x).value.nbytes + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in array_shapes + for dtype in all_dtypes)) + def testItemsize(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + np_op = lambda x: np.asarray(x).itemsize + bm_op = lambda x: bm.asarray(x).value.itemsize + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_dtype={}".format( + jtu.format_shape_dtype_string(shape, a_dtype), dtype), + "shape": shape, "a_dtype": a_dtype, "dtype": dtype} + for shape in [(8,), (3, 8)] # last dim = 8 to ensure shape compatibility + for a_dtype in (default_dtypes + unsigned_dtypes + bool_dtypes) + for dtype in (default_dtypes + unsigned_dtypes + bool_dtypes))) + def testView(self, shape, a_dtype, dtype): + if jtu.device_under_test() == 'tpu': + if bm.dtype(a_dtype).itemsize in [1, 2] or bm.dtype(dtype).itemsize in [1, 2]: + self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") + if not config.x64_enabled: + if bm.dtype(a_dtype).itemsize == 8 or bm.dtype(dtype).itemsize == 8: + self.skipTest("x64 types are disabled by jax_enable_x64") + rng = jtu.rand_fullrange(self.rng()) + args_maker = lambda: [rng(shape, a_dtype)] + np_op = lambda x: np.asarray(x).view(dtype) + bm_op = lambda x: bm.asarray(x).view(dtype) + # Above may produce signaling nans; ignore warnings from invalid values. + with np.errstate(invalid='ignore'): + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + def testPathologicalFloats(self): + args_maker = lambda: [np.array([ + 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf + 0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf + 0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan + 0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan + 0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan + 0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan + 0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan + 0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan + 0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero + 0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero + ], dtype='uint32')] + + np_op = lambda x: np.asarray(x).view('float32').view('uint32') + bm_op = lambda x: bm.asarray(x).view('float32').view('uint32') + + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + # TODO(mattjj): test other ndarray-like method overrides + + def testNpMean(self): + # from https://github.com/google/jax/issues/125 + x = bm.eye(3, dtype=float).value + 0. + ans = np.mean(x) + self.assertAllClose(ans, np.array(1. / 3), check_dtypes=False) + + def testArangeOnFloats(self): + np_arange = jtu.with_jax_dtype_defaults(np.arange) + # from https://github.com/google/jax/issues/145 + self.assertAllClose(np_arange(0.0, 1.0, 0.1), + bm.arange(0.0, 1.0, 0.1).value) + # from https://github.com/google/jax/issues/3450 + self.assertAllClose(np_arange(2.5), + bm.arange(2.5).value) + self.assertAllClose(np_arange(0., 2.5), + bm.arange(0., 2.5).value) + + def testArangeTypes(self): + # Test that arange() output type is equal to the default types. + int_ = dtypes.canonicalize_dtype(bm.int_) + float_ = dtypes.canonicalize_dtype(bm.float_) + + self.assertEqual(bm.arange(10).value.dtype, int_) + self.assertEqual(bm.arange(10.).value.dtype, float_) + self.assertEqual(bm.arange(10, dtype='uint16').value.dtype, np.uint16) + self.assertEqual(bm.arange(10, dtype='bfloat16').value.dtype, jnp.bfloat16) + + self.assertEqual(bm.arange(0, 10, 1).value.dtype, int_) + self.assertEqual(bm.arange(0, 10, 1.).value.dtype, float_) + self.assertEqual(bm.arange(0., 10, 1).value.dtype, float_) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for dtype in all_dtypes + for shape in nonzerodim_shapes + for axis in (None, *range(len(shape))))) + def testSort(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + bm_fun = bm.sort + np_fun = np.sort + if axis is not None: + bm_fun = partial(bm_fun, axis=axis) + np_fun = partial(np_fun, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for dtype in all_dtypes + for shape in one_dim_array_shapes + for axis in [None])) + def testSortComplex(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.sort_complex, bm_func(bm.sort_complex), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm.sort_complex), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_input_type={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), + input_type.__name__, axis), + "shape": shape, "dtype": dtype, "input_type": input_type, "axis": axis} + for dtype in all_dtypes + for shape in nonempty_nonscalar_array_shapes + for input_type in [np.array, tuple] + for axis in (-1, *range(len(shape) - 1)))) + def testLexsort(self, dtype, shape, input_type, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [input_type(rng(shape, dtype))] + bm_op = lambda x: bm.lexsort(x, axis=axis) + np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis)) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "shape": shape, "dtype": dtype, "axis": axis} + for dtype in all_dtypes + for shape in nonzerodim_shapes + for axis in (None, *range(len(shape))))) + def testArgsort(self, dtype, shape, axis): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + bm_fun = bm.argsort + np_fun = jtu.with_jax_dtype_defaults(np.argsort) + if axis is not None: + bm_fun = partial(bm_fun, axis=axis) + np_fun = partial(np_fun, axis=axis) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for dtype in all_dtypes + for shape in nonzerodim_shapes)) + def testMsort(self, dtype, shape): + rng = jtu.rand_some_equal(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np.msort, bm_func(bm.msort), args_maker) + self._CompileAndCheck(bm_func(bm.msort), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_shifts={}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), + shifts, axis), + "shape": shape, "dtype": dtype, "shifts": shifts, "axis": axis} + for dtype in all_dtypes + for shape in [(3, 4), (3, 4, 5), (7, 4, 0)] + for shifts, axis in [ + (3, None), + (1, 1), + ((3,), (0,)), + ((-2,), (-2,)), + ((1, 2), (0, -1)), + ((4, 2, 5, 5, 2, 4), None), + (100, None), + ])) + def testRoll(self, shape, dtype, shifts, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), np.array(shifts)] + bm_op = partial(bm.roll, axis=axis) + np_op = partial(np.roll, axis=axis) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_start={}".format( + jtu.format_shape_dtype_string(shape, dtype), + axis, start), + "shape": shape, "dtype": dtype, "axis": axis, + "start": start} + for dtype in all_dtypes + for shape in [(1, 2, 3, 4)] + for axis in [-3, 0, 2, 3] + for start in [-4, -1, 2, 4])) + def testRollaxis(self, shape, dtype, start, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + bm_op = partial(bm.rollaxis, axis=axis, start=start) + np_op = partial(np.rollaxis, axis=axis, start=start) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_bitorder={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, bitorder), + "shape": shape, "dtype": dtype, "axis": axis, + "bitorder": bitorder} + for dtype in [np.uint8, np.bool_] + for bitorder in ['big', 'little'] + for shape in [(1, 2, 3, 4)] + for axis in [None, 0, 1, -2, -1])) + def testPackbits(self, shape, dtype, axis, bitorder): + rng = jtu.rand_some_zero(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + bm_op = partial(bm.packbits, axis=axis, bitorder=bitorder) + np_op = partial(np.packbits, axis=axis, bitorder=bitorder) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_bitorder={}_count={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, bitorder, count), + "shape": shape, "dtype": dtype, "axis": axis, "bitorder": bitorder, + "count": count} + for dtype in [np.uint8] + for bitorder in ['big', 'little'] + for shape in [(1, 2, 3, 4)] + for axis in [None, 0, 1, -2, -1] + for count in [None, 20])) + def testUnpackbits(self, shape, dtype, axis, bitorder, count): + rng = jtu.rand_int(self.rng(), 0, 256) + args_maker = lambda: [rng(shape, dtype)] + bm_op = partial(bm.unpackbits, axis=axis, bitorder=bitorder) + np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): + def f(): + out = [rng(shape, dtype or jnp.float_) + for shape, dtype in zip(shapes, dtypes)] + if np_arrays: + return out + return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a + for a in out] + + return f + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_index={}_axis={}_mode={}".format( + jtu.format_shape_dtype_string(shape, dtype), + jtu.format_shape_dtype_string(index_shape, index_dtype), + axis, mode), + "shape": shape, "index_shape": index_shape, "dtype": dtype, + "index_dtype": index_dtype, "axis": axis, "mode": mode} + for shape in [(3,), (3, 4), (3, 4, 5)] + for index_shape in scalar_shapes + [(3,), (2, 1, 3)] + for axis in itertools.chain(range(-len(shape), len(shape)), + [cast(Optional[int], None)]) + for dtype in all_dtypes + for index_dtype in int_dtypes + for mode in [None, 'wrap', 'clip'])) + def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode): + def args_maker(): + x = rng(shape, dtype) + i = rng_indices(index_shape, index_dtype) + return x, i + + rng = jtu.rand_default(self.rng()) + if mode is None: + rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0]) else: - raise - def bm_fun(a, c): - try: - return bm.choose(a, c, mode=mode) - except ValueError as err: - if mode == 'raise' and str(err).startswith('invalid entry'): - return -999 # sentinel indicating expected error. + rng_indices = jtu.rand_int(self.rng(), -5, 5) + bm_op = lambda x, i: bm.take(x, i, axis=axis, mode=mode) + np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + def testTakeEmpty(self): + np.testing.assert_array_equal( + bm.array([], dtype=jnp.float32).value, + bm.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)).value) + + np.testing.assert_array_equal( + bm.ones((2, 0, 4), dtype=bm.float32).value, + bm.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32), + axis=1).value) + + with self.assertRaisesRegex(IndexError, "non-empty jnp.take"): + bm.take(jnp.ones((2, 0, 4), dtype=jnp.float32), + jnp.array([0], jnp.int32), axis=1) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_index={}_axis={}".format( + jtu.format_shape_dtype_string(x_shape, dtype), + jtu.format_shape_dtype_string(i_shape, index_dtype), axis), + "x_shape": x_shape, "i_shape": i_shape, "dtype": dtype, + "index_dtype": index_dtype, "axis": axis} + for x_shape, i_shape in filter( + _shapes_are_equal_length, + filter(_shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) + for axis in itertools.chain(range(len(x_shape)), [-1], + [cast(Optional[int], None)]) + for dtype in default_dtypes + for index_dtype in int_dtypes)) + def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): + rng = jtu.rand_default(self.rng()) + + i_shape = np.array(i_shape) + if axis is None: + i_shape = [np.prod(i_shape, dtype=np.int64)] else: - raise - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - if mode == 'raise': - msg = ("The error occurred because jnp.choose was jit-compiled" - " with mode='raise'. Use mode='wrap' or mode='clip' instead.") - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg): - jax.jit(bm_fun)(*args_maker()) - else: - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): - def f(): - out = [rng(shape, dtype or jnp.float_) - for shape, dtype in zip(shapes, dtypes)] - if np_arrays: - return out - return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a - for a in out] - return f - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_idx={}".format(shape, - jtu.format_shape_dtype_string(idx_shape, dtype)), - "shape": shape, "idx_shape": idx_shape, "dtype": dtype} - for shape in nonempty_nonscalar_array_shapes - for dtype in int_dtypes - for idx_shape in all_shapes)) - def testUnravelIndex(self, shape, idx_shape, dtype): - size = prod(shape) - rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) - - def np_fun(index, shape): - # JAX's version outputs the same dtype as the input in the typical case - # where shape is weakly-typed. - out_dtype = index.dtype - # Adjust out-of-bounds behavior to match jax's documented behavior. - index = np.clip(index, -size, size - 1) - index = np.where(index < 0, index + size, index) - return [i.astype(out_dtype) for i in np.unravel_index(index, shape)] - - bm_fun = bm.unravel_index - args_maker = lambda: [rng(idx_shape, dtype), shape] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testAstype(self): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("float32")] - np_op = lambda x: np.asarray(x).astype(bm.int32) - bm_op = lambda x: bm.asarray(x).astype(bm.int32) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - def testAstypeNone(self): - rng = self.rng() - args_maker = lambda: [rng.randn(3, 4).astype("int32")] - np_op = jtu.with_jax_dtype_defaults(lambda x: np.asarray(x).astype(None)) - bm_op = lambda x: bm.asarray(x).astype(None) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in array_shapes - for dtype in all_dtypes)) - def testNbytes(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - np_op = lambda x: np.asarray(x).nbytes - bm_op = lambda x: bm.asarray(x).value.nbytes - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in array_shapes - for dtype in all_dtypes)) - def testItemsize(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - np_op = lambda x: np.asarray(x).itemsize - bm_op = lambda x: bm.asarray(x).value.itemsize - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_dtype={}".format( - jtu.format_shape_dtype_string(shape, a_dtype), dtype), - "shape": shape, "a_dtype": a_dtype, "dtype": dtype} - for shape in [(8,), (3, 8)] # last dim = 8 to ensure shape compatibility - for a_dtype in (default_dtypes + unsigned_dtypes + bool_dtypes) - for dtype in (default_dtypes + unsigned_dtypes + bool_dtypes))) - def testView(self, shape, a_dtype, dtype): - if jtu.device_under_test() == 'tpu': - if bm.dtype(a_dtype).itemsize in [1, 2] or bm.dtype(dtype).itemsize in [1, 2]: - self.skipTest("arr.view() not supported on TPU for 8- or 16-bit types.") - if not config.x64_enabled: - if bm.dtype(a_dtype).itemsize == 8 or bm.dtype(dtype).itemsize == 8: - self.skipTest("x64 types are disabled by jax_enable_x64") - rng = jtu.rand_fullrange(self.rng()) - args_maker = lambda: [rng(shape, a_dtype)] - np_op = lambda x: np.asarray(x).view(dtype) - bm_op = lambda x: bm.asarray(x).view(dtype) - # Above may produce signaling nans; ignore warnings from invalid values. - with np.errstate(invalid='ignore'): - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - def testPathologicalFloats(self): - args_maker = lambda: [np.array([ - 0b_0111_1111_1000_0000_0000_0000_0000_0000, # inf - 0b_1111_1111_1000_0000_0000_0000_0000_0000, # -inf - 0b_0111_1111_1100_0000_0000_0000_0000_0000, # qnan - 0b_1111_1111_1100_0000_0000_0000_0000_0000, # -qnan - 0b_0111_1111_1000_0000_0000_0000_0000_0001, # snan - 0b_1111_1111_1000_0000_0000_0000_0000_0001, # -snan - 0b_0111_1111_1000_0000_0000_1100_0000_0000, # nonstandard nan - 0b_1111_1111_1000_0000_0000_1100_0000_0000, # -nonstandard nan - 0b_0000_0000_0000_0000_0000_0000_0000_0000, # zero - 0b_1000_0000_0000_0000_0000_0000_0000_0000, # -zero - ], dtype='uint32')] - - np_op = lambda x: np.asarray(x).view('float32').view('uint32') - bm_op = lambda x: bm.asarray(x).view('float32').view('uint32') - - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - # TODO(mattjj): test other ndarray-like method overrides - - def testNpMean(self): - # from https://github.com/google/jax/issues/125 - x = bm.eye(3, dtype=float).value + 0. - ans = np.mean(x) - self.assertAllClose(ans, np.array(1./3), check_dtypes=False) - - def testArangeOnFloats(self): - np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 - self.assertAllClose(np_arange(0.0, 1.0, 0.1), - bm.arange(0.0, 1.0, 0.1).value) - # from https://github.com/google/jax/issues/3450 - self.assertAllClose(np_arange(2.5), - bm.arange(2.5).value) - self.assertAllClose(np_arange(0., 2.5), - bm.arange(0., 2.5).value) - - def testArangeTypes(self): - # Test that arange() output type is equal to the default types. - int_ = dtypes.canonicalize_dtype(bm.int_) - float_ = dtypes.canonicalize_dtype(bm.float_) - - self.assertEqual(bm.arange(10).value.dtype, int_) - self.assertEqual(bm.arange(10.).value.dtype, float_) - self.assertEqual(bm.arange(10, dtype='uint16').value.dtype, np.uint16) - self.assertEqual(bm.arange(10, dtype='bfloat16').value.dtype, jnp.bfloat16) - - self.assertEqual(bm.arange(0, 10, 1).value.dtype, int_) - self.assertEqual(bm.arange(0, 10, 1.).value.dtype, float_) - self.assertEqual(bm.arange(0., 10, 1).value.dtype, float_) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for dtype in all_dtypes - for shape in nonzerodim_shapes - for axis in (None, *range(len(shape))))) - def testSort(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - bm_fun = bm.sort - np_fun = np.sort - if axis is not None: - bm_fun = partial(bm_fun, axis=axis) - np_fun = partial(np_fun, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for dtype in all_dtypes - for shape in one_dim_array_shapes - for axis in [None])) - def testSortComplex(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.sort_complex, bm_func(bm.sort_complex), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm.sort_complex), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_input_type={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), - input_type.__name__, axis), - "shape": shape, "dtype": dtype, "input_type": input_type, "axis": axis} - for dtype in all_dtypes - for shape in nonempty_nonscalar_array_shapes - for input_type in [np.array, tuple] - for axis in (-1, *range(len(shape) - 1)))) - def testLexsort(self, dtype, shape, input_type, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [input_type(rng(shape, dtype))] - bm_op = lambda x: bm.lexsort(x, axis=axis) - np_op = jtu.with_jax_dtype_defaults(lambda x: np.lexsort(x, axis=axis)) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis), - "shape": shape, "dtype": dtype, "axis": axis} - for dtype in all_dtypes - for shape in nonzerodim_shapes - for axis in (None, *range(len(shape))))) - def testArgsort(self, dtype, shape, axis): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - bm_fun = bm.argsort - np_fun = jtu.with_jax_dtype_defaults(np.argsort) - if axis is not None: - bm_fun = partial(bm_fun, axis=axis) - np_fun = partial(np_fun, axis=axis) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for dtype in all_dtypes - for shape in nonzerodim_shapes)) - def testMsort(self, dtype, shape): - rng = jtu.rand_some_equal(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.msort, bm_func(bm.msort), args_maker) - self._CompileAndCheck(bm_func(bm.msort), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_shifts={}_axis={}".format( - jtu.format_shape_dtype_string(shape, dtype), - shifts, axis), - "shape": shape, "dtype": dtype, "shifts": shifts, "axis": axis} - for dtype in all_dtypes - for shape in [(3, 4), (3, 4, 5), (7, 4, 0)] - for shifts, axis in [ - (3, None), - (1, 1), - ((3,), (0,)), - ((-2,), (-2,)), - ((1, 2), (0, -1)), - ((4, 2, 5, 5, 2, 4), None), - (100, None), - ])) - def testRoll(self, shape, dtype, shifts, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), np.array(shifts)] - bm_op = partial(bm.roll, axis=axis) - np_op = partial(np.roll, axis=axis) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_start={}".format( - jtu.format_shape_dtype_string(shape, dtype), - axis, start), - "shape": shape, "dtype": dtype, "axis": axis, - "start": start} - for dtype in all_dtypes - for shape in [(1, 2, 3, 4)] - for axis in [-3, 0, 2, 3] - for start in [-4, -1, 2, 4])) - def testRollaxis(self, shape, dtype, start, axis): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - bm_op = partial(bm.rollaxis, axis=axis, start=start) - np_op = partial(np.rollaxis, axis=axis, start=start) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_bitorder={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, bitorder), - "shape": shape, "dtype": dtype, "axis": axis, - "bitorder": bitorder} - for dtype in [np.uint8, np.bool_] - for bitorder in ['big', 'little'] - for shape in [(1, 2, 3, 4)] - for axis in [None, 0, 1, -2, -1])) - def testPackbits(self, shape, dtype, axis, bitorder): - rng = jtu.rand_some_zero(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - bm_op = partial(bm.packbits, axis=axis, bitorder=bitorder) - np_op = partial(np.packbits, axis=axis, bitorder=bitorder) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_axis={}_bitorder={}_count={}".format( - jtu.format_shape_dtype_string(shape, dtype), axis, bitorder, count), - "shape": shape, "dtype": dtype, "axis": axis, "bitorder": bitorder, - "count": count} - for dtype in [np.uint8] - for bitorder in ['big', 'little'] - for shape in [(1, 2, 3, 4)] - for axis in [None, 0, 1, -2, -1] - for count in [None, 20])) - def testUnpackbits(self, shape, dtype, axis, bitorder, count): - rng = jtu.rand_int(self.rng(), 0, 256) - args_maker = lambda: [rng(shape, dtype)] - bm_op = partial(bm.unpackbits, axis=axis, bitorder=bitorder) - np_op = partial(np.unpackbits, axis=axis, bitorder=bitorder) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - def _GetArgsMaker(self, rng, shapes, dtypes, np_arrays=True): - def f(): - out = [rng(shape, dtype or jnp.float_) - for shape, dtype in zip(shapes, dtypes)] - if np_arrays: - return out - return [jnp.asarray(a) if isinstance(a, (np.ndarray, np.generic)) else a - for a in out] - return f - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_index={}_axis={}_mode={}".format( - jtu.format_shape_dtype_string(shape, dtype), - jtu.format_shape_dtype_string(index_shape, index_dtype), - axis, mode), - "shape": shape, "index_shape": index_shape, "dtype": dtype, - "index_dtype": index_dtype, "axis": axis, "mode": mode} - for shape in [(3,), (3, 4), (3, 4, 5)] - for index_shape in scalar_shapes + [(3,), (2, 1, 3)] - for axis in itertools.chain(range(-len(shape), len(shape)), - [cast(Optional[int], None)]) - for dtype in all_dtypes - for index_dtype in int_dtypes - for mode in [None, 'wrap', 'clip'])) - def testTake(self, shape, dtype, index_shape, index_dtype, axis, mode): - def args_maker(): - x = rng(shape, dtype) - i = rng_indices(index_shape, index_dtype) - return x, i - - rng = jtu.rand_default(self.rng()) - if mode is None: - rng_indices = jtu.rand_int(self.rng(), -shape[axis or 0], shape[axis or 0]) - else: - rng_indices = jtu.rand_int(self.rng(), -5, 5) - bm_op = lambda x, i: bm.take(x, i, axis=axis, mode=mode) - np_op = lambda x, i: np.take(x, i, axis=axis, mode=mode) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - def testTakeEmpty(self): - np.testing.assert_array_equal( - bm.array([], dtype=jnp.float32).value, - bm.take(jnp.array([], jnp.float32), jnp.array([], jnp.int32)).value) - - np.testing.assert_array_equal( - bm.ones((2, 0, 4), dtype=bm.float32).value, - bm.take(jnp.ones((2, 0, 4), dtype=jnp.float32), jnp.array([], jnp.int32), - axis=1).value) - - with self.assertRaisesRegex(IndexError, "non-empty jnp.take"): - bm.take(jnp.ones((2, 0, 4), dtype=jnp.float32), - jnp.array([0], jnp.int32), axis=1) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_index={}_axis={}".format( - jtu.format_shape_dtype_string(x_shape, dtype), - jtu.format_shape_dtype_string(i_shape, index_dtype), axis), - "x_shape": x_shape, "i_shape": i_shape, "dtype": dtype, - "index_dtype": index_dtype, "axis": axis} - for x_shape, i_shape in filter( - _shapes_are_equal_length, - filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(nonempty_nonscalar_array_shapes, 2))) - for axis in itertools.chain(range(len(x_shape)), [-1], - [cast(Optional[int], None)]) - for dtype in default_dtypes - for index_dtype in int_dtypes)) - def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): - rng = jtu.rand_default(self.rng()) - - i_shape = np.array(i_shape) - if axis is None: - i_shape = [np.prod(i_shape, dtype=np.int64)] - else: - # Test the case where the size of the axis doesn't necessarily broadcast. - i_shape[axis] *= 3 - i_shape = list(i_shape) - def args_maker(): - x = rng(x_shape, dtype) - n = np.prod(x_shape, dtype=np.int32) if axis is None else x_shape[axis] - if np.issubdtype(index_dtype, np.unsignedinteger): - index_rng = jtu.rand_int(self.rng(), 0, n) - else: - index_rng = jtu.rand_int(self.rng(), -n, n) - i = index_rng(i_shape, index_dtype) - return x, i - - bm_op = lambda x, i: bm.take_along_axis(x, i, axis=axis) - - if hasattr(np, "take_along_axis"): - np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 - h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) - g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) - q0 = bm.take_along_axis(h, g, axis=-1).value - q1 = np.take_along_axis( h, g, axis=-1) - np.testing.assert_equal(q0, q1) - - def testTakeAlongAxisOutOfBounds(self): - x = jnp.arange(10, dtype=jnp.float32) - idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) - out = jnp.take_along_axis(x, idx, axis=0) - expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, - jnp.nan], np.float32) - np.testing.assert_array_equal(expected_fill, out) - out = bm.take_along_axis(x, idx, axis=0, mode="fill").value - np.testing.assert_array_equal(expected_fill, out) - - expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) - out = bm.take_along_axis(x, idx, axis=0, mode="clip").value - np.testing.assert_array_equal(expected_clip, out) - - def testTakeAlongAxisRequiresIntIndices(self): - x = jnp.arange(5) - idx = jnp.array([3.], jnp.float32) - with self.assertRaisesRegex( - TypeError, - "take_along_axis indices must be of integer type, got float32"): - bm.take_along_axis(x, idx, axis=0).value - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}_n={}_increasing={}".format( - jtu.format_shape_dtype_string([shape], dtype), - n, increasing), - "dtype": dtype, "shape": shape, "n": n, "increasing": increasing} - for dtype in inexact_dtypes - for shape in [0, 5] - for n in [2, 4] - for increasing in [False, True])) - def testVander(self, shape, dtype, n, increasing): - rng = jtu.rand_default(self.rng()) - def np_fun(arg): - arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg - return np.vander(arg, N=n, increasing=increasing) - bm_fun = lambda arg: bm.vander(arg, N=n, increasing=increasing) - args_maker = lambda: [rng([shape], dtype)] - # np.vander seems to return float64 for all floating types. We could obey - # those semantics, but they seem like a bug. - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol={np.float32: 1e-3}) - self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=False) - - @parameterized.named_parameters(jtu.cases_from_list( + # Test the case where the size of the axis doesn't necessarily broadcast. + i_shape[axis] *= 3 + i_shape = list(i_shape) + + def args_maker(): + x = rng(x_shape, dtype) + n = np.prod(x_shape, dtype=np.int32) if axis is None else x_shape[axis] + if np.issubdtype(index_dtype, np.unsignedinteger): + index_rng = jtu.rand_int(self.rng(), 0, n) + else: + index_rng = jtu.rand_int(self.rng(), -n, n) + i = index_rng(i_shape, index_dtype) + return x, i + + bm_op = lambda x, i: bm.take_along_axis(x, i, axis=axis) + + if hasattr(np, "take_along_axis"): + np_op = lambda x, i: np.take_along_axis(x, i, axis=axis) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): + # https://github.com/google/jax/issues/5088 + h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) + g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) + q0 = bm.take_along_axis(h, g, axis=-1).value + q1 = np.take_along_axis(h, g, axis=-1) + np.testing.assert_equal(q0, q1) + + def testTakeAlongAxisOutOfBounds(self): + x = jnp.arange(10, dtype=jnp.float32) + idx = jnp.array([-11, -10, -9, -5, -1, 0, 1, 5, 9, 10, 11]) + out = jnp.take_along_axis(x, idx, axis=0) + expected_fill = np.array([jnp.nan, 0, 1, 5, 9, 0, 1, 5, 9, jnp.nan, + jnp.nan], np.float32) + np.testing.assert_array_equal(expected_fill, out) + out = bm.take_along_axis(x, idx, axis=0, mode="fill").value + np.testing.assert_array_equal(expected_fill, out) + + expected_clip = np.array([0, 0, 1, 5, 9, 0, 1, 5, 9, 9, 9], np.float32) + out = bm.take_along_axis(x, idx, axis=0, mode="clip").value + np.testing.assert_array_equal(expected_clip, out) + + def testTakeAlongAxisRequiresIntIndices(self): + x = jnp.arange(5) + idx = jnp.array([3.], jnp.float32) + with self.assertRaisesRegex( + TypeError, + "take_along_axis indices must be of integer type, got float32"): + bm.take_along_axis(x, idx, axis=0).value + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}_n={}_increasing={}".format( + jtu.format_shape_dtype_string([shape], dtype), + n, increasing), + "dtype": dtype, "shape": shape, "n": n, "increasing": increasing} + for dtype in inexact_dtypes + for shape in [0, 5] + for n in [2, 4] + for increasing in [False, True])) + def testVander(self, shape, dtype, n, increasing): + rng = jtu.rand_default(self.rng()) + + def np_fun(arg): + arg = arg.astype(np.float32) if dtype == jnp.bfloat16 else arg + return np.vander(arg, N=n, increasing=increasing) + + bm_fun = lambda arg: bm.vander(arg, N=n, increasing=increasing) + args_maker = lambda: [rng([shape], dtype)] + # np.vander seems to return float64 for all floating types. We could obey + # those semantics, but they seem like a bug. + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol={np.float32: 1e-3}) + self._CompileAndCheck(bm_func(bm_fun), args_maker, check_dtypes=False) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( "nan_to_num", [shape], [dtype]), - "shape": shape, "dtype": dtype} + "shape": shape, "dtype": dtype} for shape in array_shapes for dtype in inexact_dtypes)) - def testNanToNum(self, shape, dtype): - rng = jtu.rand_some_inf_and_nan(self.rng()) - dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type - def np_fun(x): - if dtype == jnp.bfloat16: - x = np.where(np.isnan(x), dtype(0), x) - x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x) - x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x) - return x - else: - return np.nan_to_num(x).astype(dtype) - - args_maker = lambda: [rng(shape, dtype)] - check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE - self._CheckAgainstNumpy(np_fun, bm_func(bm.nan_to_num), args_maker, - check_dtypes=check_dtypes) - self._CompileAndCheck(bm_func(bm.nan_to_num), args_maker, - check_dtypes=check_dtypes) - - @parameterized.named_parameters(jtu.cases_from_list( + def testNanToNum(self, shape, dtype): + rng = jtu.rand_some_inf_and_nan(self.rng()) + dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type + + def np_fun(x): + if dtype == jnp.bfloat16: + x = np.where(np.isnan(x), dtype(0), x) + x = np.where(np.isposinf(x), jnp.finfo(dtype).max, x) + x = np.where(np.isneginf(x), jnp.finfo(dtype).min, x) + return x + else: + return np.nan_to_num(x).astype(dtype) + + args_maker = lambda: [rng(shape, dtype)] + check_dtypes = shape is not jtu.PYTHON_SCALAR_SHAPE + self._CheckAgainstNumpy(np_fun, bm_func(bm.nan_to_num), args_maker, + check_dtypes=check_dtypes) + self._CompileAndCheck(bm_func(bm.nan_to_num), args_maker, + check_dtypes=check_dtypes) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("ix_", shapes, dtypes), "shapes": shapes, "dtypes": dtypes} for shapes, dtypes in ( - ((), ()), - (((7,),), (np.int32,)), - (((3,), (4,)), (np.int32, np.int32)), - (((3,), (1,), (4,)), (np.int32, np.int32, np.int32)), + ((), ()), + (((7,),), (np.int32,)), + (((3,), (4,)), (np.int32, np.int32)), + (((3,), (1,), (4,)), (np.int32, np.int32, np.int32)), ))) - def testIx_(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype) - for shape, dtype in zip(shapes, dtypes)] - self._CheckAgainstNumpy(np.ix_, bm_func(bm.ix_), args_maker) - self._CompileAndCheck(bm_func(bm.ix_), args_maker) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": "_dimensions={}_dtype={}_sparse={}".format( - dimensions, dtype, sparse), - "dimensions": dimensions, "dtype": dtype, "sparse": sparse} - for dimensions in [(), (2,), (3, 0), (4, 5, 6)] - for dtype in number_dtypes - for sparse in [True, False])) - def testIndices(self, dimensions, dtype, sparse): - def args_maker(): return [] - np_fun = partial(np.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) - bm_fun = partial(bm.indices, dimensions=dimensions, - dtype=dtype, sparse=sparse) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( + def testIx_(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype) + for shape, dtype in zip(shapes, dtypes)] + self._CheckAgainstNumpy(np.ix_, bm_func(bm.ix_), args_maker) + self._CompileAndCheck(bm_func(bm.ix_), args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": "_dimensions={}_dtype={}_sparse={}".format( + dimensions, dtype, sparse), + "dimensions": dimensions, "dtype": dtype, "sparse": sparse} + for dimensions in [(), (2,), (3, 0), (4, 5, 6)] + for dtype in number_dtypes + for sparse in [True, False])) + def testIndices(self, dimensions, dtype, sparse): + def args_maker(): return [] + + np_fun = partial(np.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) + bm_fun = partial(bm.indices, dimensions=dimensions, + dtype=dtype, sparse=sparse) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": - "_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_method={}".format( - op, - jtu.format_shape_dtype_string(a_shape, a_dtype), - jtu.format_shape_dtype_string(q_shape, q_dtype), - axis, keepdims, method), - "a_rng": jtu.rand_some_nan, - "q_rng": q_rng, "op": op, - "a_shape": a_shape, "a_dtype": a_dtype, - "q_shape": q_shape, "q_dtype": q_dtype, "axis": axis, - "keepdims": keepdims, - "method": method} + "_op={}_a_shape={}_q_shape={}_axis={}_keepdims={}_method={}".format( + op, + jtu.format_shape_dtype_string(a_shape, a_dtype), + jtu.format_shape_dtype_string(q_shape, q_dtype), + axis, keepdims, method), + "a_rng": jtu.rand_some_nan, + "q_rng": q_rng, "op": op, + "a_shape": a_shape, "a_dtype": a_dtype, + "q_shape": q_shape, "q_dtype": q_dtype, "axis": axis, + "keepdims": keepdims, + "method": method} for (op, q_rng) in ( - ("percentile", partial(jtu.rand_uniform, low=0., high=100.)), - ("quantile", partial(jtu.rand_uniform, low=0., high=1.)), - ("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)), - ("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)), + ("percentile", partial(jtu.rand_uniform, low=0., high=100.)), + ("quantile", partial(jtu.rand_uniform, low=0., high=1.)), + ("nanpercentile", partial(jtu.rand_uniform, low=0., high=100.)), + ("nanquantile", partial(jtu.rand_uniform, low=0., high=1.)), ) for a_dtype in default_dtypes for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((47, 7), ()), - ((4, 101), 1), - ((4, 47, 7), (1, 2)), - ((4, 47, 7), (0, 2)), - ((4, 47, 7), (1, 0, 2)), + ((7,), None), + ((47, 7), 0), + ((47, 7), ()), + ((4, 101), 1), + ((4, 47, 7), (1, 2)), + ((4, 47, 7), (0, 2)), + ((4, 47, 7), (1, 0, 2)), ) for q_dtype in [np.float32] for q_shape in scalar_shapes + [(1,), (4,)] for keepdims in [False, True] for method in ['linear', 'lower', 'higher', 'nearest', 'midpoint'])) - def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, - axis, keepdims, method): - a_rng = a_rng(self.rng()) - q_rng = q_rng(self.rng()) - if "median" in op: - args_maker = lambda: [a_rng(a_shape, a_dtype)] - else: - args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] - - def np_fun(*args): - args = [x if jnp.result_type(x) != jnp.bfloat16 else - np.asarray(x, np.float32) for x in args] - if numpy_version <= (1, 22): - return getattr(np, op)(*args, axis=axis, keepdims=keepdims, - interpolation=method) - else: - return getattr(np, op)(*args, axis=axis, keepdims=keepdims, - method=method) - bm_fun = partial(getattr(bm, op), axis=axis, keepdims=keepdims, - method=method) - - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} - tol = max(jtu.tolerance(a_dtype, tol_spec), - jtu.tolerance(q_dtype, tol_spec)) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol) - - @unittest.skipIf(not config.jax_enable_x64, "test requires X64") - @unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision") - def testPercentilePrecision(self): - # Regression test for https://github.com/google/jax/issues/8513 - x = jnp.float64([1, 2, 3, 4, 7, 10]) - self.assertEqual(bm.percentile(x, 50).value, 3.5) - - @parameterized.named_parameters(jtu.cases_from_list( + def testQuantile(self, op, a_rng, q_rng, a_shape, a_dtype, q_shape, q_dtype, + axis, keepdims, method): + a_rng = a_rng(self.rng()) + q_rng = q_rng(self.rng()) + if "median" in op: + args_maker = lambda: [a_rng(a_shape, a_dtype)] + else: + args_maker = lambda: [a_rng(a_shape, a_dtype), q_rng(q_shape, q_dtype)] + + def np_fun(*args): + args = [x if jnp.result_type(x) != jnp.bfloat16 else + np.asarray(x, np.float32) for x in args] + if numpy_version <= (1, 22): + return getattr(np, op)(*args, axis=axis, keepdims=keepdims, + interpolation=method) + else: + return getattr(np, op)(*args, axis=axis, keepdims=keepdims, + method=method) + + bm_fun = partial(getattr(bm, op), axis=axis, keepdims=keepdims, + method=method) + + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} + tol = max(jtu.tolerance(a_dtype, tol_spec), + jtu.tolerance(q_dtype, tol_spec)) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol) + + @unittest.skipIf(not config.jax_enable_x64, "test requires X64") + @unittest.skipIf(jtu.device_under_test() != 'cpu', "test is for CPU float64 precision") + def testPercentilePrecision(self): + # Regression test for https://github.com/google/jax/issues/8513 + x = jnp.float64([1, 2, 3, 4, 7, 10]) + self.assertEqual(bm.percentile(x, 50).value, 3.5) + + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": - "_{}_a_shape={}_axis={}_keepdims={}".format( - op, jtu.format_shape_dtype_string(a_shape, a_dtype), - axis, keepdims), - "op": op, "a_shape": a_shape, "a_dtype": a_dtype, - "axis": axis, - "keepdims": keepdims} + "_{}_a_shape={}_axis={}_keepdims={}".format( + op, jtu.format_shape_dtype_string(a_shape, a_dtype), + axis, keepdims), + "op": op, "a_shape": a_shape, "a_dtype": a_dtype, + "axis": axis, + "keepdims": keepdims} for a_dtype in default_dtypes for a_shape, axis in ( - ((7,), None), - ((47, 7), 0), - ((4, 101), 1), + ((7,), None), + ((47, 7), 0), + ((4, 101), 1), ) for keepdims in [False, True] for op in ["median", "nanmedian"])) - def testMedian(self, op, a_shape, a_dtype, axis, keepdims): - if op == "median": - a_rng = jtu.rand_default(self.rng()) - else: - a_rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: [a_rng(a_shape, a_dtype)] - def np_fun(*args): - args = [x if jnp.result_type(x) != jnp.bfloat16 else - np.asarray(x, np.float32) for x in args] - return getattr(np, op)(*args, axis=axis, keepdims=keepdims) - bm_fun = partial(getattr(bm, op), axis=axis, keepdims=keepdims) - # TODO(phawkins): we currently set dtype=False because we aren't as - # aggressive about promoting to float64. It's not clear we want to mimic - # Numpy here. - tol_spec = {np.float32: 2e-4, np.float64: 5e-6} - tol = jtu.tolerance(a_dtype, tol_spec) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol) - - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_shape={}".format( - jtu.format_shape_dtype_string(shape, dtype)), - "shape": shape, "dtype": dtype} - for shape in all_shapes for dtype in all_dtypes)) - def testWhereOneArgument(self, shape, dtype): - rng = jtu.rand_some_zero(self.rng()) - np_fun = lambda x: np.where(x) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np_fun) - bm_fun = lambda x: bm.where(x) - args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - - # JIT compilation requires specifying a size statically. Full test of - # this behavior is in testNonzeroSize(). - bm_fun = lambda x: bm.where(x, size=np.size(x) // 2) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": "_{}".format("_".join( - jtu.format_shape_dtype_string(shape, dtype) - for shape, dtype in zip(shapes, dtypes))), - "shapes": shapes, "dtypes": dtypes + def testMedian(self, op, a_shape, a_dtype, axis, keepdims): + if op == "median": + a_rng = jtu.rand_default(self.rng()) + else: + a_rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: [a_rng(a_shape, a_dtype)] + + def np_fun(*args): + args = [x if jnp.result_type(x) != jnp.bfloat16 else + np.asarray(x, np.float32) for x in args] + return getattr(np, op)(*args, axis=axis, keepdims=keepdims) + + bm_fun = partial(getattr(bm, op), axis=axis, keepdims=keepdims) + # TODO(phawkins): we currently set dtype=False because we aren't as + # aggressive about promoting to float64. It's not clear we want to mimic + # Numpy here. + tol_spec = {np.float32: 2e-4, np.float64: 5e-6} + tol = jtu.tolerance(a_dtype, tol_spec) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_shape={}".format( + jtu.format_shape_dtype_string(shape, dtype)), + "shape": shape, "dtype": dtype} + for shape in all_shapes for dtype in all_dtypes)) + def testWhereOneArgument(self, shape, dtype): + rng = jtu.rand_some_zero(self.rng()) + np_fun = lambda x: np.where(x) + np_fun = jtu.ignore_warning( + category=DeprecationWarning, + message="Calling nonzero on 0d arrays.*")(np_fun) + bm_fun = lambda x: bm.where(x) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + + # JIT compilation requires specifying a size statically. Full test of + # this behavior is in testNonzeroSize(). + bm_fun = lambda x: bm.where(x, size=np.size(x) // 2) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": "_{}".format("_".join( + jtu.format_shape_dtype_string(shape, dtype) + for shape, dtype in zip(shapes, dtypes))), + "shapes": shapes, "dtypes": dtypes } for shapes in s(filter(_shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 3))) - for dtypes in s(itertools.combinations_with_replacement(all_dtypes, 3))))) - def testWhereThreeArgument(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, shapes, dtypes) - def np_fun(cond, x, y): - return _promote_like_jnp(partial(np.where, cond))(x, y) - self._CheckAgainstNumpy(np_fun, bm_func(bm.where), args_maker) - self._CompileAndCheck(bm_func(bm.where), args_maker) - - def testWhereScalarPromotion(self): - x = bm.where(jnp.array([True, False]), 3, - jnp.ones((2,), dtype=jnp.float32)).value - self.assertEqual(x.dtype, np.dtype(np.float32)) - - @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ - "testcase_name": jtu.format_test_name_suffix("", shapes, (np.bool_,) * n + dtypes), - "shapes": shapes, "dtypes": dtypes + itertools.combinations_with_replacement(all_shapes, 3))) + for dtypes in s(itertools.combinations_with_replacement(all_dtypes, 3))))) + def testWhereThreeArgument(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, shapes, dtypes) + + def np_fun(cond, x, y): + return _promote_like_jnp(partial(np.where, cond))(x, y) + + self._CheckAgainstNumpy(np_fun, bm_func(bm.where), args_maker) + self._CompileAndCheck(bm_func(bm.where), args_maker) + + def testWhereScalarPromotion(self): + x = bm.where(jnp.array([True, False]), 3, + jnp.ones((2,), dtype=jnp.float32)).value + self.assertEqual(x.dtype, np.dtype(np.float32)) + + @parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({ + "testcase_name": jtu.format_test_name_suffix("", shapes, (np.bool_,) * n + dtypes), + "shapes": shapes, "dtypes": dtypes } for n in s(range(1, 3)) - for shapes in s(filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2 * n + 1))) - for dtypes in s(itertools.combinations_with_replacement(all_dtypes, n + 1))))) - def testSelect(self, shapes, dtypes): - rng = jtu.rand_default(self.rng()) - n = len(dtypes) - 1 - def args_maker(): - condlist = [rng(shape, np.bool_) for shape in shapes[:n]] - choicelist = [rng(shape, dtype) - for shape, dtype in zip(shapes[n:-1], dtypes[:n])] - default = rng(shapes[-1], dtypes[-1]) - return condlist, choicelist, default - # TODO(phawkins): float32/float64 type mismatches - def np_fun(condlist, choicelist, default): - choicelist = [x if jnp.result_type(x) != jnp.bfloat16 - else x.astype(np.float32) for x in choicelist] - dtype = jnp.result_type(default, *choicelist) - return np.select(condlist, - [np.asarray(x, dtype=dtype) for x in choicelist], - np.asarray(default, dtype=dtype)) - self._CheckAgainstNumpy(np_fun, bm_func(bm.select), args_maker, - check_dtypes=False) - self._CompileAndCheck(bm_func(bm.select), args_maker, - rtol={np.float64: 1e-7, np.complex128: 1e-7}) - - - def testIssue330(self): - x = bm.full((1, 1), jnp.array([1])[0]).value # doesn't crash - self.assertEqual(x[0, 0], 1) - - def testScalarDtypePromotion(self): - orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype - jax_numpy_result = (1 + bm.eye(1, dtype=jnp.float32).value).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - def testSymmetrizeDtypePromotion(self): - x = np.eye(3, dtype=np.float32) - orig_numpy_result = ((x + x.T) / 2).dtype - - x = bm.eye(3, dtype=jnp.float32).value - jax_numpy_result = ((x + x.T) / 2).dtype - self.assertEqual(orig_numpy_result, jax_numpy_result) - - def testIssue453(self): - # https://github.com/google/jax/issues/453 - a = np.arange(6) + 1 - ans = bm.reshape(a, (3, 2), order='F').value - expected = np.reshape(a, (3, 2), order='F') - self.assertAllClose(ans, expected) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_op={}_dtype={}".format(op, dtype.__name__), - "dtype": dtype, "op": op} - for dtype in [int, float, bool, complex] - for op in ["atleast_1d", "atleast_2d", "atleast_3d"])) - def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 - np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) - bm_fun = lambda arg: getattr(bm, op)(arg) - args_maker = lambda: [dtype(2)] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - { - "testcase_name": "_shape={}_dtype={}_weights={}_minlength={}_length={}".format( - shape, dtype, weights, minlength, length - ), - "shape": shape, - "dtype": dtype, - "weights": weights, - "minlength": minlength, - "length": length} - for shape in [(0,), (5,), (10,)] - for dtype in int_dtypes - for weights in [True, False] - for minlength in [0, 20] - for length in [None, 8] - )) - def testBincount(self, shape, dtype, weights, minlength, length): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) - - def np_fun(x, *args): - x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero. - out = np.bincount(x, *args, minlength=minlength) - if length and length > out.size: - return np.pad(out, (0, length - out.size)) - return out[:length] - bm_fun = partial(bm.bincount, minlength=minlength, length=length) - - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - if length is not None: - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testBincountNegative(self): - # Test that jnp.bincount ignores negative values. - x_rng = jtu.rand_int(self.rng(), -100, 100) - w_rng = jtu.rand_uniform(self.rng()) - shape = (1000,) - x = x_rng(shape, 'int32') - w = w_rng(shape, 'float32') - - xn = np.array(x) - xn[xn < 0] = 0 - wn = np.array(w) - np_result = np.bincount(xn[xn >= 0], wn[xn >= 0]) - bm_result = bm.bincount(x, w).value - self.assertAllClose(np_result, bm_result, check_dtypes=False) - - @parameterized.named_parameters(*jtu.cases_from_list( - {"testcase_name": "_case={}".format(i), - "input": input} - for i, input in enumerate([ - 3, - [3], - [np.array(3)], - [np.array([3])], - [[np.array(3)]], - [[np.array([3])]], - [3, 4, 5], - [ - [np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)], - [np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3], - ], - [np.array([1, 2, 3]), np.array([2, 3, 4]), 10], - [np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)], - [[np.array([1, 2, 3])], [np.array([2, 3, 4])]], - ]))) - def testBlock(self, input): - args_maker = lambda: [input] - self._CheckAgainstNumpy(np.block, bm_func(bm.block), args_maker) - self._CompileAndCheck(bm_func(bm.block), args_maker) - - def testLongLong(self): - self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) - - @jtu.ignore_warning(category=UserWarning, - message="Explicitly requested dtype.*") - def testArange(self): - # test cases inspired by dask tests at - # https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92 - np_arange = jtu.with_jax_dtype_defaults(np.arange) - self.assertAllClose(bm.arange(77).value, - np_arange(77)) - self.assertAllClose(bm.arange(2, 13).value, - np_arange(2, 13)) - self.assertAllClose(bm.arange(4, 21, 9).value, - np_arange(4, 21, 9)) - self.assertAllClose(bm.arange(53, 5, -3).value, - np_arange(53, 5, -3)) - self.assertAllClose(bm.arange(77, dtype=float).value, - np_arange(77, dtype=float)) - self.assertAllClose(bm.arange(2, 13, dtype=int).value, - np_arange(2, 13, dtype=int)) - self.assertAllClose(bm.arange(0, 1, -0.5).value, - np_arange(0, 1, -0.5)) - - self.assertRaises(TypeError, lambda: bm.arange()) - - # test that jnp.arange(N) doesn't instantiate an ndarray - self.assertNotEqual(type(bm.arange(77).value), type(np.arange(77))) - self.assertEqual(type(bm.arange(77).value), type(lax.iota(np.int32, 77))) - - # test that bm.arange(N, dtype=int32) doesn't instantiate an ndarray - self.assertNotEqual(type(bm.arange(77, dtype=bm.int32).value), - type(np.arange(77, dtype=np.int32))) - self.assertEqual(type(bm.arange(77, dtype=bm.int32).value), - type(lax.iota(np.int32, 77))) - - def testArangeJit(self): - ans = jax.jit(lambda: bm.arange(5).value)() - expected = jtu.with_jax_dtype_defaults(np.arange)(5) - self.assertAllClose(ans, expected) - - @parameterized.named_parameters( - {"testcase_name": f"_{args}", "args": args} for args in [(5,), (0, 5)]) - def testArangeJaxpr(self, args): - jaxpr = jax.make_jaxpr(lambda: bm.arange(*args).value)() - self.assertEqual(len(jaxpr.jaxpr.eqns), 1) - self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]), + for shapes in s(filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2 * n + 1))) + for dtypes in s(itertools.combinations_with_replacement(all_dtypes, n + 1))))) + def testSelect(self, shapes, dtypes): + rng = jtu.rand_default(self.rng()) + n = len(dtypes) - 1 + + def args_maker(): + condlist = [rng(shape, np.bool_) for shape in shapes[:n]] + choicelist = [rng(shape, dtype) + for shape, dtype in zip(shapes[n:-1], dtypes[:n])] + default = rng(shapes[-1], dtypes[-1]) + return condlist, choicelist, default + + # TODO(phawkins): float32/float64 type mismatches + def np_fun(condlist, choicelist, default): + choicelist = [x if jnp.result_type(x) != jnp.bfloat16 + else x.astype(np.float32) for x in choicelist] + dtype = jnp.result_type(default, *choicelist) + return np.select(condlist, + [np.asarray(x, dtype=dtype) for x in choicelist], + np.asarray(default, dtype=dtype)) + + self._CheckAgainstNumpy(np_fun, bm_func(bm.select), args_maker, + check_dtypes=False) + self._CompileAndCheck(bm_func(bm.select), args_maker, + rtol={np.float64: 1e-7, np.complex128: 1e-7}) + + def testIssue330(self): + x = bm.full((1, 1), jnp.array([1])[0]).value # doesn't crash + self.assertEqual(x[0, 0], 1) + + def testScalarDtypePromotion(self): + orig_numpy_result = (1 + np.eye(1, dtype=np.float32)).dtype + jax_numpy_result = (1 + bm.eye(1, dtype=jnp.float32).value).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + def testSymmetrizeDtypePromotion(self): + x = np.eye(3, dtype=np.float32) + orig_numpy_result = ((x + x.T) / 2).dtype + + x = bm.eye(3, dtype=jnp.float32).value + jax_numpy_result = ((x + x.T) / 2).dtype + self.assertEqual(orig_numpy_result, jax_numpy_result) + + def testIssue453(self): + # https://github.com/google/jax/issues/453 + a = np.arange(6) + 1 + ans = bm.reshape(a, (3, 2), order='F').value + expected = np.reshape(a, (3, 2), order='F') + self.assertAllClose(ans, expected) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_op={}_dtype={}".format(op, dtype.__name__), "dtype": dtype, "op": op} - for dtype in float_dtypes - for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", - "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", - "log", "expm1", "log1p"))) - def testMathSpecialFloatValues(self, op, dtype): - np_op = getattr(np, op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="invalid value.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="overflow.*")(np_op) - - bm_op = getattr(bm, op) - dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type - for x in (np.nan, -np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf, - jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max), - np.sqrt(jnp.finfo(dtype).max) * 2.): - if (op in ("sin", "cos", "tan") and - jtu.device_under_test() == "tpu"): - continue # TODO(b/132196789): fix and reenable. - x = dtype(x) - expected = np_op(x) - actual = bm_op(x) - tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) - self.assertAllClose(expected, actual.value, atol=tol, - rtol=tol) - - - def testReductionOfOutOfBoundsAxis(self): # Issue 888 - x = bm.ones((3, 4)) - self.assertRaises(ValueError, lambda: bm.sum(x, axis=2).value) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True])) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.") - def np_fun(x): - out = np.var(x.astype(jnp.promote_types(np.float32, dtype)), - axis=axis, ddof=ddof, keepdims=keepdims) - return out.astype(out_dtype) - bm_fun = partial(bm.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, - np.float64: 1e-3, np.complex128: 1e-6}) - if (jnp.issubdtype(dtype, jnp.complexfloating) and - not jnp.issubdtype(out_dtype, jnp.complexfloating)): - self.assertRaises(ValueError, lambda: bm_fun(*args_maker())) - else: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol, - atol=tol) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" - .format(shape, dtype, out_dtype, axis, ddof, keepdims), - "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, - "ddof": ddof, "keepdims": keepdims} - for shape in [(5,), (10, 5)] - for dtype in all_dtypes - for out_dtype in inexact_dtypes - for axis in [None, 0, -1] - for ddof in [0, 1, 2] - for keepdims in [False, True])) - def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): - rng = jtu.rand_some_nan(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - @jtu.ignore_warning(category=RuntimeWarning, - message="Degrees of freedom <= 0 for slice.") - def np_fun(x): - out = np.nanvar(x.astype(jnp.promote_types(np.float32, dtype)), - axis=axis, ddof=ddof, keepdims=keepdims) - return out.astype(out_dtype) - bm_fun = partial(bm.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) - tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, - np.float64: 1e-3, np.complex128: 1e-6}) - if (jnp.issubdtype(dtype, jnp.complexfloating) and - not jnp.issubdtype(out_dtype, jnp.complexfloating)): - self.assertRaises(ValueError, lambda: bm_fun(*args_maker())) - else: - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol, - atol=tol) - - def testNanStdGrad(self): - # Regression test for https://github.com/google/jax/issues/8128 - x = bm.arange(5.0).at[0].set(jnp.nan) - y = jax.grad(bm_func(bm.nanvar))(x) - self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75])) - - z = jax.grad(bm_func(bm.nanstd))(x) - self.assertEqual(jnp.isnan(z).sum(), 0) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": - "_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format( - shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights), - "shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype,"rowvar": rowvar, "ddof": ddof, - "bias": bias, "fweights": fweights, "aweights": aweights} - for shape in [(5,), (10, 5), (5, 10)] - for dtype in all_dtypes - for y_dtype in [None, dtype] - for rowvar in [True, False] - for y_shape in _get_y_shapes(y_dtype, shape, rowvar) - for bias in [True, False] - for ddof in [None, 2, 3] - for fweights in [True, False] - for aweights in [True, False])) - def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights): - rng = jtu.rand_default(self.rng()) - wrng = jtu.rand_positive(self.rng()) - wdtype = np.real(dtype(0)).dtype - wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1] - - args_maker = lambda: [rng(shape, dtype), - rng(y_shape, y_dtype) if y_dtype else None, - wrng(wshape, int) if fweights else None, - wrng(wshape, wdtype) if aweights else None] - kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias) - np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs) - bm_fun = lambda m, y, f, a: bm.cov(m, y, fweights=f, aweights=a, **kwargs) - tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5, - np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13} - tol = 7e-2 if jtu.device_under_test() == "tpu" else tol - tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) - self._CheckAgainstNumpy( - np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, - rtol=tol) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shape={}_dtype={}_rowvar={}".format( - shape, dtype.__name__, rowvar), - "shape": shape, "dtype": dtype, "rowvar": rowvar} - for shape in [(5,), (10, 5), (3, 10)] - for dtype in number_dtypes - for rowvar in [True, False])) - def testCorrCoef(self, shape, dtype, rowvar): - rng = jtu.rand_default(self.rng()) - def args_maker(): - ok = False - while not ok: - x = rng(shape, dtype) - ok = not np.any(np.isclose(np.std(x), 0.0)) - return (x,) - np_fun = partial(np.corrcoef, rowvar=rowvar) - np_fun = jtu.ignore_warning( - category=RuntimeWarning, message="invalid value encountered.*")(np_fun) - bm_fun = partial(bm.corrcoef, rowvar=rowvar) - tol = 1e-2 if jtu.device_under_test() == "tpu" else None - self._CheckAgainstNumpy( - np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, rtol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": "_{}_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), - "None" if end_dtype is None else jtu.format_shape_dtype_string(end_shape, end_dtype), - "None" if begin_dtype is None else jtu.format_shape_dtype_string(begin_shape, begin_dtype)), - "shape": shape, "dtype": dtype, "end_shape": end_shape, - "end_dtype": end_dtype, "begin_shape": begin_shape, - "begin_dtype": begin_dtype} - for dtype in number_dtypes - for end_dtype in [None] + [dtype] - for begin_dtype in [None] + [dtype] - for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE] - for begin_shape in ( - [None] if begin_dtype is None - else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) - for end_shape in ( - [None] if end_dtype is None - else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]))) - def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape, - begin_dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), - (None if end_dtype is None else rng(end_shape, end_dtype)), - (None if begin_dtype is None else rng(begin_shape, begin_dtype))] - np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) - bm_fun = lambda x, to_end, to_begin: bm.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testEDiff1dWithDtypeCast(self): - rng = jtu.rand_default(self.rng()) - shape = jtu.NUMPY_SCALAR_SHAPE - dtype = jnp.float32 - end_dtype = jnp.int32 - args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] - np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) - bm_fun = lambda x, to_end, to_begin: bm.ediff1d(x, to_end, to_begin) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shapes={}_dtype={}_indexing={}_sparse={}".format( - shapes, dtype, indexing, sparse), - "shapes": shapes, "dtype": dtype, "indexing": indexing, - "sparse": sparse} - for shapes in [(), (5,), (5, 3)] + for dtype in [int, float, bool, complex] + for op in ["atleast_1d", "atleast_2d", "atleast_3d"])) + def testAtLeastNdLiterals(self, dtype, op): + # Fixes: https://github.com/google/jax/issues/634 + np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) + bm_fun = lambda arg: getattr(bm, op)(arg) + args_maker = lambda: [dtype(2)] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + { + "testcase_name": "_shape={}_dtype={}_weights={}_minlength={}_length={}".format( + shape, dtype, weights, minlength, length + ), + "shape": shape, + "dtype": dtype, + "weights": weights, + "minlength": minlength, + "length": length} + for shape in [(0,), (5,), (10,)] + for dtype in int_dtypes + for weights in [True, False] + for minlength in [0, 20] + for length in [None, 8] + )) + def testBincount(self, shape, dtype, weights, minlength, length): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype), (rng(shape, 'float32') if weights else None)) + + def np_fun(x, *args): + x = np.clip(x, 0, None) # jnp.bincount clips negative values to zero. + out = np.bincount(x, *args, minlength=minlength) + if length and length > out.size: + return np.pad(out, (0, length - out.size)) + return out[:length] + + bm_fun = partial(bm.bincount, minlength=minlength, length=length) + + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + if length is not None: + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testBincountNegative(self): + # Test that jnp.bincount ignores negative values. + x_rng = jtu.rand_int(self.rng(), -100, 100) + w_rng = jtu.rand_uniform(self.rng()) + shape = (1000,) + x = x_rng(shape, 'int32') + w = w_rng(shape, 'float32') + + xn = np.array(x) + xn[xn < 0] = 0 + wn = np.array(w) + np_result = np.bincount(xn[xn >= 0], wn[xn >= 0]) + bm_result = bm.bincount(x, w).value + self.assertAllClose(np_result, bm_result, check_dtypes=False) + + @parameterized.named_parameters(*jtu.cases_from_list( + {"testcase_name": "_case={}".format(i), + "input": input} + for i, input in enumerate([ + 3, + [3], + [np.array(3)], + [np.array([3])], + [[np.array(3)]], + [[np.array([3])]], + [3, 4, 5], + [ + [np.eye(2, dtype=np.int32) * 2, np.zeros((2, 3), dtype=np.int32)], + [np.ones((3, 2), dtype=np.int32), np.eye(3, dtype=np.int32) * 3], + ], + [np.array([1, 2, 3]), np.array([2, 3, 4]), 10], + [np.ones((2, 2), dtype=np.int32), np.zeros((2, 2), dtype=np.int32)], + [[np.array([1, 2, 3])], [np.array([2, 3, 4])]], + ]))) + def testBlock(self, input): + args_maker = lambda: [input] + self._CheckAgainstNumpy(np.block, bm_func(bm.block), args_maker) + self._CompileAndCheck(bm_func(bm.block), args_maker) + + def testLongLong(self): + self.assertAllClose(np.int64(7), jax.jit(lambda x: x)(np.longlong(7))) + + @jtu.ignore_warning(category=UserWarning, + message="Explicitly requested dtype.*") + def testArange(self): + # test cases inspired by dask tests at + # https://github.com/dask/dask/blob/main/dask/array/tests/test_creation.py#L92 + np_arange = jtu.with_jax_dtype_defaults(np.arange) + self.assertAllClose(bm.arange(77).value, + np_arange(77)) + self.assertAllClose(bm.arange(2, 13).value, + np_arange(2, 13)) + self.assertAllClose(bm.arange(4, 21, 9).value, + np_arange(4, 21, 9)) + self.assertAllClose(bm.arange(53, 5, -3).value, + np_arange(53, 5, -3)) + self.assertAllClose(bm.arange(77, dtype=float).value, + np_arange(77, dtype=float)) + self.assertAllClose(bm.arange(2, 13, dtype=int).value, + np_arange(2, 13, dtype=int)) + self.assertAllClose(bm.arange(0, 1, -0.5).value, + np_arange(0, 1, -0.5)) + + self.assertRaises(TypeError, lambda: bm.arange()) + + # test that jnp.arange(N) doesn't instantiate an ndarray + self.assertNotEqual(type(bm.arange(77).value), type(np.arange(77))) + self.assertEqual(type(bm.arange(77).value), type(lax.iota(np.int32, 77))) + + # test that bm.arange(N, dtype=int32) doesn't instantiate an ndarray + self.assertNotEqual(type(bm.arange(77, dtype=bm.int32).value), + type(np.arange(77, dtype=np.int32))) + self.assertEqual(type(bm.arange(77, dtype=bm.int32).value), + type(lax.iota(np.int32, 77))) + + def testArangeJit(self): + ans = jax.jit(lambda: bm.arange(5).value)() + expected = jtu.with_jax_dtype_defaults(np.arange)(5) + self.assertAllClose(ans, expected) + + @parameterized.named_parameters( + {"testcase_name": f"_{args}", "args": args} for args in [(5,), (0, 5)]) + def testArangeJaxpr(self, args): + jaxpr = jax.make_jaxpr(lambda: bm.arange(*args).value)() + self.assertEqual(len(jaxpr.jaxpr.eqns), 1) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, lax.iota_p) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]), + "dtype": dtype, "op": op} + for dtype in float_dtypes + for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan", + "sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp", + "log", "expm1", "log1p"))) + def testMathSpecialFloatValues(self, op, dtype): + np_op = getattr(np, op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="invalid value.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="divide by zero.*")(np_op) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="overflow.*")(np_op) + + bm_op = getattr(bm, op) + dtype = np.dtype(dtypes.canonicalize_dtype(dtype)).type + for x in (np.nan, -np.inf, -100., -2., -1., 0., 1., 2., 100., np.inf, + jnp.finfo(dtype).max, np.sqrt(jnp.finfo(dtype).max), + np.sqrt(jnp.finfo(dtype).max) * 2.): + if (op in ("sin", "cos", "tan") and + jtu.device_under_test() == "tpu"): + continue # TODO(b/132196789): fix and reenable. + x = dtype(x) + expected = np_op(x) + actual = bm_op(x) + tol = jtu.tolerance(dtype, {np.float32: 1e-3, np.float64: 1e-7}) + self.assertAllClose(expected, actual.value, atol=tol, + rtol=tol) + + def testReductionOfOutOfBoundsAxis(self): # Issue 888 + x = bm.ones((3, 4)) + self.assertRaises(ValueError, lambda: bm.sum(x, axis=2).value) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": + "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" + .format(shape, dtype.__name__, out_dtype.__name__, axis, ddof, keepdims), + "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, + "ddof": ddof, "keepdims": keepdims} + for shape in [(5,), (10, 5)] + for dtype in all_dtypes + for out_dtype in inexact_dtypes + for axis in [None, 0, -1] + for ddof in [0, 1, 2] + for keepdims in [False, True])) + def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.") + def np_fun(x): + out = np.var(x.astype(jnp.promote_types(np.float32, dtype)), + axis=axis, ddof=ddof, keepdims=keepdims) + return out.astype(out_dtype) + + bm_fun = partial(bm.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, + np.float64: 1e-3, np.complex128: 1e-6}) + if (jnp.issubdtype(dtype, jnp.complexfloating) and + not jnp.issubdtype(out_dtype, jnp.complexfloating)): + self.assertRaises(ValueError, lambda: bm_fun(*args_maker())) + else: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol, + atol=tol) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": + "_shape={}_dtype={}_out_dtype={}_axis={}_ddof={}_keepdims={}" + .format(shape, dtype, out_dtype, axis, ddof, keepdims), + "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "axis": axis, + "ddof": ddof, "keepdims": keepdims} + for shape in [(5,), (10, 5)] + for dtype in all_dtypes + for out_dtype in inexact_dtypes + for axis in [None, 0, -1] + for ddof in [0, 1, 2] + for keepdims in [False, True])) + def testNanVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + rng = jtu.rand_some_nan(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + + @jtu.ignore_warning(category=RuntimeWarning, + message="Degrees of freedom <= 0 for slice.") + def np_fun(x): + out = np.nanvar(x.astype(jnp.promote_types(np.float32, dtype)), + axis=axis, ddof=ddof, keepdims=keepdims) + return out.astype(out_dtype) + + bm_fun = partial(bm.nanvar, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, + np.float64: 1e-3, np.complex128: 1e-6}) + if (jnp.issubdtype(dtype, jnp.complexfloating) and + not jnp.issubdtype(out_dtype, jnp.complexfloating)): + self.assertRaises(ValueError, lambda: bm_fun(*args_maker())) + else: + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, rtol=tol, + atol=tol) + + def testNanStdGrad(self): + # Regression test for https://github.com/google/jax/issues/8128 + x = bm.arange(5.0).at[0].set(jnp.nan) + y = jax.grad(bm_func(bm.nanvar))(x) + self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75])) + + z = jax.grad(bm_func(bm.nanstd))(x) + self.assertEqual(jnp.isnan(z).sum(), 0) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": + "_shape={}_dtype={}_y_shape={}_y_dtype={}_rowvar={}_ddof={}_bias={}_fweights={}_aweights={}".format( + shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights), + "shape": shape, "y_shape": y_shape, "dtype": dtype, "y_dtype": y_dtype, "rowvar": rowvar, "ddof": ddof, + "bias": bias, "fweights": fweights, "aweights": aweights} + for shape in [(5,), (10, 5), (5, 10)] + for dtype in all_dtypes + for y_dtype in [None, dtype] + for rowvar in [True, False] + for y_shape in _get_y_shapes(y_dtype, shape, rowvar) + for bias in [True, False] + for ddof in [None, 2, 3] + for fweights in [True, False] + for aweights in [True, False])) + def testCov(self, shape, dtype, y_shape, y_dtype, rowvar, ddof, bias, fweights, aweights): + rng = jtu.rand_default(self.rng()) + wrng = jtu.rand_positive(self.rng()) + wdtype = np.real(dtype(0)).dtype + wshape = shape[-1:] if rowvar or shape[0] == 1 else shape[:1] + + args_maker = lambda: [rng(shape, dtype), + rng(y_shape, y_dtype) if y_dtype else None, + wrng(wshape, int) if fweights else None, + wrng(wshape, wdtype) if aweights else None] + kwargs = dict(rowvar=rowvar, ddof=ddof, bias=bias) + np_fun = lambda m, y, f, a: np.cov(m, y, fweights=f, aweights=a, **kwargs) + bm_fun = lambda m, y, f, a: bm.cov(m, y, fweights=f, aweights=a, **kwargs) + tol = {jnp.bfloat16: 5E-2, np.float16: 1E-2, np.float32: 1e-5, + np.float64: 1e-13, np.complex64: 1e-5, np.complex128: 1e-13} + tol = 7e-2 if jtu.device_under_test() == "tpu" else tol + tol = jtu.join_tolerance(tol, jtu.tolerance(dtype)) + self._CheckAgainstNumpy( + np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, + rtol=tol) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": "_shape={}_dtype={}_rowvar={}".format( + shape, dtype.__name__, rowvar), + "shape": shape, "dtype": dtype, "rowvar": rowvar} + for shape in [(5,), (10, 5), (3, 10)] + for dtype in number_dtypes + for rowvar in [True, False])) + def testCorrCoef(self, shape, dtype, rowvar): + rng = jtu.rand_default(self.rng()) + + def args_maker(): + ok = False + while not ok: + x = rng(shape, dtype) + ok = not np.any(np.isclose(np.std(x), 0.0)) + return (x,) + + np_fun = partial(np.corrcoef, rowvar=rowvar) + np_fun = jtu.ignore_warning( + category=RuntimeWarning, message="invalid value encountered.*")(np_fun) + bm_fun = partial(bm.corrcoef, rowvar=rowvar) + tol = 1e-2 if jtu.device_under_test() == "tpu" else None + self._CheckAgainstNumpy( + np_fun, bm_func(bm_fun), args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(bm_func(bm_fun), args_maker, atol=tol, rtol=tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_{}_{}".format(jtu.format_shape_dtype_string(shape, dtype), + "None" if end_dtype is None else jtu.format_shape_dtype_string(end_shape, + end_dtype), + "None" if begin_dtype is None else jtu.format_shape_dtype_string( + begin_shape, begin_dtype)), + "shape": shape, "dtype": dtype, "end_shape": end_shape, + "end_dtype": end_dtype, "begin_shape": begin_shape, + "begin_dtype": begin_dtype} for dtype in number_dtypes - for indexing in ['xy', 'ij'] - for sparse in [True, False])) - def testMeshGrid(self, shapes, dtype, indexing, sparse): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], - [dtype] * len(shapes)) - np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) - bm_fun = partial(bm.meshgrid, indexing=indexing, sparse=sparse) - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testMgrid(self): - # wrap indexer for appropriate dtype defaults. - np_mgrid = _indexer_with_default_outputs(np.mgrid) - assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) - assertAllEqual(np_mgrid[:4], bm.mgrid[:4]) - assertAllEqual(np_mgrid[:4,], bm.mgrid[:4,]) - assertAllEqual(np_mgrid[:4], jax.jit(lambda: bm.mgrid[:4])()) - assertAllEqual(np_mgrid[:5, :5], bm.mgrid[:5, :5]) - assertAllEqual(np_mgrid[:3, :2], bm.mgrid[:3, :2]) - assertAllEqual(np_mgrid[1:4:2], bm.mgrid[1:4:2]) - assertAllEqual(np_mgrid[1:5:3, :5], bm.mgrid[1:5:3, :5]) - assertAllEqual(np_mgrid[:3, :2, :5], bm.mgrid[:3, :2, :5]) - assertAllEqual(np_mgrid[:3:2, :2, :5], bm.mgrid[:3:2, :2, :5]) - # Corner cases - assertAllEqual(np_mgrid[:], bm.mgrid[:]) - # When the step length is a complex number, because of float calculation, - # the values between bm and np might slightly different. - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_mgrid[-1:1:5j], - bm.mgrid[-1:1:5j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[3:4:7j], - bm.mgrid[3:4:7j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[1:6:8j, 2:4], - bm.mgrid[1:6:8j, 2:4], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_mgrid[0:3.5:0.5], - bm.mgrid[0:3.5:0.5], - atol=atol, - rtol=rtol) - self.assertAllClose(np_mgrid[1.3:4.2:0.3], - bm.mgrid[1.3:4.2:0.3], - atol=atol, - rtol=rtol) - # abstract tracer value for bm.mgrid slice - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, - "slice start of jnp.mgrid"): - jax.jit(lambda a, b: bm.mgrid[a:b])(0, 2) - - def testOgrid(self): - # wrap indexer for appropriate dtype defaults. - np_ogrid = _indexer_with_default_outputs(np.ogrid) - def assertListOfArraysEqual(xs, ys): - self.assertIsInstance(xs, list) - self.assertIsInstance(ys, list) - self.assertEqual(len(xs), len(ys)) - for x, y in zip(xs, ys): - self.assertArraysEqual(x, y) - - self.assertArraysEqual(np_ogrid[:5], bm.ogrid[:5]) - self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: bm.ogrid[:5])()) - self.assertArraysEqual(np_ogrid[1:7:2], bm.ogrid[1:7:2]) - # List of arrays - assertListOfArraysEqual(np_ogrid[:5,], bm.ogrid[:5,]) - assertListOfArraysEqual(np_ogrid[0:5, 1:3], bm.ogrid[0:5, 1:3]) - assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], bm.ogrid[1:3:2, 2:9:3]) - assertListOfArraysEqual(np_ogrid[:5, :9, :11], bm.ogrid[:5, :9, :11]) - # Corner cases - self.assertArraysEqual(np_ogrid[:], bm.ogrid[:]) - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_ogrid[-1:1:5j], - bm.ogrid[-1:1:5j], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_ogrid[0:3.5:0.3], - bm.ogrid[0:3.5:0.3], - atol=atol, - rtol=rtol) - self.assertAllClose(np_ogrid[1.2:4.8:0.24], - bm.ogrid[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - # abstract tracer value for ogrid slice - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, - "slice start of jnp.ogrid"): - jax.jit(lambda a, b: bm.ogrid[a:b])(0, 2) - - def testR_(self): - a = np.arange(6).reshape((2,3)) - self.assertArraysEqual(np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])], - bm.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]) - self.assertArraysEqual(np.r_['-1', a, a], bm.r_['-1', a, a]) - - # wrap indexer for appropriate dtype defaults. - np_r_ = _indexer_with_default_outputs(np.r_) - self.assertArraysEqual(np_r_['0,2', [1,2,3], [4,5,6]], bm.r_['0,2', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np_r_['0,2,0', [1,2,3], [4,5,6]], bm.r_['0,2,0', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np_r_['1,2,0', [1,2,3], [4,5,6]], bm.r_['1,2,0', [1,2,3], [4,5,6]]) - # negative 1d axis start - self.assertArraysEqual(np_r_['0,4,-1', [1,2,3], [4,5,6]], bm.r_['0,4,-1', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np_r_['0,4,-2', [1,2,3], [4,5,6]], bm.r_['0,4,-2', [1,2,3], [4,5,6]]) - - # matrix directives - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) - self.assertArraysEqual(np_r_['r',[1,2,3], [4,5,6]], bm.r_['r',[1,2,3], [4,5,6]]) - self.assertArraysEqual(np_r_['c', [1, 2, 3], [4, 5, 6]], bm.r_['c', [1, 2, 3], [4, 5, 6]]) - - # bad directive - with self.assertRaisesRegex(ValueError, "could not understand directive.*"): - bm.r_["asdfgh",[1,2,3]] - # abstract tracer value for r_ slice - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, - "slice start of jnp.r_"): - jax.jit(lambda a, b: bm.r_[a:b])(0, 2) - - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_r_[-1:1:6j], - bm.r_[-1:1:6j], - atol=atol, - rtol=rtol) - self.assertAllClose(np_r_[-1:1:6j, [0]*3, 5, 6], - bm.r_[-1:1:6j, [0]*3, 5, 6], - atol=atol, - rtol=rtol) - # Non-integer steps - self.assertAllClose(np_r_[1.2:4.8:0.24], - bm.r_[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - - def testC_(self): - a = np.arange(6).reshape((2, 3)) - self.assertArraysEqual(np.c_[np.array([1,2,3]), np.array([4,5,6])], - bm.c_[np.array([1,2,3]), np.array([4,5,6])]) - self.assertArraysEqual(np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])], - bm.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]) - self.assertArraysEqual(np.c_['-1', a, a], bm.c_['-1', a, a]) - - # wrap indexer for appropriate dtype defaults. - np_c_ = _indexer_with_default_outputs(np.c_) - self.assertArraysEqual(np_c_['0,2', [1,2,3], [4,5,6]], bm.c_['0,2', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np_c_['0,2,0', [1,2,3], [4,5,6]], bm.c_['0,2,0', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np_c_['1,2,0', [1,2,3], [4,5,6]], bm.c_['1,2,0', [1,2,3], [4,5,6]]) - # negative 1d axis start - self.assertArraysEqual(np_c_['0,4,-1', [1,2,3], [4,5,6]], bm.c_['0,4,-1', [1,2,3], [4,5,6]]) - self.assertArraysEqual(np_c_['0,4,-2', [1,2,3], [4,5,6]], bm.c_['0,4,-2', [1,2,3], [4,5,6]]) - # matrix directives, avoid numpy deprecation warning - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=PendingDeprecationWarning) - self.assertArraysEqual(np_c_['r',[1,2,3], [4,5,6]], bm.c_['r',[1,2,3], [4,5,6]]) - self.assertArraysEqual(np_c_['c', [1, 2, 3], [4, 5, 6]], bm.c_['c', [1, 2, 3], [4, 5, 6]]) - - # bad directive - with self.assertRaisesRegex(ValueError, "could not understand directive.*"): - bm.c_["asdfgh",[1,2,3]] - # abstract tracer value for c_ slice - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, - "slice start of jnp.c_"): - jax.jit(lambda a, b: bm.c_[a:b])(0, 2) - - # Complex number steps - atol = 1e-6 - rtol = 1e-6 - self.assertAllClose(np_c_[-1:1:6j], - bm.c_[-1:1:6j], - atol=atol, - rtol=rtol) - - # Non-integer steps - self.assertAllClose(np_c_[1.2:4.8:0.24], - bm.c_[1.2:4.8:0.24], - atol=atol, - rtol=rtol) - - def testS_(self): - self.assertEqual(np.s_[1:2:20],bm.s_[1:2:20]) - - def testIndex_exp(self): - self.assertEqual(np.index_exp[5:3:2j],bm.index_exp[5:3:2j]) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": f"_start_shape={start_shape}_stop_shape={stop_shape}" - f"_num={num}_endpoint={endpoint}_retstep={retstep}" - f"_dtype={dtype.__name__ if dtype else 'None'}", - "start_shape": start_shape, "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "retstep": retstep, - "dtype": dtype} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for retstep in [True, False] - # floating-point compute between jitted platforms and non-jit + rounding - # cause unavoidable variation in integer truncation for some inputs, so - # we currently only test inexact 'dtype' arguments. - for dtype in inexact_dtypes + [None,])) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype): - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = jtu.tolerance(dtype if dtype else np.float32) * 10 - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(np.shape(start + stop)) - for axis in range(-ndim, ndim): - bm_op = lambda start, stop: bm.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - # NumPy 1.20.0 changed the semantics of linspace to floor for integer - # dtypes. - if numpy_version >= (1, 20) or not np.issubdtype(dtype, np.integer): - np_op = lambda start, stop: np.linspace( - start, stop, num, - endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) - else: + for end_dtype in [None] + [dtype] + for begin_dtype in [None] + [dtype] + for shape in [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE] + for begin_shape in ( + [None] if begin_dtype is None + else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]) + for end_shape in ( + [None] if end_dtype is None + else [s for s in all_shapes if s != jtu.PYTHON_SCALAR_SHAPE]))) + def testEDiff1d(self, shape, dtype, end_shape, end_dtype, begin_shape, + begin_dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), + (None if end_dtype is None else rng(end_shape, end_dtype)), + (None if begin_dtype is None else rng(begin_shape, begin_dtype))] + np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) + bm_fun = lambda x, to_end, to_begin: bm.ediff1d(x, to_end, to_begin) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testEDiff1dWithDtypeCast(self): + rng = jtu.rand_default(self.rng()) + shape = jtu.NUMPY_SCALAR_SHAPE + dtype = jnp.float32 + end_dtype = jnp.int32 + args_maker = lambda: [rng(shape, dtype), rng(shape, end_dtype), rng(shape, dtype)] + np_fun = lambda x, to_end, to_begin: np.ediff1d(x, to_end, to_begin) + bm_fun = lambda x, to_end, to_begin: bm.ediff1d(x, to_end, to_begin) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": "_shapes={}_dtype={}_indexing={}_sparse={}".format( + shapes, dtype, indexing, sparse), + "shapes": shapes, "dtype": dtype, "indexing": indexing, + "sparse": sparse} + for shapes in [(), (5,), (5, 3)] + for dtype in number_dtypes + for indexing in ['xy', 'ij'] + for sparse in [True, False])) + def testMeshGrid(self, shapes, dtype, indexing, sparse): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [(x,) for x in shapes], + [dtype] * len(shapes)) + np_fun = partial(np.meshgrid, indexing=indexing, sparse=sparse) + bm_fun = partial(bm.meshgrid, indexing=indexing, sparse=sparse) + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testMgrid(self): + # wrap indexer for appropriate dtype defaults. + np_mgrid = _indexer_with_default_outputs(np.mgrid) + assertAllEqual = partial(self.assertAllClose, atol=0, rtol=0) + assertAllEqual(np_mgrid[:4], bm.mgrid[:4]) + assertAllEqual(np_mgrid[:4, ], bm.mgrid[:4, ]) + assertAllEqual(np_mgrid[:4], jax.jit(lambda: bm.mgrid[:4])()) + assertAllEqual(np_mgrid[:5, :5], bm.mgrid[:5, :5]) + assertAllEqual(np_mgrid[:3, :2], bm.mgrid[:3, :2]) + assertAllEqual(np_mgrid[1:4:2], bm.mgrid[1:4:2]) + assertAllEqual(np_mgrid[1:5:3, :5], bm.mgrid[1:5:3, :5]) + assertAllEqual(np_mgrid[:3, :2, :5], bm.mgrid[:3, :2, :5]) + assertAllEqual(np_mgrid[:3:2, :2, :5], bm.mgrid[:3:2, :2, :5]) + # Corner cases + assertAllEqual(np_mgrid[:], bm.mgrid[:]) + # When the step length is a complex number, because of float calculation, + # the values between bm and np might slightly different. + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_mgrid[-1:1:5j], + bm.mgrid[-1:1:5j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[3:4:7j], + bm.mgrid[3:4:7j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[1:6:8j, 2:4], + bm.mgrid[1:6:8j, 2:4], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_mgrid[0:3.5:0.5], + bm.mgrid[0:3.5:0.5], + atol=atol, + rtol=rtol) + self.assertAllClose(np_mgrid[1.3:4.2:0.3], + bm.mgrid[1.3:4.2:0.3], + atol=atol, + rtol=rtol) + # abstract tracer value for bm.mgrid slice + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, + "slice start of jnp.mgrid"): + jax.jit(lambda a, b: bm.mgrid[a:b])(0, 2) + + def testOgrid(self): + # wrap indexer for appropriate dtype defaults. + np_ogrid = _indexer_with_default_outputs(np.ogrid) + + def assertListOfArraysEqual(xs, ys): + self.assertIsInstance(xs, list) + self.assertIsInstance(ys, list) + self.assertEqual(len(xs), len(ys)) + for x, y in zip(xs, ys): + self.assertArraysEqual(x, y) + + self.assertArraysEqual(np_ogrid[:5], bm.ogrid[:5]) + self.assertArraysEqual(np_ogrid[:5], jax.jit(lambda: bm.ogrid[:5])()) + self.assertArraysEqual(np_ogrid[1:7:2], bm.ogrid[1:7:2]) + # List of arrays + assertListOfArraysEqual(np_ogrid[:5, ], bm.ogrid[:5, ]) + assertListOfArraysEqual(np_ogrid[0:5, 1:3], bm.ogrid[0:5, 1:3]) + assertListOfArraysEqual(np_ogrid[1:3:2, 2:9:3], bm.ogrid[1:3:2, 2:9:3]) + assertListOfArraysEqual(np_ogrid[:5, :9, :11], bm.ogrid[:5, :9, :11]) + # Corner cases + self.assertArraysEqual(np_ogrid[:], bm.ogrid[:]) + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_ogrid[-1:1:5j], + bm.ogrid[-1:1:5j], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_ogrid[0:3.5:0.3], + bm.ogrid[0:3.5:0.3], + atol=atol, + rtol=rtol) + self.assertAllClose(np_ogrid[1.2:4.8:0.24], + bm.ogrid[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + # abstract tracer value for ogrid slice + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, + "slice start of jnp.ogrid"): + jax.jit(lambda a, b: bm.ogrid[a:b])(0, 2) + + def testR_(self): + a = np.arange(6).reshape((2, 3)) + self.assertArraysEqual(np.r_[np.array([1, 2, 3]), 0, 0, np.array([4, 5, 6])], + bm.r_[np.array([1, 2, 3]), 0, 0, np.array([4, 5, 6])]) + self.assertArraysEqual(np.r_['-1', a, a], bm.r_['-1', a, a]) + + # wrap indexer for appropriate dtype defaults. + np_r_ = _indexer_with_default_outputs(np.r_) + self.assertArraysEqual(np_r_['0,2', [1, 2, 3], [4, 5, 6]], bm.r_['0,2', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_r_['0,2,0', [1, 2, 3], [4, 5, 6]], bm.r_['0,2,0', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_r_['1,2,0', [1, 2, 3], [4, 5, 6]], bm.r_['1,2,0', [1, 2, 3], [4, 5, 6]]) + # negative 1d axis start + self.assertArraysEqual(np_r_['0,4,-1', [1, 2, 3], [4, 5, 6]], bm.r_['0,4,-1', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_r_['0,4,-2', [1, 2, 3], [4, 5, 6]], bm.r_['0,4,-2', [1, 2, 3], [4, 5, 6]]) + + # matrix directives + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + self.assertArraysEqual(np_r_['r', [1, 2, 3], [4, 5, 6]], bm.r_['r', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_r_['c', [1, 2, 3], [4, 5, 6]], bm.r_['c', [1, 2, 3], [4, 5, 6]]) + + # bad directive + with self.assertRaisesRegex(ValueError, "could not understand directive.*"): + bm.r_["asdfgh", [1, 2, 3]] + # abstract tracer value for r_ slice + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, + "slice start of jnp.r_"): + jax.jit(lambda a, b: bm.r_[a:b])(0, 2) + + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_r_[-1:1:6j], + bm.r_[-1:1:6j], + atol=atol, + rtol=rtol) + self.assertAllClose(np_r_[-1:1:6j, [0] * 3, 5, 6], + bm.r_[-1:1:6j, [0] * 3, 5, 6], + atol=atol, + rtol=rtol) + # Non-integer steps + self.assertAllClose(np_r_[1.2:4.8:0.24], + bm.r_[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + + def testC_(self): + a = np.arange(6).reshape((2, 3)) + self.assertArraysEqual(np.c_[np.array([1, 2, 3]), np.array([4, 5, 6])], + bm.c_[np.array([1, 2, 3]), np.array([4, 5, 6])]) + self.assertArraysEqual(np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])], + bm.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])]) + self.assertArraysEqual(np.c_['-1', a, a], bm.c_['-1', a, a]) + + # wrap indexer for appropriate dtype defaults. + np_c_ = _indexer_with_default_outputs(np.c_) + self.assertArraysEqual(np_c_['0,2', [1, 2, 3], [4, 5, 6]], bm.c_['0,2', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_c_['0,2,0', [1, 2, 3], [4, 5, 6]], bm.c_['0,2,0', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_c_['1,2,0', [1, 2, 3], [4, 5, 6]], bm.c_['1,2,0', [1, 2, 3], [4, 5, 6]]) + # negative 1d axis start + self.assertArraysEqual(np_c_['0,4,-1', [1, 2, 3], [4, 5, 6]], bm.c_['0,4,-1', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_c_['0,4,-2', [1, 2, 3], [4, 5, 6]], bm.c_['0,4,-2', [1, 2, 3], [4, 5, 6]]) + # matrix directives, avoid numpy deprecation warning + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=PendingDeprecationWarning) + self.assertArraysEqual(np_c_['r', [1, 2, 3], [4, 5, 6]], bm.c_['r', [1, 2, 3], [4, 5, 6]]) + self.assertArraysEqual(np_c_['c', [1, 2, 3], [4, 5, 6]], bm.c_['c', [1, 2, 3], [4, 5, 6]]) + + # bad directive + with self.assertRaisesRegex(ValueError, "could not understand directive.*"): + bm.c_["asdfgh", [1, 2, 3]] + # abstract tracer value for c_ slice + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, + "slice start of jnp.c_"): + jax.jit(lambda a, b: bm.c_[a:b])(0, 2) + + # Complex number steps + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np_c_[-1:1:6j], + bm.c_[-1:1:6j], + atol=atol, + rtol=rtol) + + # Non-integer steps + self.assertAllClose(np_c_[1.2:4.8:0.24], + bm.c_[1.2:4.8:0.24], + atol=atol, + rtol=rtol) + + def testS_(self): + self.assertEqual(np.s_[1:2:20], bm.s_[1:2:20]) + + def testIndex_exp(self): + self.assertEqual(np.index_exp[5:3:2j], bm.index_exp[5:3:2j]) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": f"_start_shape={start_shape}_stop_shape={stop_shape}" + f"_num={num}_endpoint={endpoint}_retstep={retstep}" + f"_dtype={dtype.__name__ if dtype else 'None'}", + "start_shape": start_shape, "stop_shape": stop_shape, + "num": num, "endpoint": endpoint, "retstep": retstep, + "dtype": dtype} + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for retstep in [True, False] + # floating-point compute between jitted platforms and non-jit + rounding + # cause unavoidable variation in integer truncation for some inputs, so + # we currently only test inexact 'dtype' arguments. + for dtype in inexact_dtypes + [None, ])) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLinspace(self, start_shape, stop_shape, num, endpoint, retstep, dtype): + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = jtu.tolerance(dtype if dtype else np.float32) * 10 + args_maker = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype]) + start, stop = args_maker() + ndim = len(np.shape(start + stop)) + for axis in range(-ndim, ndim): + bm_op = lambda start, stop: bm.linspace( + start, stop, num, + endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) + # NumPy 1.20.0 changed the semantics of linspace to floor for integer + # dtypes. + if numpy_version >= (1, 20) or not np.issubdtype(dtype, np.integer): + np_op = lambda start, stop: np.linspace( + start, stop, num, + endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis) + else: + def np_op(start, stop): + out = np.linspace(start, stop, num, endpoint=endpoint, + retstep=retstep, axis=axis) + if retstep: + return np.floor(out[0]).astype(dtype), out[1] + else: + return np.floor(out).astype(dtype) + + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, + check_dtypes=False, tol=tol) + self._CompileAndCheck(bm_func(bm_op), args_maker, + check_dtypes=False, atol=tol, rtol=tol) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype} + for dtype in number_dtypes)) + def testLinspaceEndpoints(self, dtype): + """Regression test for Issue #3014.""" + rng = jtu.rand_default(self.rng()) + endpoints = rng((2,), dtype) + out = bm.linspace(*endpoints, 10, dtype=dtype) + self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_base={}_dtype={}").format( + start_shape, stop_shape, num, endpoint, base, + dtype.__name__ if dtype else "None"), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, "endpoint": endpoint, "brainpy_object": base, + "dtype": dtype} + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + for base in [10.0, 2, np.e] + # skip 16-bit floats due to insufficient precision for the test. + for dtype in jtu.dtypes.inexact + [None, ])) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogspace(self, start_shape, stop_shape, num, + endpoint, base, dtype): + if (dtype in int_dtypes and + jtu.device_under_test() in ("gpu", "tpu") and + not config.x64_enabled): + raise unittest.SkipTest("GPUx32 truncated exponentiation" + " doesn't exactly match other platforms.") + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6} + args_maker = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype]) + start, stop = args_maker() + ndim = len(np.shape(start + stop)) + for axis in range(-ndim, ndim): + bm_op = lambda start, stop: bm.logspace( + start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) + + @jtu.ignore_warning(category=RuntimeWarning, + message="overflow encountered in power") + def np_op(start, stop): + return np.logspace(start, stop, num, endpoint=endpoint, + base=base, dtype=dtype, axis=axis) + + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, + check_dtypes=False, tol=tol) + if dtype in (inexact_dtypes + [None, ]): + # Why do compiled and op-by-op float16 np.power numbers differ + # slightly more than expected? + atol = {np.float16: 1e-2} + self._CompileAndCheck(bm_func(bm_op), args_maker, + check_dtypes=False, atol=atol, rtol=tol) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" + "_dtype={}_axis={}").format( + start_shape, stop_shape, num, endpoint, + dtype.__name__ if dtype else "None", axis), + "start_shape": start_shape, + "stop_shape": stop_shape, + "num": num, "endpoint": endpoint, + "dtype": dtype, "axis": axis} + for start_shape in [(), (2,), (2, 2)] + for stop_shape in [(), (2,), (2, 2)] + for num in [0, 1, 2, 5, 20] + for endpoint in [True, False] + # NB: numpy's geomspace gives nonsense results on integer types + for dtype in inexact_dtypes + [None, ] + for axis in range(-max(len(start_shape), len(stop_shape)), + max(len(start_shape), len(stop_shape))))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testGeomspace(self, start_shape, stop_shape, num, + endpoint, dtype, axis): + rng = jtu.rand_default(self.rng()) + # relax default tolerances slightly + tol = {np.float16: 4e-3, np.float32: 2e-3, np.float64: 1e-14, + np.complex128: 1e-14} + + def args_maker(): + """Test the set of inputs np.geomspace is well-defined on.""" + start, stop = self._GetArgsMaker(rng, + [start_shape, stop_shape], + [dtype, dtype])() + # np.geomspace can't handle differently ranked tensors + # w. negative numbers! + start, stop = jnp.broadcast_arrays(start, stop) + if dtype in complex_dtypes: + return start, stop + # to avoid NaNs, non-complex start and stop cannot + # differ in sign, elementwise + start = start * jnp.sign(start) * jnp.sign(stop) + return start, stop + + start, stop = args_maker() + + def bm_op(start, stop): + return bm.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, + axis=axis) + def np_op(start, stop): - out = np.linspace(start, stop, num, endpoint=endpoint, - retstep=retstep, axis=axis) - if retstep: - return np.floor(out[0]).astype(dtype), out[1] - else: - return np.floor(out).astype(dtype) - - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, - check_dtypes=False, tol=tol) - self._CompileAndCheck(bm_func(bm_op), args_maker, - check_dtypes=False, atol=tol, rtol=tol) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype} - for dtype in number_dtypes)) - def testLinspaceEndpoints(self, dtype): - """Regression test for Issue #3014.""" - rng = jtu.rand_default(self.rng()) - endpoints = rng((2,), dtype) - out = bm.linspace(*endpoints, 10, dtype=dtype) - self.assertAllClose(out[np.array([0, -1])], endpoints, rtol=0, atol=0) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_base={}_dtype={}").format( - start_shape, stop_shape, num, endpoint, base, - dtype.__name__ if dtype else "None"), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, "brainpy_object": base, - "dtype": dtype} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - for base in [10.0, 2, np.e] - # skip 16-bit floats due to insufficient precision for the test. - for dtype in jtu.dtypes.inexact + [None, ])) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogspace(self, start_shape, stop_shape, num, - endpoint, base, dtype): - if (dtype in int_dtypes and - jtu.device_under_test() in ("gpu", "tpu") and - not config.x64_enabled): - raise unittest.SkipTest("GPUx32 truncated exponentiation" - " doesn't exactly match other platforms.") - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = {np.float32: 1e-2, np.float64: 1e-6, np.complex64: 1e-3, np.complex128: 1e-6} - args_maker = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype]) - start, stop = args_maker() - ndim = len(np.shape(start + stop)) - for axis in range(-ndim, ndim): - bm_op = lambda start, stop: bm.logspace( - start, stop, num, endpoint=endpoint, base=base, dtype=dtype, axis=axis) - - @jtu.ignore_warning(category=RuntimeWarning, - message="overflow encountered in power") - def np_op(start, stop): - return np.logspace(start, stop, num, endpoint=endpoint, - base=base, dtype=dtype, axis=axis) - - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None, ]): - # Why do compiled and op-by-op float16 np.power numbers differ - # slightly more than expected? - atol = {np.float16: 1e-2} - self._CompileAndCheck(bm_func(bm_op), args_maker, - check_dtypes=False, atol=atol, rtol=tol) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}" - "_dtype={}_axis={}").format( - start_shape, stop_shape, num, endpoint, - dtype.__name__ if dtype else "None", axis), - "start_shape": start_shape, - "stop_shape": stop_shape, - "num": num, "endpoint": endpoint, - "dtype": dtype, "axis": axis} - for start_shape in [(), (2,), (2, 2)] - for stop_shape in [(), (2,), (2, 2)] - for num in [0, 1, 2, 5, 20] - for endpoint in [True, False] - # NB: numpy's geomspace gives nonsense results on integer types - for dtype in inexact_dtypes + [None,] - for axis in range(-max(len(start_shape), len(stop_shape)), - max(len(start_shape), len(stop_shape))))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testGeomspace(self, start_shape, stop_shape, num, - endpoint, dtype, axis): - rng = jtu.rand_default(self.rng()) - # relax default tolerances slightly - tol = {np.float16: 4e-3, np.float32: 2e-3, np.float64: 1e-14, - np.complex128: 1e-14} - def args_maker(): - """Test the set of inputs np.geomspace is well-defined on.""" - start, stop = self._GetArgsMaker(rng, - [start_shape, stop_shape], - [dtype, dtype])() - # np.geomspace can't handle differently ranked tensors - # w. negative numbers! - start, stop = jnp.broadcast_arrays(start, stop) - if dtype in complex_dtypes: - return start, stop - # to avoid NaNs, non-complex start and stop cannot - # differ in sign, elementwise - start = start * jnp.sign(start) * jnp.sign(stop) - return start, stop - start, stop = args_maker() - def bm_op(start, stop): - return bm.geomspace(start, stop, num, endpoint=endpoint, dtype=dtype, - axis=axis) - def np_op(start, stop): - start = start.astype(np.float32) if dtype == jnp.bfloat16 else start - stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop - return np.geomspace( - start, stop, num, endpoint=endpoint, - dtype=dtype if dtype != jnp.bfloat16 else np.float32, - axis=axis).astype(dtype) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, - check_dtypes=False, tol=tol) - if dtype in (inexact_dtypes + [None,]): - self._CompileAndCheck(bm_func(bm_op), args_maker, - check_dtypes=False, atol=tol, rtol=tol) - - def testDisableNumpyRankPromotionBroadcasting(self): - try: - prev_flag = config._read('jax_numpy_rank_promotion') - FLAGS.jax_numpy_rank_promotion = "allow" - bm.ones(2) + bm.ones((1, 2)) # works just fine - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag + start = start.astype(np.float32) if dtype == jnp.bfloat16 else start + stop = stop.astype(np.float32) if dtype == jnp.bfloat16 else stop + return np.geomspace( + start, stop, num, endpoint=endpoint, + dtype=dtype if dtype != jnp.bfloat16 else np.float32, + axis=axis).astype(dtype) + + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, + check_dtypes=False, tol=tol) + if dtype in (inexact_dtypes + [None, ]): + self._CompileAndCheck(bm_func(bm_op), args_maker, + check_dtypes=False, atol=tol, rtol=tol) + + def testDisableNumpyRankPromotionBroadcasting(self): + try: + prev_flag = config._read('jax_numpy_rank_promotion') + FLAGS.jax_numpy_rank_promotion = "allow" + bm.ones(2) + bm.ones((1, 2)) # works just fine + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + try: + prev_flag = config._read('jax_numpy_rank_promotion') + FLAGS.jax_numpy_rank_promotion = "raise" + self.assertRaises(ValueError, lambda: bm.ones(2) + bm.ones((1, 2))) + bm.ones(2) + 3 # don't want to raise for scalars + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + try: + prev_flag = config._read('jax_numpy_rank_promotion') + FLAGS.jax_numpy_rank_promotion = "warn" + self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " + r"shapes \(2,\) \(1, 2\).*", lambda: bm.ones(2) + bm.ones((1, 2))) + bm.ones(2) + 3 # don't want to warn for scalars + finally: + FLAGS.jax_numpy_rank_promotion = prev_flag + + def testStackArrayArgument(self): + # tests https://github.com/google/jax/issues/1271 + @jax.jit + def foo(x): + return bm.stack(x) + + foo(np.zeros(2)) # doesn't crash + + @jax.jit + def foo(x): + return bm.concatenate(x) + + foo(np.zeros((2, 2))) # doesn't crash + + def testReluGradientConstants(self): + # This is a regression test that verifies that constants associated with the + # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the + # outermost jaxpr. This was producing some large materialized constants for + # every relu activation in a model. + def body(i, xy): + x, y = xy + y = y + jax.grad(lambda z: bm.sum(bm.maximum(z, 0.)))(x) + return x, y + + f = lambda y: lax.fori_loop(0, 5, body, (y, y)) + jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32)) + self.assertFalse( + any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32)) + for x in jaxpr.consts)) + + @parameterized.named_parameters( + {"testcase_name": "_from={}_to={}".format(from_shape, to_shape), + "from_shape": from_shape, "to_shape": to_shape} + for from_shape, to_shape in [ + [(1, 3), (4, 3)], + [(3,), (2, 1, 3)], + [(3,), (3, 3)], + [(1,), (3,)], + [(1,), 3], + ]) + def testBroadcastTo(self, from_shape, to_shape): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) + np_op = lambda x: np.broadcast_to(x, to_shape) + bm_op = lambda x: bm.broadcast_to(x, to_shape) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) + self._CompileAndCheck(bm_func(bm_op), args_maker) + + @parameterized.named_parameters( + {"testcase_name": f"_{shapes}", "shapes": shapes, "broadcasted_shape": broadcasted_shape} + for shapes, broadcasted_shape in [ + [[], ()], + [[()], ()], + [[(1, 3), (4, 3)], (4, 3)], + [[(3,), (2, 1, 3)], (2, 1, 3)], + [[(3,), (3, 3)], (3, 3)], + [[(1,), (3,)], (3,)], + [[(1,), 3], (3,)], + [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], + [[[1], [0, 1]], (0, 1)], + [[(1,), np.array([0, 1])], (0, 1)], + ]) + def testBroadcastShapes(self, shapes, broadcasted_shape): + # Test against np.broadcast_shapes once numpy 1.20 is minimum required version + np.testing.assert_equal(bm.broadcast_shapes(*shapes), broadcasted_shape) + + def testBroadcastToOnScalar(self): + self.assertIsInstance(bm.broadcast_to(10.0, ()), bm.ndarray) + self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray) + + def testPrecision(self): + + ones_1d = np.ones((2,)) + ones_2d = np.ones((2, 2)) + ones_3d = np.ones((2, 2, 2)) + HIGHEST = lax.Precision.HIGHEST + + jtu.assert_dot_precision(None, bm.dot, ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.dot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.dot, precision=HIGHEST), + ones_3d, ones_3d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.matmul, precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.vdot, precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.tensordot, axes=2, precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.tensordot, axes=(0, 0), precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.tensordot, axes=((0,), (0,)), precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.einsum, 'i,i', precision=HIGHEST), + ones_1d, ones_1d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.einsum, 'ij,ij', precision=HIGHEST), + ones_2d, ones_2d) + jtu.assert_dot_precision( + HIGHEST, + partial(bm.inner, precision=HIGHEST), + ones_1d, ones_1d) + + @parameterized.named_parameters( + jtu.cases_from_list( + {"testcase_name": "_shape={}_varargs={} axis={}_dtype={}".format( + shape, varargs, axis, dtype), + "shape": shape, "varargs": varargs, "axis": axis, "dtype": dtype} + for shape in [(10,), (10, 15), (10, 15, 20)] + for _num_axes in range(len(shape)) + for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes) + for axis in itertools.combinations(range(len(shape)), _num_axes) + for dtype in inexact_dtypes)) + def testGradient(self, shape, varargs, axis, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = self._GetArgsMaker(rng, [shape], [dtype]) + bm_fun = lambda y: bm.gradient(y, *varargs, axis=axis) + np_fun = lambda y: np.gradient(y, *varargs, axis=axis) + self._CheckAgainstNumpy( + np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + def testTraceMethod(self): + x = self.rng().randn(3, 4).astype(bm.float_) + self.assertAllClose(x.trace(), bm.array(x).value.trace()) + self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) + + def testIntegerPowersArePrecise(self): + # See https://github.com/google/jax/pull/3036 + # Checks if the squares of float32 integers have no numerical errors. + # It should be satisfied with all integers less than sqrt(2**24). + x = bm.arange(-2 ** 12, 2 ** 12, dtype=bm.int32) + np.testing.assert_array_equal(bm.square(x.astype(bm.float32)).value, x * x) + np.testing.assert_array_equal(x.astype(bm.float32) ** 2, x * x) + + # Similarly for cubes. + x = bm.arange(-2 ** 8, 2 ** 8, dtype=bm.int32) + np.testing.assert_array_equal(x.astype(bm.float32) ** 3, x * x * x) + + x = np.arange(10, dtype=np.float32) + for i in range(10): + self.assertAllClose(x.astype(bm.float32) ** i, x ** i, + check_dtypes=False) + + def testToBytes(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + for order in ['C', 'F']: + self.assertEqual(bm.asarray(v).tobytes(order), v.tobytes(order)) + + def testToList(self): + v = np.arange(12, dtype=np.int32).reshape(3, 4) + self.assertEqual(bm.asarray(v).tolist(), v.tolist()) + + def testReductionWithRepeatedAxisError(self): + with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"): + bm.sum(bm.arange(3), (0, 0)) + + def testArangeConcretizationError(self): + msg = r"It arose in jax.numpy.arange argument `{}`".format + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')): + jax.jit(bm.arange)(3) + + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('start')): + jax.jit(lambda start: bm.arange(start, 3))(0) + + with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')): + jax.jit(lambda stop: bm.arange(0, stop))(3) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": str(dtype), "dtype": dtype} + for dtype in [None] + float_dtypes)) + def testArange64Bit(self, dtype): + # Test that jnp.arange uses 64-bit arithmetic to define its range, even if the + # output has another dtype. The issue here is that if python scalar inputs to + # jnp.arange are cast to float32 before the range is computed, it changes the + # number of elements output by the range. It's unclear whether this was deliberate + # behavior in the initial implementation, but it's behavior that downstream users + # have come to rely on. + args = (1.2, 4.8, 0.24) + + # Ensure that this test case leads to differing lengths if cast to float32. + self.assertLen(np.arange(*args), 15) + self.assertLen(np.arange(*map(np.float32, args)), 16) + + bm_fun = lambda: bm.arange(*args, dtype=dtype) + np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None) + args_maker = lambda: [] + self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) + self._CompileAndCheck(bm_func(bm_fun), args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), + "shapes": shapes, "dtypes": dtypes} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes)))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogaddexpComplex(self, shapes, dtypes): + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_op(x1, x2): + return np.log(np.exp(x1) + np.exp(x2)) + + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) + if jtu.device_under_test() == 'tpu': + tol = {np.complex64: 1e-3, np.complex128: 1e-10} + else: + tol = {np.complex64: 1e-5, np.complex128: 1e-14} + self._CheckAgainstNumpy(_promote_like_jnp(np_op), bm_func(bm.logaddexp), args_maker, tol=tol) + self._CompileAndCheck(bm_func(bm.logaddexp), args_maker, rtol=tol, atol=tol) - try: - prev_flag = config._read('jax_numpy_rank_promotion') - FLAGS.jax_numpy_rank_promotion = "raise" - self.assertRaises(ValueError, lambda: bm.ones(2) + bm.ones((1, 2))) - bm.ones(2) + 3 # don't want to raise for scalars - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), + "shapes": shapes, "dtypes": dtypes} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(all_shapes, 2)) + for dtypes in itertools.product( + *(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes)))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testLogaddexp2Complex(self, shapes, dtypes): + @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") + def np_op(x1, x2): + return np.log2(np.exp2(x1) + np.exp2(x2)) + + rng = jtu.rand_some_nan(self.rng()) + args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) + if jtu.device_under_test() == 'tpu': + tol = {np.complex64: 1e-3, np.complex128: 1e-10} + else: + tol = {np.complex64: 1e-5, np.complex128: 1e-14} + self._CheckAgainstNumpy(_promote_like_jnp(np_op), bm_func(bm.logaddexp2), args_maker, tol=tol) + self._CompileAndCheck(bm_func(bm.logaddexp2), args_maker, rtol=tol, atol=tol) - try: - prev_flag = config._read('jax_numpy_rank_promotion') - FLAGS.jax_numpy_rank_promotion = "warn" - self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " - r"shapes \(2,\) \(1, 2\).*", lambda: bm.ones(2) + bm.ones((1, 2))) - bm.ones(2) + 3 # don't want to warn for scalars - finally: - FLAGS.jax_numpy_rank_promotion = prev_flag - - def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 - @jax.jit - def foo(x): - return bm.stack(x) - foo(np.zeros(2)) # doesn't crash - - @jax.jit - def foo(x): - return bm.concatenate(x) - foo(np.zeros((2, 2))) # doesn't crash - - def testReluGradientConstants(self): - # This is a regression test that verifies that constants associated with the - # gradient of np.maximum (from lax._balanced_eq) aren't hoisted into the - # outermost jaxpr. This was producing some large materialized constants for - # every relu activation in a model. - def body(i, xy): - x, y = xy - y = y + jax.grad(lambda z: bm.sum(bm.maximum(z, 0.)))(x) - return x, y - - f = lambda y: lax.fori_loop(0, 5, body, (y, y)) - jaxpr = jax.make_jaxpr(f)(np.zeros((3, 4), np.float32)) - self.assertFalse( - any(np.array_equal(x, np.full((3, 4), 2., dtype=np.float32)) - for x in jaxpr.consts)) - - @parameterized.named_parameters( - {"testcase_name": "_from={}_to={}".format(from_shape, to_shape), - "from_shape": from_shape, "to_shape": to_shape} - for from_shape, to_shape in [ - [(1, 3), (4, 3)], - [(3,), (2, 1, 3)], - [(3,), (3, 3)], - [(1,), (3,)], - [(1,), 3], - ]) - def testBroadcastTo(self, from_shape, to_shape): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [from_shape], [np.float32]) - np_op = lambda x: np.broadcast_to(x, to_shape) - bm_op = lambda x: bm.broadcast_to(x, to_shape) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker) - self._CompileAndCheck(bm_func(bm_op), args_maker) - - @parameterized.named_parameters( - {"testcase_name": f"_{shapes}", "shapes": shapes, "broadcasted_shape": broadcasted_shape} - for shapes, broadcasted_shape in [ - [[], ()], - [[()], ()], - [[(1, 3), (4, 3)], (4, 3)], - [[(3,), (2, 1, 3)], (2, 1, 3)], - [[(3,), (3, 3)], (3, 3)], - [[(1,), (3,)], (3,)], - [[(1,), 3], (3,)], - [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)], - [[[1], [0, 1]], (0, 1)], - [[(1,), np.array([0, 1])], (0, 1)], - ]) - def testBroadcastShapes(self, shapes, broadcasted_shape): - # Test against np.broadcast_shapes once numpy 1.20 is minimum required version - np.testing.assert_equal(bm.broadcast_shapes(*shapes), broadcasted_shape) - - def testBroadcastToOnScalar(self): - self.assertIsInstance(bm.broadcast_to(10.0, ()), bm.ndarray) - self.assertIsInstance(np.broadcast_to(10.0, ()), np.ndarray) - - def testPrecision(self): - - ones_1d = np.ones((2,)) - ones_2d = np.ones((2, 2)) - ones_3d = np.ones((2, 2, 2)) - HIGHEST = lax.Precision.HIGHEST - - jtu.assert_dot_precision(None, bm.dot, ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.dot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.dot, precision=HIGHEST), - ones_3d, ones_3d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.matmul, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.vdot, precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.tensordot, axes=2, precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.tensordot, axes=(0, 0), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.tensordot, axes=((0,), (0,)), precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.einsum, 'i,i', precision=HIGHEST), - ones_1d, ones_1d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.einsum, 'ij,ij', precision=HIGHEST), - ones_2d, ones_2d) - jtu.assert_dot_precision( - HIGHEST, - partial(bm.inner, precision=HIGHEST), - ones_1d, ones_1d) - - @parameterized.named_parameters( - jtu.cases_from_list( - {"testcase_name": "_shape={}_varargs={} axis={}_dtype={}".format( - shape, varargs, axis, dtype), - "shape": shape, "varargs": varargs, "axis": axis, "dtype": dtype} - for shape in [(10,), (10, 15), (10, 15, 20)] - for _num_axes in range(len(shape)) - for varargs in itertools.combinations(range(1, len(shape) + 1), _num_axes) - for axis in itertools.combinations(range(len(shape)), _num_axes) - for dtype in inexact_dtypes)) - def testGradient(self, shape, varargs, axis, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = self._GetArgsMaker(rng, [shape], [dtype]) - bm_fun = lambda y: bm.gradient(y, *varargs, axis=axis) - np_fun = lambda y: np.gradient(y, *varargs, axis=axis) - self._CheckAgainstNumpy( - np_fun, bm_func(bm_fun), args_maker, check_dtypes=False) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - def testTraceMethod(self): - x = self.rng().randn(3, 4).astype(bm.float_) - self.assertAllClose(x.trace(), bm.array(x).value.trace()) - self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) - - def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 - # Checks if the squares of float32 integers have no numerical errors. - # It should be satisfied with all integers less than sqrt(2**24). - x = bm.arange(-2**12, 2**12, dtype=bm.int32) - np.testing.assert_array_equal(bm.square(x.astype(bm.float32)).value, x * x) - np.testing.assert_array_equal(x.astype(bm.float32) ** 2, x * x) - - # Similarly for cubes. - x = bm.arange(-2**8, 2**8, dtype=bm.int32) - np.testing.assert_array_equal(x.astype(bm.float32) ** 3, x * x * x) - - x = np.arange(10, dtype=np.float32) - for i in range(10): - self.assertAllClose(x.astype(bm.float32) ** i, x ** i, - check_dtypes=False) - - def testToBytes(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - for order in ['C', 'F']: - self.assertEqual(bm.asarray(v).tobytes(order), v.tobytes(order)) - - def testToList(self): - v = np.arange(12, dtype=np.int32).reshape(3, 4) - self.assertEqual(bm.asarray(v).tolist(), v.tolist()) - - def testReductionWithRepeatedAxisError(self): - with self.assertRaisesRegex(ValueError, r"duplicate value in 'axis': \(0, 0\)"): - bm.sum(bm.arange(3), (0, 0)) - - def testArangeConcretizationError(self): - msg = r"It arose in jax.numpy.arange argument `{}`".format - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')): - jax.jit(bm.arange)(3) - - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('start')): - jax.jit(lambda start: bm.arange(start, 3))(0) - - with self.assertRaisesRegex(jax.core.ConcretizationTypeError, msg('stop')): - jax.jit(lambda stop: bm.arange(0, stop))(3) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": str(dtype), "dtype": dtype} - for dtype in [None] + float_dtypes)) - def testArange64Bit(self, dtype): - # Test that jnp.arange uses 64-bit arithmetic to define its range, even if the - # output has another dtype. The issue here is that if python scalar inputs to - # jnp.arange are cast to float32 before the range is computed, it changes the - # number of elements output by the range. It's unclear whether this was deliberate - # behavior in the initial implementation, but it's behavior that downstream users - # have come to rely on. - args = (1.2, 4.8, 0.24) - - # Ensure that this test case leads to differing lengths if cast to float32. - self.assertLen(np.arange(*args), 15) - self.assertLen(np.arange(*map(np.float32, args)), 16) - - bm_fun = lambda: bm.arange(*args, dtype=dtype) - np_fun = jtu.with_jax_dtype_defaults(lambda: np.arange(*args, dtype=dtype), dtype is None) - args_maker = lambda: [] - self._CheckAgainstNumpy(np_fun, bm_func(bm_fun), args_maker) - self._CompileAndCheck(bm_func(bm_fun), args_maker) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), - "shapes": shapes, "dtypes": dtypes} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes)))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogaddexpComplex(self, shapes, dtypes): - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_op(x1, x2): - return np.log(np.exp(x1) + np.exp(x2)) - - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) - if jtu.device_under_test() == 'tpu': - tol = {np.complex64: 1e-3, np.complex128: 1e-10} - else: - tol = {np.complex64: 1e-5, np.complex128: 1e-14} - self._CheckAgainstNumpy(_promote_like_jnp(np_op), bm_func(bm.logaddexp), args_maker, tol=tol) - self._CompileAndCheck(bm_func(bm.logaddexp), args_maker, rtol=tol, atol=tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), - "shapes": shapes, "dtypes": dtypes} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(all_shapes, 2)) - for dtypes in itertools.product( - *(_valid_dtypes_for_shape(s, complex_dtypes) for s in shapes)))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testLogaddexp2Complex(self, shapes, dtypes): - @jtu.ignore_warning(category=RuntimeWarning, message="invalid value.*") - def np_op(x1, x2): - return np.log2(np.exp2(x1) + np.exp2(x2)) - - rng = jtu.rand_some_nan(self.rng()) - args_maker = lambda: tuple(rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)) - if jtu.device_under_test() == 'tpu': - tol = {np.complex64: 1e-3, np.complex128: 1e-10} - else: - tol = {np.complex64: 1e-5, np.complex128: 1e-14} - self._CheckAgainstNumpy(_promote_like_jnp(np_op), bm_func(bm.logaddexp2), args_maker, tol=tol) - self._CompileAndCheck(bm_func(bm.logaddexp2), args_maker, rtol=tol, atol=tol) - - def testFromBuffer(self): - buf = b'\x01\x02\x03' - expected = np.frombuffer(buf, dtype='uint8') - actual = bm.frombuffer(buf, dtype='uint8') - self.assertArraysEqual(expected, actual) - - def testFromFunction(self): - def f(x, y, z): - return x + 2 * y + 3 * z - shape = (3, 4, 5) - expected = np.fromfunction(f, shape=shape) - actual = bm.fromfunction(f, shape=shape) - self.assertArraysEqual(expected, actual) - - def testFromString(self): - s = "1,2,3" - expected = np.fromstring(s, sep=',', dtype=int) - actual = bm.fromstring(s, sep=',', dtype=int) - self.assertArraysEqual(expected, actual) + def testFromBuffer(self): + buf = b'\x01\x02\x03' + expected = np.frombuffer(buf, dtype='uint8') + actual = bm.frombuffer(buf, dtype='uint8') + self.assertArraysEqual(expected, actual) + + def testFromFunction(self): + def f(x, y, z): + return x + 2 * y + 3 * z + + shape = (3, 4, 5) + expected = np.fromfunction(f, shape=shape) + actual = bm.fromfunction(f, shape=shape) + self.assertArraysEqual(expected, actual) + + def testFromString(self): + s = "1,2,3" + expected = np.fromstring(s, sep=',', dtype=int) + actual = bm.fromstring(s, sep=',', dtype=int) + self.assertArraysEqual(expected, actual) # Most grad tests are at the lax level (see lax_test.py), but we add some here @@ -5959,9 +6056,12 @@ def testFromString(self): GradTestSpec = collections.namedtuple( "GradTestSpec", ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"]) + + def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): - return GradTestSpec( - op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + return GradTestSpec( + op, nargs, order, rng_factory, dtypes, name or op.__name__, tol) + GRAD_TEST_RECORDS = [ grad_test_spec(bm.arcsinh, nargs=1, order=2, @@ -5994,83 +6094,84 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None): GradSpecialValuesTestSpec(bm.sinc, [0.], 1), ] + @pytest.mark.skipif(True, reason="No longer need to test.") @jtu.with_config(jax_numpy_dtype_promotion='standard') class NumpyGradTests(jtu.JaxTestCase): - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix( - rec.name, shapes, itertools.repeat(dtype)), - "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, - "order": rec.order, "tol": rec.tol} - for shapes in itertools.combinations_with_replacement(nonempty_shapes, rec.nargs) - for dtype in rec.dtypes) - for rec in GRAD_TEST_RECORDS)) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): - rng = rng_factory(self.rng()) - tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3, - np.complex64: 1e-1, np.complex128: 1e-3}) - args = tuple(rng(shape, dtype) for shape in shapes) - check_grads(op, args, order, ["fwd", "rev"], tol, tol) - - @parameterized.named_parameters(itertools.chain.from_iterable( - jtu.cases_from_list( - {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), - "op": rec.op, "special_value": special_value, "order": rec.order} - for special_value in rec.values) - for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS)) - def testOpGradSpecialValue(self, op, special_value, order): - check_grads(op, (special_value,), order, ["fwd", "rev"], - atol={np.float32: 3e-3}) - - def testSincGradArrayInput(self): - # tests for a bug almost introduced in #5077 - jax.grad(lambda x: bm.sinc(x).sum())(jnp.arange(10.)) # doesn't crash - - def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 - idx = bm.repeat(jnp.arange(3), 10).reshape((30, 1)) - - def f(x): - y = x * jnp.arange(3.).reshape((1, 3)) - return bm.take_along_axis(y, idx, -1).sum() - - check_grads(f, (1.,), order=1) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)), - "shapes": shapes, "dtype": dtype} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(nonempty_shapes, 2)) - for dtype in (np.complex128, ))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testGradLogaddexpComplex(self, shapes, dtype): - rng = jtu.rand_default(self.rng()) - args = tuple(rng(shape, dtype) for shape in shapes) - if jtu.device_under_test() == "tpu": - tol = 5e-2 - else: - tol = 3e-2 - check_grads(bm.logaddexp, args, 1, ["fwd", "rev"], tol, tol) - - @parameterized.named_parameters(jtu.cases_from_list( - {"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)), - "shapes": shapes, "dtype": dtype} - for shapes in filter( - _shapes_are_broadcast_compatible, - itertools.combinations_with_replacement(nonempty_shapes, 2)) - for dtype in (np.complex128, ))) - @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def testGradLogaddexp2Complex(self, shapes, dtype): - rng = jtu.rand_default(self.rng()) - args = tuple(rng(shape, dtype) for shape in shapes) - if jtu.device_under_test() == "tpu": - tol = 5e-2 - else: - tol = 3e-2 - check_grads(bm.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix( + rec.name, shapes, itertools.repeat(dtype)), + "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, + "order": rec.order, "tol": rec.tol} + for shapes in itertools.combinations_with_replacement(nonempty_shapes, rec.nargs) + for dtype in rec.dtypes) + for rec in GRAD_TEST_RECORDS)) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): + rng = rng_factory(self.rng()) + tol = jtu.join_tolerance(tol, {np.float32: 1e-1, np.float64: 1e-3, + np.complex64: 1e-1, np.complex128: 1e-3}) + args = tuple(rng(shape, dtype) for shape in shapes) + check_grads(op, args, order, ["fwd", "rev"], tol, tol) + + @parameterized.named_parameters(itertools.chain.from_iterable( + jtu.cases_from_list( + {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), + "op": rec.op, "special_value": special_value, "order": rec.order} + for special_value in rec.values) + for rec in GRAD_SPECIAL_VALUE_TEST_RECORDS)) + def testOpGradSpecialValue(self, op, special_value, order): + check_grads(op, (special_value,), order, ["fwd", "rev"], + atol={np.float32: 3e-3}) + + def testSincGradArrayInput(self): + # tests for a bug almost introduced in #5077 + jax.grad(lambda x: bm.sinc(x).sum())(jnp.arange(10.)) # doesn't crash + + def testTakeAlongAxisIssue1521(self): + # https://github.com/google/jax/issues/1521 + idx = bm.repeat(jnp.arange(3), 10).reshape((30, 1)) + + def f(x): + y = x * jnp.arange(3.).reshape((1, 3)) + return bm.take_along_axis(y, idx, -1).sum() + + check_grads(f, (1.,), order=1) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)), + "shapes": shapes, "dtype": dtype} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(nonempty_shapes, 2)) + for dtype in (np.complex128,))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testGradLogaddexpComplex(self, shapes, dtype): + rng = jtu.rand_default(self.rng()) + args = tuple(rng(shape, dtype) for shape in shapes) + if jtu.device_under_test() == "tpu": + tol = 5e-2 + else: + tol = 3e-2 + check_grads(bm.logaddexp, args, 1, ["fwd", "rev"], tol, tol) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": jtu.format_test_name_suffix("", shapes, itertools.repeat(dtype)), + "shapes": shapes, "dtype": dtype} + for shapes in filter( + _shapes_are_broadcast_compatible, + itertools.combinations_with_replacement(nonempty_shapes, 2)) + for dtype in (np.complex128,))) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testGradLogaddexp2Complex(self, shapes, dtype): + rng = jtu.rand_default(self.rng()) + args = tuple(rng(shape, dtype) for shape in shapes) + if jtu.device_under_test() == "tpu": + tol = 5e-2 + else: + tol = 3e-2 + check_grads(bm.logaddexp2, args, 1, ["fwd", "rev"], tol, tol) _available_numpy_dtypes: List[str] = [dtype.__name__ for dtype in jtu.dtypes.all @@ -6078,59 +6179,60 @@ def testGradLogaddexp2Complex(self, shapes, dtype): def _all_numpy_ufuncs() -> Iterator[str]: - """Generate the names of all ufuncs in the top-level numpy namespace.""" - for name in dir(np): - f = getattr(np, name) - if isinstance(f, np.ufunc): - yield name + """Generate the names of all ufuncs in the top-level numpy namespace.""" + for name in dir(np): + f = getattr(np, name) + if isinstance(f, np.ufunc): + yield name def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]: - """Generate valid dtypes of inputs to the given numpy ufunc.""" - func = getattr(np, name) - for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): - args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) - try: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning) - _ = func(*args) - except TypeError: - pass - else: - yield arg_dtypes + """Generate valid dtypes of inputs to the given numpy ufunc.""" + func = getattr(np, name) + for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): + args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "divide by zero", RuntimeWarning) + _ = func(*args) + except TypeError: + pass + else: + yield arg_dtypes + @pytest.mark.skipif(True, reason="No longer need to test.") @jtu.with_config(jax_numpy_dtype_promotion='standard') class NumpyUfuncTests(jtu.JaxTestCase): - @parameterized.named_parameters( - {"testcase_name": f"_{name}_{','.join(arg_dtypes)}", - "name": name, "arg_dtypes": arg_dtypes} - for name in _all_numpy_ufuncs() - for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name))) - def testUfuncInputTypes(self, name, arg_dtypes): - if (name in ['divmod', 'floor_divide', 'fmod', 'gcd', 'left_shift', 'mod', - 'power', 'remainder', 'right_shift', 'rint', 'square'] - and 'bool_' in arg_dtypes): - self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}") - if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating): - self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.") - if name == 'spacing': - self.skipTest("No spacing operators.") - bm_op = getattr(bm, name) - np_op = getattr(np, name) - np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) - args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) - - try: - bm_op(*args_maker()) - except NotImplementedError: - self.skipTest(f"jtu.{name} is not yet implemented.") - - # large tol comes from the fact that numpy returns float16 in places - # that jnp returns float32. e.g. np.cos(np.uint8(0)) - self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, check_dtypes=False, tol=1E-2) + @parameterized.named_parameters( + {"testcase_name": f"_{name}_{','.join(arg_dtypes)}", + "name": name, "arg_dtypes": arg_dtypes} + for name in _all_numpy_ufuncs() + for arg_dtypes in jtu.cases_from_list(_dtypes_for_ufunc(name))) + def testUfuncInputTypes(self, name, arg_dtypes): + if (name in ['divmod', 'floor_divide', 'fmod', 'gcd', 'left_shift', 'mod', + 'power', 'remainder', 'right_shift', 'rint', 'square'] + and 'bool_' in arg_dtypes): + self.skipTest(f"jax.numpy does not support {name}{tuple(arg_dtypes)}") + if name == 'arctanh' and jnp.issubdtype(arg_dtypes[0], jnp.complexfloating): + self.skipTest("np.arctanh & jnp.arctanh have mismatched NaNs for complex input.") + if name == 'spacing': + self.skipTest("No spacing operators.") + bm_op = getattr(bm, name) + np_op = getattr(np, name) + np_op = jtu.ignore_warning(category=RuntimeWarning, + message="divide by zero.*")(np_op) + args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) + + try: + bm_op(*args_maker()) + except NotImplementedError: + self.skipTest(f"jtu.{name} is not yet implemented.") + + # large tol comes from the fact that numpy returns float16 in places + # that jnp returns float32. e.g. np.cos(np.uint8(0)) + self._CheckAgainstNumpy(np_op, bm_func(bm_op), args_maker, check_dtypes=False, tol=1E-2) if __name__ == "__main__": - absltest.main(testLoader=jtu.JaxTestLoader()) + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/brainpy/_src/math/tests/test_oprators.py b/brainpy/_src/math/tests/test_oprators.py index 019dc143f..0483f8b4c 100644 --- a/brainpy/_src/math/tests/test_oprators.py +++ b/brainpy/_src/math/tests/test_oprators.py @@ -30,61 +30,55 @@ class TestSyn2Post(unittest.TestCase): - def test_syn2post_sum(self): - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - self.assertTrue(bm.array_equal(bm.syn2post_sum(data, segment_ids, 3), - bm.asarray([1, 5, 4]))) - + def test_syn2post_sum(self): + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + self.assertTrue(bm.array_equal(bm.syn2post_sum(data, segment_ids, 3), + bm.asarray([1, 5, 4]))) - def test_syn2post_max(self): - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - self.assertTrue(bm.array_equal(bm.syn2post_max(data, segment_ids, 3), - bm.asarray([1, 3, 4]))) - + def test_syn2post_max(self): + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + self.assertTrue(bm.array_equal(bm.syn2post_max(data, segment_ids, 3), + bm.asarray([1, 3, 4]))) - def test_syn2post_min(self): - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - self.assertTrue(bm.array_equal(bm.syn2post_min(data, segment_ids, 3), - bm.asarray([0, 2, 4]))) - + def test_syn2post_min(self): + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + self.assertTrue(bm.array_equal(bm.syn2post_min(data, segment_ids, 3), + bm.asarray([0, 2, 4]))) - def test_syn2post_prod(self): - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - self.assertTrue(bm.array_equal(bm.syn2post_prod(data, segment_ids, 3), - bm.asarray([0, 6, 4]))) - + def test_syn2post_prod(self): + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + self.assertTrue(bm.array_equal(bm.syn2post_prod(data, segment_ids, 3), + bm.asarray([0, 6, 4]))) - def test_syn2post_mean(self): - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - self.assertTrue(bm.array_equal(bm.syn2post_mean(data, segment_ids, 3), - bm.asarray([0.5, 2.5, 4.]))) - + def test_syn2post_mean(self): + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + self.assertTrue(bm.array_equal(bm.syn2post_mean(data, segment_ids, 3), + bm.asarray([0.5, 2.5, 4.]))) - def test_syn2post_softmax(self): - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - f_ans = bm.syn2post_softmax(data, segment_ids, 3) - true_ans = bm.asarray([jnp.exp(data[0]) / (jnp.exp(data[0]) + jnp.exp(data[1])), - jnp.exp(data[1]) / (jnp.exp(data[0]) + jnp.exp(data[1])), - jnp.exp(data[2]) / (jnp.exp(data[2]) + jnp.exp(data[3])), - jnp.exp(data[3]) / (jnp.exp(data[2]) + jnp.exp(data[3])), - jnp.exp(data[4]) / jnp.exp(data[4])]) - print() - print(bm.asarray(f_ans)) - print(true_ans) - print(f_ans == true_ans) - # self.assertTrue(bm.array_equal(bm.syn2post_softmax(data, segment_ids, 3), - # true_ans)) + def test_syn2post_softmax(self): + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + f_ans = bm.syn2post_softmax(data, segment_ids, 3) + true_ans = bm.asarray([jnp.exp(data[0]) / (jnp.exp(data[0]) + jnp.exp(data[1])), + jnp.exp(data[1]) / (jnp.exp(data[0]) + jnp.exp(data[1])), + jnp.exp(data[2]) / (jnp.exp(data[2]) + jnp.exp(data[3])), + jnp.exp(data[3]) / (jnp.exp(data[2]) + jnp.exp(data[3])), + jnp.exp(data[4]) / jnp.exp(data[4])]) + print() + print(bm.asarray(f_ans)) + print(true_ans) + print(f_ans == true_ans) + # self.assertTrue(bm.array_equal(bm.syn2post_softmax(data, segment_ids, 3), + # true_ans)) - data = bm.arange(5) - segment_ids = bm.array([0, 0, 1, 1, 2]) - print(bm.syn2post_softmax(data, segment_ids, 4)) - + data = bm.arange(5) + segment_ids = bm.array([0, 0, 1, 1, 2]) + print(bm.syn2post_softmax(data, segment_ids, 4)) # # class TestSparseMatmul(unittest.TestCase): diff --git a/brainpy/_src/math/tests/test_others.py b/brainpy/_src/math/tests/test_others.py index 084b8664d..bfebf8a7e 100644 --- a/brainpy/_src/math/tests/test_others.py +++ b/brainpy/_src/math/tests/test_others.py @@ -1,21 +1,18 @@ +from unittest import TestCase -import brainpy.math as bm from scipy.special import exprel -from unittest import TestCase +import brainpy.math as bm class Test_exprel(TestCase): - def test1(self): - for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]: - print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}') - # self.assertEqual(exprel(x)) - - def test2(self): - bm.enable_x64() - for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]: - print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}') - # self.assertEqual(exprel(x)) - - - + def test1(self): + for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]: + print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}') + # self.assertEqual(exprel(x)) + + def test2(self): + bm.enable_x64() + for x in [1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9]: + print(f'{exprel(x)}, {bm.exprel(x)}, {exprel(x) - bm.exprel(x):.10f}') + # self.assertEqual(exprel(x)) diff --git a/brainpy/_src/math/tests/test_random.py b/brainpy/_src/math/tests/test_random.py index 6b0760f6f..46d2f498c 100644 --- a/brainpy/_src/math/tests/test_random.py +++ b/brainpy/_src/math/tests/test_random.py @@ -11,553 +11,553 @@ class TestRandom(unittest.TestCase): - def test_seed(self): - test_seed = 299 - br.seed(test_seed) - a = br.rand(3) - br.seed(test_seed) - b = br.rand(3) - self.assertTrue(bm.array_equal(a, b)) - - def test_rand(self): - br.seed() - a = br.rand(3, 2) - self.assertTupleEqual(a.shape, (3, 2)) - self.assertTrue((a >= 0).all() and (a < 1).all()) - - key = jr.PRNGKey(123) - jres = jr.uniform(key, shape=(10, 100)) - self.assertTrue(bm.allclose(jres, br.rand(10, 100, key=key))) - self.assertTrue(bm.allclose(jres, br.rand(10, 100, key=123))) - - def test_randint1(self): - br.seed() - a = br.randint(5) - self.assertTupleEqual(a.shape, ()) - self.assertTrue(0 <= a < 5) - - def test_randint2(self): - br.seed() - a = br.randint(2, 6, size=(4, 3)) - self.assertTupleEqual(a.shape, (4, 3)) - self.assertTrue((a >= 2).all() and (a < 6).all()) - - def test_randint3(self): - br.seed() - a = br.randint([1, 2, 3], [10, 7, 8]) - self.assertTupleEqual(a.shape, (3,)) - self.assertTrue((a - bm.array([1, 2, 3]) >= 0).all() - and (-a + bm.array([10, 7, 8]) > 0).all()) - - def test_randint4(self): - br.seed() - a = br.randint([1, 2, 3], [10, 7, 8], size=(2, 3)) - self.assertTupleEqual(a.shape, (2, 3)) - - def test_randn(self): - br.seed() - a = br.randn(3, 2) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_random1(self): - br.seed() - a = br.random() - self.assertTrue(0. <= a < 1) - - def test_random2(self): - br.seed() - a = br.random(size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - self.assertTrue((a >= 0).all() and (a < 1).all()) - - def test_random_sample(self): - br.seed() - a = br.random_sample(size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - self.assertTrue((a >= 0).all() and (a < 1).all()) - - def test_choice1(self): - br.seed() - a = bm.random.choice(5) - self.assertTupleEqual(jnp.shape(a), ()) - self.assertTrue(0 <= a < 5) - - def test_choice2(self): - br.seed() - a = bm.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3]) - self.assertTupleEqual(a.shape, (3,)) - self.assertTrue((a >= 0).all() and (a < 5).all()) - - def test_choice3(self): - br.seed() - a = bm.random.choice(bm.arange(2, 20), size=(4, 3), replace=False) - self.assertTupleEqual(a.shape, (4, 3)) - self.assertTrue((a >= 2).all() and (a < 20).all()) - self.assertEqual(len(bm.unique(a)), 12) - - def test_permutation1(self): - br.seed() - a = bm.random.permutation(10) - self.assertTupleEqual(a.shape, (10,)) - self.assertEqual(len(bm.unique(a)), 10) - - def test_permutation2(self): - br.seed() - a = bm.random.permutation(bm.arange(10)) - self.assertTupleEqual(a.shape, (10,)) - self.assertEqual(len(bm.unique(a)), 10) - - def test_shuffle1(self): - br.seed() - a = bm.arange(10) - bm.random.shuffle(a) - self.assertTupleEqual(a.shape, (10,)) - self.assertEqual(len(bm.unique(a)), 10) - - def test_shuffle2(self): - br.seed() - a = bm.arange(12).reshape(4, 3) - bm.random.shuffle(a, axis=1) - self.assertTupleEqual(a.shape, (4, 3)) - self.assertEqual(len(bm.unique(a)), 12) - - # test that a is only shuffled along axis 1 - uni = bm.unique(bm.diff(a, axis=0)) - self.assertEqual(uni, bm.Array([3])) - - def test_beta1(self): - br.seed() - a = bm.random.beta(2, 2) - self.assertTupleEqual(a.shape, ()) - - def test_beta2(self): - br.seed() - a = bm.random.beta([2, 2, 3], 2, size=(3,)) - self.assertTupleEqual(a.shape, (3,)) - - def test_exponential1(self): - br.seed() - a = bm.random.exponential(10., size=[3, 2]) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_exponential2(self): - br.seed() - a = bm.random.exponential([1., 2., 5.]) - self.assertTupleEqual(a.shape, (3,)) - - def test_gamma(self): - br.seed() - a = bm.random.gamma(2, 10., size=[3, 2]) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_gumbel(self): - br.seed() - a = bm.random.gumbel(0., 2., size=[3, 2]) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_laplace(self): - br.seed() - a = bm.random.laplace(0., 2., size=[3, 2]) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_logistic(self): - br.seed() - a = bm.random.logistic(0., 2., size=[3, 2]) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_normal1(self): - br.seed() - a = bm.random.normal() - self.assertTupleEqual(a.shape, ()) - - def test_normal2(self): - br.seed() - a = bm.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.]) - self.assertTupleEqual(a.shape, (3,)) - - def test_normal3(self): - br.seed() - a = bm.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]]) - print(a) - self.assertTupleEqual(a.shape, (2, 3)) - - def test_pareto(self): - br.seed() - a = bm.random.pareto([1, 2, 2]) - self.assertTupleEqual(a.shape, (3,)) - - def test_poisson(self): - br.seed() - a = bm.random.poisson([1., 2., 2.], size=3) - self.assertTupleEqual(a.shape, (3,)) - - def test_standard_cauchy(self): - br.seed() - a = bm.random.standard_cauchy(size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_standard_exponential(self): - br.seed() - a = bm.random.standard_exponential(size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_standard_gamma(self): - br.seed() - a = bm.random.standard_gamma(shape=[1, 2, 4], size=3) - self.assertTupleEqual(a.shape, (3,)) - - def test_standard_normal(self): - br.seed() - a = bm.random.standard_normal(size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_standard_t(self): - br.seed() - a = bm.random.standard_t(df=[1, 2, 4], size=3) - self.assertTupleEqual(a.shape, (3,)) - - def test_standard_uniform1(self): - br.seed() - a = bm.random.uniform() - self.assertTupleEqual(a.shape, ()) - self.assertTrue(0 <= a < 1) - - def test_uniform2(self): - br.seed() - a = bm.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3) - self.assertTupleEqual(a.shape, (3,)) - self.assertTrue((a - bm.array([-1., 5., 2.]) >= 0).all() - and (-a + bm.array([2., 6., 10.]) > 0).all()) - - def test_uniform3(self): - br.seed() - a = bm.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3)) - self.assertTupleEqual(a.shape, (2, 3)) - - def test_uniform4(self): - br.seed() - a = bm.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]]) - self.assertTupleEqual(a.shape, (2, 3)) - - def test_truncated_normal1(self): - br.seed() - a = bm.random.truncated_normal(-1., 1.) - self.assertTupleEqual(a.shape, ()) - self.assertTrue(-1. <= a <= 1.) - - def test_truncated_normal2(self): - br.seed() - a = bm.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3)) - self.assertTupleEqual(a.shape, (4, 3)) - - def test_truncated_normal3(self): - br.seed() - a = bm.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]]) - self.assertTupleEqual(a.shape, (2, 3)) - self.assertTrue((a - bm.array([-1., 0., 1.]) >= 0.).all() - and (- a + bm.array([2., 2., 4.]) >= 0.).all()) - - def test_bernoulli1(self): - br.seed() - a = bm.random.bernoulli() - self.assertTupleEqual(a.shape, ()) - self.assertTrue(a == 0 or a == 1) - - def test_bernoulli2(self): - br.seed() - a = bm.random.bernoulli([0.5, 0.6, 0.8]) - self.assertTupleEqual(a.shape, (3,)) - self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) - - def test_bernoulli3(self): - br.seed() - a = bm.random.bernoulli([0.5, 0.6], size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) - - def test_lognormal1(self): - br.seed() - a = bm.random.lognormal() - self.assertTupleEqual(a.shape, ()) - - def test_lognormal2(self): - br.seed() - a = bm.random.lognormal(sigma=[2., 1.], size=[3, 2]) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_lognormal3(self): - br.seed() - a = bm.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]]) - self.assertTupleEqual(a.shape, (2, 2)) - - def test_binomial1(self): - br.seed() - a = bm.random.binomial(5, 0.5) - b = np.random.binomial(5, 0.5) - print(a) - print(b) - self.assertTupleEqual(a.shape, ()) - self.assertTrue(a.dtype, int) - - def test_binomial2(self): - br.seed() - a = bm.random.binomial(5, 0.5, size=(3, 2)) - self.assertTupleEqual(a.shape, (3, 2)) - self.assertTrue((a >= 0).all() and (a <= 5).all()) - - def test_binomial3(self): - br.seed() - a = bm.random.binomial(n=bm.asarray([2, 3, 4]), p=bm.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]])) - self.assertTupleEqual(a.shape, (2, 3)) - - def test_chisquare1(self): - br.seed() - a = bm.random.chisquare(3) - self.assertTupleEqual(a.shape, ()) - self.assertTrue(a.dtype, float) - - def test_chisquare2(self): - br.seed() - with self.assertRaises(NotImplementedError): - a = bm.random.chisquare(df=[2, 3, 4]) - - def test_chisquare3(self): - br.seed() - a = bm.random.chisquare(df=2, size=100) - self.assertTupleEqual(a.shape, (100,)) - - def test_chisquare4(self): - br.seed() - a = bm.random.chisquare(df=2, size=(100, 10)) - self.assertTupleEqual(a.shape, (100, 10)) - - def test_dirichlet1(self): - br.seed() - a = bm.random.dirichlet((10, 5, 3)) - self.assertTupleEqual(a.shape, (3,)) - - def test_dirichlet2(self): - br.seed() - a = bm.random.dirichlet((10, 5, 3), 20) - self.assertTupleEqual(a.shape, (20, 3)) - - def test_f(self): - br.seed() - a = bm.random.f(1., 48., 100) - self.assertTupleEqual(a.shape, (100,)) - - def test_geometric(self): - br.seed() - a = bm.random.geometric([0.7, 0.5, 0.2]) - self.assertTupleEqual(a.shape, (3,)) - - def test_hypergeometric1(self): - br.seed() - a = bm.random.hypergeometric(10, 10, 10, 20) - self.assertTupleEqual(a.shape, (20,)) - - @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') - def test_hypergeometric2(self): - br.seed() - a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]]) - self.assertTupleEqual(a.shape, (2, 2)) - - @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') - def test_hypergeometric3(self): - br.seed() - a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2)) - self.assertTupleEqual(a.shape, (3, 2, 2)) - - def test_logseries(self): - br.seed() - a = bm.random.logseries([0.7, 0.5, 0.2], size=[4, 3]) - self.assertTupleEqual(a.shape, (4, 3)) - - def test_multinominal1(self): - br.seed() - a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) - print(a, a.shape) - b = bm.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) - print(b, b.shape) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (4, 2, 3)) - - def test_multinominal2(self): - br.seed() - a = bm.random.multinomial(100, (0.5, 0.2, 0.3)) - self.assertTupleEqual(a.shape, (3,)) - self.assertTrue(a.sum() == 100) - - def test_multivariate_normal1(self): - br.seed() - # self.skipTest('Windows jaxlib error') - a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) - b = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) - print('test_multivariate_normal1') - print(a) - print(b) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(a.shape, (3, 2)) - - def test_multivariate_normal2(self): - br.seed() - a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) - b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') - print(a) - print(b) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(a.shape, (2,)) - - def test_negative_binomial(self): - br.seed() - a = np.random.negative_binomial([3., 10.], 0.5) - b = bm.random.negative_binomial([3., 10.], 0.5) - print(a) - print(b) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (2,)) - - def test_negative_binomial2(self): - br.seed() - a = np.random.negative_binomial(3., 0.5, 10) - b = bm.random.negative_binomial(3., 0.5, 10) - print(a) - print(b) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (10,)) - - def test_noncentral_chisquare(self): - br.seed() - a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2)) - b = bm.random.noncentral_chisquare(3, [3., 2.], (4, 2)) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (4, 2)) - - def test_noncentral_chisquare2(self): - br.seed() - a = bm.random.noncentral_chisquare(3, [3., 2.]) - self.assertTupleEqual(a.shape, (2,)) - - def test_noncentral_f(self): - br.seed() - a = bm.random.noncentral_f(3, 20, 3., 100) - self.assertTupleEqual(a.shape, (100,)) - - def test_power(self): - br.seed() - a = np.random.power(2, (4, 2)) - b = bm.random.power(2, (4, 2)) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (4, 2)) - - def test_rayleigh(self): - br.seed() - a = bm.random.power(2., (4, 2)) - self.assertTupleEqual(a.shape, (4, 2)) - - def test_triangular(self): - br.seed() - a = bm.random.triangular((2, 2)) - self.assertTupleEqual(a.shape, (2, 2)) - - def test_vonmises(self): - br.seed() - a = np.random.vonmises(2., 2.) - b = bm.random.vonmises(2., 2.) - print(a, b) - self.assertTupleEqual(np.shape(a), b.shape) - self.assertTupleEqual(b.shape, ()) - - def test_vonmises2(self): - br.seed() - a = np.random.vonmises(2., 2., 10) - b = bm.random.vonmises(2., 2., 10) - print(a, b) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (10,)) - - def test_wald(self): - br.seed() - a = np.random.wald([2., 0.5], 2.) - b = bm.random.wald([2., 0.5], 2.) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (2,)) - - def test_wald2(self): - br.seed() - a = np.random.wald(2., 2., 100) - b = bm.random.wald(2., 2., 100) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (100,)) - - def test_weibull(self): - br.seed() - a = bm.random.weibull(2., (4, 2)) - self.assertTupleEqual(a.shape, (4, 2)) - - def test_weibull2(self): - br.seed() - a = bm.random.weibull(2., ) - self.assertTupleEqual(a.shape, ()) - - def test_weibull3(self): - br.seed() - a = bm.random.weibull([2., 3.], ) - self.assertTupleEqual(a.shape, (2,)) - - def test_weibull_min(self): - br.seed() - a = bm.random.weibull_min(2., 2., (4, 2)) - self.assertTupleEqual(a.shape, (4, 2)) - - def test_weibull_min2(self): - br.seed() - a = bm.random.weibull_min(2., 2.) - self.assertTupleEqual(a.shape, ()) - - def test_weibull_min3(self): - br.seed() - a = bm.random.weibull_min([2., 3.], 2.) - self.assertTupleEqual(a.shape, (2,)) - - def test_zipf(self): - br.seed() - a = bm.random.zipf(2., (4, 2)) - self.assertTupleEqual(a.shape, (4, 2)) - - def test_zipf2(self): - br.seed() - a = np.random.zipf([1.1, 2.]) - b = bm.random.zipf([1.1, 2.]) - self.assertTupleEqual(a.shape, b.shape) - self.assertTupleEqual(b.shape, (2,)) - - def test_maxwell(self): - br.seed() - a = bm.random.maxwell(10) - self.assertTupleEqual(a.shape, (10,)) - - def test_maxwell2(self): - br.seed() - a = bm.random.maxwell() - self.assertTupleEqual(a.shape, ()) - - def test_t(self): - br.seed() - a = bm.random.t(1., size=10) - self.assertTupleEqual(a.shape, (10,)) - - def test_t2(self): - br.seed() - a = bm.random.t([1., 2.], size=None) - self.assertTupleEqual(a.shape, (2,)) + def test_seed(self): + test_seed = 299 + br.seed(test_seed) + a = br.rand(3) + br.seed(test_seed) + b = br.rand(3) + self.assertTrue(bm.array_equal(a, b)) + + def test_rand(self): + br.seed() + a = br.rand(3, 2) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a < 1).all()) + + key = jr.PRNGKey(123) + jres = jr.uniform(key, shape=(10, 100)) + self.assertTrue(bm.allclose(jres, br.rand(10, 100, key=key))) + self.assertTrue(bm.allclose(jres, br.rand(10, 100, key=123))) + + def test_randint1(self): + br.seed() + a = br.randint(5) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(0 <= a < 5) + + def test_randint2(self): + br.seed() + a = br.randint(2, 6, size=(4, 3)) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertTrue((a >= 2).all() and (a < 6).all()) + + def test_randint3(self): + br.seed() + a = br.randint([1, 2, 3], [10, 7, 8]) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue((a - bm.array([1, 2, 3]) >= 0).all() + and (-a + bm.array([10, 7, 8]) > 0).all()) + + def test_randint4(self): + br.seed() + a = br.randint([1, 2, 3], [10, 7, 8], size=(2, 3)) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_randn(self): + br.seed() + a = br.randn(3, 2) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_random1(self): + br.seed() + a = br.random() + self.assertTrue(0. <= a < 1) + + def test_random2(self): + br.seed() + a = br.random(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a < 1).all()) + + def test_random_sample(self): + br.seed() + a = br.random_sample(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a < 1).all()) + + def test_choice1(self): + br.seed() + a = bm.random.choice(5) + self.assertTupleEqual(jnp.shape(a), ()) + self.assertTrue(0 <= a < 5) + + def test_choice2(self): + br.seed() + a = bm.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3]) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue((a >= 0).all() and (a < 5).all()) + + def test_choice3(self): + br.seed() + a = bm.random.choice(bm.arange(2, 20), size=(4, 3), replace=False) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertTrue((a >= 2).all() and (a < 20).all()) + self.assertEqual(len(bm.unique(a)), 12) + + def test_permutation1(self): + br.seed() + a = bm.random.permutation(10) + self.assertTupleEqual(a.shape, (10,)) + self.assertEqual(len(bm.unique(a)), 10) + + def test_permutation2(self): + br.seed() + a = bm.random.permutation(bm.arange(10)) + self.assertTupleEqual(a.shape, (10,)) + self.assertEqual(len(bm.unique(a)), 10) + + def test_shuffle1(self): + br.seed() + a = bm.arange(10) + bm.random.shuffle(a) + self.assertTupleEqual(a.shape, (10,)) + self.assertEqual(len(bm.unique(a)), 10) + + def test_shuffle2(self): + br.seed() + a = bm.arange(12).reshape(4, 3) + bm.random.shuffle(a, axis=1) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertEqual(len(bm.unique(a)), 12) + + # test that a is only shuffled along axis 1 + uni = bm.unique(bm.diff(a, axis=0)) + self.assertEqual(uni, bm.Array([3])) + + def test_beta1(self): + br.seed() + a = bm.random.beta(2, 2) + self.assertTupleEqual(a.shape, ()) + + def test_beta2(self): + br.seed() + a = bm.random.beta([2, 2, 3], 2, size=(3,)) + self.assertTupleEqual(a.shape, (3,)) + + def test_exponential1(self): + br.seed() + a = bm.random.exponential(10., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_exponential2(self): + br.seed() + a = bm.random.exponential([1., 2., 5.]) + self.assertTupleEqual(a.shape, (3,)) + + def test_gamma(self): + br.seed() + a = bm.random.gamma(2, 10., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_gumbel(self): + br.seed() + a = bm.random.gumbel(0., 2., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_laplace(self): + br.seed() + a = bm.random.laplace(0., 2., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_logistic(self): + br.seed() + a = bm.random.logistic(0., 2., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_normal1(self): + br.seed() + a = bm.random.normal() + self.assertTupleEqual(a.shape, ()) + + def test_normal2(self): + br.seed() + a = bm.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.]) + self.assertTupleEqual(a.shape, (3,)) + + def test_normal3(self): + br.seed() + a = bm.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]]) + print(a) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_pareto(self): + br.seed() + a = bm.random.pareto([1, 2, 2]) + self.assertTupleEqual(a.shape, (3,)) + + def test_poisson(self): + br.seed() + a = bm.random.poisson([1., 2., 2.], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_standard_cauchy(self): + br.seed() + a = bm.random.standard_cauchy(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_standard_exponential(self): + br.seed() + a = bm.random.standard_exponential(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_standard_gamma(self): + br.seed() + a = bm.random.standard_gamma(shape=[1, 2, 4], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_standard_normal(self): + br.seed() + a = bm.random.standard_normal(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_standard_t(self): + br.seed() + a = bm.random.standard_t(df=[1, 2, 4], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_standard_uniform1(self): + br.seed() + a = bm.random.uniform() + self.assertTupleEqual(a.shape, ()) + self.assertTrue(0 <= a < 1) + + def test_uniform2(self): + br.seed() + a = bm.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue((a - bm.array([-1., 5., 2.]) >= 0).all() + and (-a + bm.array([2., 6., 10.]) > 0).all()) + + def test_uniform3(self): + br.seed() + a = bm.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3)) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_uniform4(self): + br.seed() + a = bm.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]]) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_truncated_normal1(self): + br.seed() + a = bm.random.truncated_normal(-1., 1.) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(-1. <= a <= 1.) + + def test_truncated_normal2(self): + br.seed() + a = bm.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3)) + self.assertTupleEqual(a.shape, (4, 3)) + + def test_truncated_normal3(self): + br.seed() + a = bm.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]]) + self.assertTupleEqual(a.shape, (2, 3)) + self.assertTrue((a - bm.array([-1., 0., 1.]) >= 0.).all() + and (- a + bm.array([2., 2., 4.]) >= 0.).all()) + + def test_bernoulli1(self): + br.seed() + a = bm.random.bernoulli() + self.assertTupleEqual(a.shape, ()) + self.assertTrue(a == 0 or a == 1) + + def test_bernoulli2(self): + br.seed() + a = bm.random.bernoulli([0.5, 0.6, 0.8]) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) + + def test_bernoulli3(self): + br.seed() + a = bm.random.bernoulli([0.5, 0.6], size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) + + def test_lognormal1(self): + br.seed() + a = bm.random.lognormal() + self.assertTupleEqual(a.shape, ()) + + def test_lognormal2(self): + br.seed() + a = bm.random.lognormal(sigma=[2., 1.], size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_lognormal3(self): + br.seed() + a = bm.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]]) + self.assertTupleEqual(a.shape, (2, 2)) + + def test_binomial1(self): + br.seed() + a = bm.random.binomial(5, 0.5) + b = np.random.binomial(5, 0.5) + print(a) + print(b) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(a.dtype, int) + + def test_binomial2(self): + br.seed() + a = bm.random.binomial(5, 0.5, size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a <= 5).all()) + + def test_binomial3(self): + br.seed() + a = bm.random.binomial(n=bm.asarray([2, 3, 4]), p=bm.asarray([[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]])) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_chisquare1(self): + br.seed() + a = bm.random.chisquare(3) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(a.dtype, float) + + def test_chisquare2(self): + br.seed() + with self.assertRaises(NotImplementedError): + a = bm.random.chisquare(df=[2, 3, 4]) + + def test_chisquare3(self): + br.seed() + a = bm.random.chisquare(df=2, size=100) + self.assertTupleEqual(a.shape, (100,)) + + def test_chisquare4(self): + br.seed() + a = bm.random.chisquare(df=2, size=(100, 10)) + self.assertTupleEqual(a.shape, (100, 10)) + + def test_dirichlet1(self): + br.seed() + a = bm.random.dirichlet((10, 5, 3)) + self.assertTupleEqual(a.shape, (3,)) + + def test_dirichlet2(self): + br.seed() + a = bm.random.dirichlet((10, 5, 3), 20) + self.assertTupleEqual(a.shape, (20, 3)) + + def test_f(self): + br.seed() + a = bm.random.f(1., 48., 100) + self.assertTupleEqual(a.shape, (100,)) + + def test_geometric(self): + br.seed() + a = bm.random.geometric([0.7, 0.5, 0.2]) + self.assertTupleEqual(a.shape, (3,)) + + def test_hypergeometric1(self): + br.seed() + a = bm.random.hypergeometric(10, 10, 10, 20) + self.assertTupleEqual(a.shape, (20,)) + + @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') + def test_hypergeometric2(self): + br.seed() + a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]]) + self.assertTupleEqual(a.shape, (2, 2)) + + @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') + def test_hypergeometric3(self): + br.seed() + a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2)) + self.assertTupleEqual(a.shape, (3, 2, 2)) + + def test_logseries(self): + br.seed() + a = bm.random.logseries([0.7, 0.5, 0.2], size=[4, 3]) + self.assertTupleEqual(a.shape, (4, 3)) + + def test_multinominal1(self): + br.seed() + a = np.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) + print(a, a.shape) + b = bm.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) + print(b, b.shape) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (4, 2, 3)) + + def test_multinominal2(self): + br.seed() + a = bm.random.multinomial(100, (0.5, 0.2, 0.3)) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue(a.sum() == 100) + + def test_multivariate_normal1(self): + br.seed() + # self.skipTest('Windows jaxlib error') + a = np.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) + b = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) + print('test_multivariate_normal1') + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_multivariate_normal2(self): + br.seed() + a = np.random.multivariate_normal([1, 2], [[1, 3], [3, 1]]) + b = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], method='svd') + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(a.shape, (2,)) + + def test_negative_binomial(self): + br.seed() + a = np.random.negative_binomial([3., 10.], 0.5) + b = bm.random.negative_binomial([3., 10.], 0.5) + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (2,)) + + def test_negative_binomial2(self): + br.seed() + a = np.random.negative_binomial(3., 0.5, 10) + b = bm.random.negative_binomial(3., 0.5, 10) + print(a) + print(b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (10,)) + + def test_noncentral_chisquare(self): + br.seed() + a = np.random.noncentral_chisquare(3, [3., 2.], (4, 2)) + b = bm.random.noncentral_chisquare(3, [3., 2.], (4, 2)) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (4, 2)) + + def test_noncentral_chisquare2(self): + br.seed() + a = bm.random.noncentral_chisquare(3, [3., 2.]) + self.assertTupleEqual(a.shape, (2,)) + + def test_noncentral_f(self): + br.seed() + a = bm.random.noncentral_f(3, 20, 3., 100) + self.assertTupleEqual(a.shape, (100,)) + + def test_power(self): + br.seed() + a = np.random.power(2, (4, 2)) + b = bm.random.power(2, (4, 2)) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (4, 2)) + + def test_rayleigh(self): + br.seed() + a = bm.random.power(2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_triangular(self): + br.seed() + a = bm.random.triangular((2, 2)) + self.assertTupleEqual(a.shape, (2, 2)) + + def test_vonmises(self): + br.seed() + a = np.random.vonmises(2., 2.) + b = bm.random.vonmises(2., 2.) + print(a, b) + self.assertTupleEqual(np.shape(a), b.shape) + self.assertTupleEqual(b.shape, ()) + + def test_vonmises2(self): + br.seed() + a = np.random.vonmises(2., 2., 10) + b = bm.random.vonmises(2., 2., 10) + print(a, b) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (10,)) + + def test_wald(self): + br.seed() + a = np.random.wald([2., 0.5], 2.) + b = bm.random.wald([2., 0.5], 2.) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (2,)) + + def test_wald2(self): + br.seed() + a = np.random.wald(2., 2., 100) + b = bm.random.wald(2., 2., 100) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (100,)) + + def test_weibull(self): + br.seed() + a = bm.random.weibull(2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_weibull2(self): + br.seed() + a = bm.random.weibull(2., ) + self.assertTupleEqual(a.shape, ()) + + def test_weibull3(self): + br.seed() + a = bm.random.weibull([2., 3.], ) + self.assertTupleEqual(a.shape, (2,)) + + def test_weibull_min(self): + br.seed() + a = bm.random.weibull_min(2., 2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_weibull_min2(self): + br.seed() + a = bm.random.weibull_min(2., 2.) + self.assertTupleEqual(a.shape, ()) + + def test_weibull_min3(self): + br.seed() + a = bm.random.weibull_min([2., 3.], 2.) + self.assertTupleEqual(a.shape, (2,)) + + def test_zipf(self): + br.seed() + a = bm.random.zipf(2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_zipf2(self): + br.seed() + a = np.random.zipf([1.1, 2.]) + b = bm.random.zipf([1.1, 2.]) + self.assertTupleEqual(a.shape, b.shape) + self.assertTupleEqual(b.shape, (2,)) + + def test_maxwell(self): + br.seed() + a = bm.random.maxwell(10) + self.assertTupleEqual(a.shape, (10,)) + + def test_maxwell2(self): + br.seed() + a = bm.random.maxwell() + self.assertTupleEqual(a.shape, ()) + + def test_t(self): + br.seed() + a = bm.random.t(1., size=10) + self.assertTupleEqual(a.shape, (10,)) + + def test_t2(self): + br.seed() + a = bm.random.t([1., 2.], size=None) + self.assertTupleEqual(a.shape, (2,)) class TestRandomKey(unittest.TestCase): - def test_clear_memory(self): - bm.random.split_key() - - print(bm.random.DEFAULT.value) - # Accept both numpy arrays and JAX arrays - import jax - self.assertTrue(isinstance(bm.random.DEFAULT.value, (np.ndarray, jax.Array))) + def test_clear_memory(self): + bm.random.split_key() + + print(bm.random.DEFAULT.value) + # Accept both numpy arrays and JAX arrays + import jax + self.assertTrue(isinstance(bm.random.DEFAULT.value, (np.ndarray, jax.Array))) diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py index 5bf0a0ad5..247cdaa5c 100644 --- a/brainpy/_src/math/tests/test_tifunc.py +++ b/brainpy/_src/math/tests/test_tifunc.py @@ -9,111 +9,110 @@ import matplotlib.pyplot as plt import os - bm.set_platform('cpu') def test_taichi_random(): - @ti.kernel - def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - key = bm.tifunc.lfsr88_key(seed[0]) - for i in range(out.shape[0]): - key, result = bm.tifunc.lfsr88_rand(key) - out[i] = result - - @ti.kernel - def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range(out.shape[0]): - out[i] = bm.tifunc.taichi_lcg_rand(seed) - - @ti.kernel - def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) - - @ti.kernel - def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) - - @ti.kernel - def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), - mu_sigma: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - mu = mu_sigma[0] - sigma = mu_sigma[1] - - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) - - n = 100000 - seed = jnp.array([1234, ], dtype=jnp.uint32) - low_high = jnp.array([0, 10]) - mu_sigma = jnp.array([0, 1]) - - prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, - gpu_kernel=test_taichi_lfsr88) - - prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, - gpu_kernel=test_taichi_lcg_rand) - prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, - gpu_kernel=test_taichi_uniform_int_distribution) - prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, - gpu_kernel=test_taichi_uniform_real_distribution) - prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, - gpu_kernel=test_taichi_normal_distribution) - - file_path = os.path.dirname(os.path.abspath(__file__)) - - out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LFSR88 random number generator") - plt.savefig(file_path + "/lfsr88.png") - plt.close() - - out = prim_lcg_rand(seed, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LCG random number generator") - plt.savefig(file_path + "/lcg_rand.png") - plt.close() - - out = prim_uniform_int_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) - # show the distribution of out - plt.hist(out, bins=10) - plt.title("Uniform int distribution (0, 10)") - plt.savefig(file_path + "/uniform_int_distribution.png") - plt.close() - - out = prim_uniform_real_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("Uniform real distribution (0, 10)") - plt.savefig(file_path + "/uniform_real_distribution.png") - plt.close() - - out = prim_normal_distribution(seed, mu_sigma, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.title("Normal distribution mu=0, sigma=1") - plt.hist(out, bins=100) - plt.savefig(file_path + "/normal_distribution.png") + @ti.kernel + def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), + out: ti.types.ndarray(ndim=1, dtype=ti.f32)): + key = bm.tifunc.lfsr88_key(seed[0]) + for i in range(out.shape[0]): + key, result = bm.tifunc.lfsr88_rand(key) + out[i] = result + + @ti.kernel + def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range(out.shape[0]): + out[i] = bm.tifunc.taichi_lcg_rand(seed) + + @ti.kernel + def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) + + @ti.kernel + def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) + + @ti.kernel + def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), + mu_sigma: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + mu = mu_sigma[0] + sigma = mu_sigma[1] + + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) + + n = 100000 + seed = jnp.array([1234, ], dtype=jnp.uint32) + low_high = jnp.array([0, 10]) + mu_sigma = jnp.array([0, 1]) + + prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, + gpu_kernel=test_taichi_lfsr88) + + prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, + gpu_kernel=test_taichi_lcg_rand) + prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, + gpu_kernel=test_taichi_uniform_int_distribution) + prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, + gpu_kernel=test_taichi_uniform_real_distribution) + prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, + gpu_kernel=test_taichi_normal_distribution) + + file_path = os.path.dirname(os.path.abspath(__file__)) + + out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LFSR88 random number generator") + plt.savefig(file_path + "/lfsr88.png") + plt.close() + + out = prim_lcg_rand(seed, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LCG random number generator") + plt.savefig(file_path + "/lcg_rand.png") + plt.close() + + out = prim_uniform_int_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) + # show the distribution of out + plt.hist(out, bins=10) + plt.title("Uniform int distribution (0, 10)") + plt.savefig(file_path + "/uniform_int_distribution.png") + plt.close() + + out = prim_uniform_real_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("Uniform real distribution (0, 10)") + plt.savefig(file_path + "/uniform_real_distribution.png") + plt.close() + + out = prim_normal_distribution(seed, mu_sigma, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.title("Normal distribution mu=0, sigma=1") + plt.hist(out, bins=100) + plt.savefig(file_path + "/normal_distribution.png") # TODO; test default types diff --git a/brainpy/_src/measure/__init__.py b/brainpy/_src/measure/__init__.py index 976345282..35e416a2a 100644 --- a/brainpy/_src/measure/__init__.py +++ b/brainpy/_src/measure/__init__.py @@ -10,5 +10,3 @@ from .correlation import * from .firings import * from .lfp import * - - diff --git a/brainpy/_src/measure/correlation.py b/brainpy/_src/measure/correlation.py index aaf6a3dde..9eb93225e 100644 --- a/brainpy/_src/measure/correlation.py +++ b/brainpy/_src/measure/correlation.py @@ -1,309 +1,19 @@ # -*- coding: utf-8 -*- -import numpy as onp -from jax import vmap, lax, numpy as jnp - -from brainpy._src import math as bm -from brainpy.errors import UnsupportedError +import braintools __all__ = [ - 'cross_correlation', - 'voltage_fluctuation', - 'matrix_correlation', - 'weighted_correlation', - 'functional_connectivity', - # 'functional_connectivity_dynamics', + 'cross_correlation', + 'voltage_fluctuation', + 'matrix_correlation', + 'weighted_correlation', + 'functional_connectivity', + # 'functional_connectivity_dynamics', ] - -def cross_correlation(spikes, bin, dt=None, numpy=True, method='loop'): - r"""Calculate cross correlation index between neurons. - - The coherence [1]_ between two neurons i and j is measured by their - cross-correlation of spike trains at zero time lag within a time bin - of :math:`\Delta t = \tau`. More specifically, suppose that a long - time interval T is divided into small bins of :math:`\Delta t` and - that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0 - or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence - measure for the pair as: - - .. math:: - - \kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)} - {\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}} - - The population coherence measure :math:`\kappa(\tau)` is defined by the - average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the - network. - - .. note:: - To JIT compile this function, users should make ``bin``, ``dt``, ``numpy`` static. - For example, ``partial(brainpy.measure.cross_correlation, bin=10, numpy=False)``. - - Parameters:: - - spikes : ndarray - The history of spike states of the neuron group. - bin : float, int - The time bin to normalize spike states. - dt : float, optional - The time precision. - numpy: bool - Whether we use numpy array as the functional output. - If ``False``, this function can be JIT compiled. - method: str - The method to calculate all pairs of cross correlation. - Supports two kinds of methods: `loop` and `vmap`. - `vmap` method needs much more memory. - - .. versionadded:: 2.2.3.4 - - Returns:: - - cc_index : float - The cross correlation value which represents the synchronization index. - - References:: - - .. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic - inhibition in a hippocampal interneuronal network model." Journal of - neuroscience 16.20 (1996): 6402-6413. - """ - spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes) - np = onp if numpy else jnp - dt = bm.get_dt() if dt is None else dt - bin_size = int(bin / dt) - num_hist, num_neu = spikes.shape - num_bin = int(onp.ceil(num_hist / bin_size)) - if num_bin * bin_size != num_hist: - spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) - states = spikes.T.reshape((num_neu, num_bin, bin_size)) - states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_) - indices = jnp.tril_indices(num_neu, k=-1) - - if method == 'loop': - def _f(i, j): - sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) - return lax.cond(sqrt_ij == 0., - lambda _: 0., - lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, - None) - res = bm.for_loop(_f, operands=indices) - - elif method == 'vmap': - @vmap - def _cc(i, j): - sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) - return lax.cond(sqrt_ij == 0., - lambda _: 0., - lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, - None) - - res = _cc(*indices) - else: - raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".') - - return np.mean(np.asarray(res)) - - -def _f_signal(signal): - return jnp.mean(signal * signal) - jnp.mean(signal) ** 2 - - -def voltage_fluctuation(potentials, numpy=True, method='loop'): - r"""Calculate neuronal synchronization via voltage variance. - - The method comes from [1]_ [2]_ [3]_. - - First, average over the membrane potential :math:`V` - - .. math:: - - V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t) - - The variance of the time fluctuations of :math:`V(t)` is - - .. math:: - - \sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t - - \left[ \left\langle V(t) \right\rangle_t \right]^2 - - where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots` - denotes time-averaging over a large time, :math:`\tau_m`. After normalization - of :math:`\sigma_V` to the average over the population of the single cell - membrane potentials - - .. math:: - - \sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t - - \left[ \left\langle V_i(t) \right\rangle_t \right]^2 - - one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system - of :math:`N` neurons by: - - .. math:: - - \chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N - \sigma_{V_i}^2} - - .. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled - inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814. - .. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled - inhibitory neurons. Physica D 72:259-282. - .. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347. - - Args: - potentials: The membrane potential matrix of the neuron group. - numpy: Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. - method: The method to calculate all pairs of cross correlation. - Supports two kinds of methods: `loop` and `vmap`. - `vmap` method will consume much more memory. - - .. versionadded:: 2.2.3.4 - - Returns: - sync_index: The synchronization index. - """ - - potentials = bm.as_jax(potentials) - avg = jnp.mean(potentials, axis=1) - avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2 - - if method == 'loop': - _var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1)) - - elif method == 'vmap': - _var = vmap(_f_signal, in_axes=1)(potentials) - else: - raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".') - - var_mean = jnp.mean(_var) - r = jnp.where(var_mean == 0., 1., avg_var / var_mean) - return bm.as_numpy(r) if numpy else r - - -def matrix_correlation(x, y, numpy=True): - """Pearson correlation of the lower triagonal of two matrices. - - The triangular matrix is offset by k = 1 in order to ignore the diagonal line - - Parameters:: - - x: ndarray - First matrix. - y: ndarray - Second matrix - numpy: bool - Whether we use numpy array as the functional output. - If ``False``, this function can be JIT compiled. - - Returns:: - - coef: ndarray - Correlation coefficient - """ - - x = bm.as_numpy(x) if numpy else bm.as_device_array(x) - y = bm.as_numpy(y) if numpy else bm.as_device_array(y) - np = onp if numpy else jnp - if x.ndim != 2: - raise ValueError(f'Only support 2d array, but we got a array ' - f'with the shape of {x.shape}') - if y.ndim != 2: - raise ValueError(f'Only support 2d array, but we got a array ' - f'with the shape of {y.shape}') - x = x[np.triu_indices_from(x, k=1)] - y = y[np.triu_indices_from(y, k=1)] - cc = np.corrcoef(x, y)[0, 1] - return cc - - -def functional_connectivity(activities, numpy=True): - """Functional connectivity matrix of timeseries activities. - - Parameters:: - - activities: ndarray - The multidimensional array with the shape of ``(num_time, num_sample)``. - numpy: bool - Whether we use numpy array as the functional output. - If ``False``, this function can be JIT compiled. - - Returns:: - - connectivity_matrix: ndarray - ``num_sample x num_sample`` functional connectivity matrix. - """ - activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities) - np = onp if numpy else jnp - if activities.ndim != 2: - raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". ' - f'But we got a array with the shape of {activities.shape}') - fc = np.corrcoef(activities.T) - return np.nan_to_num(fc) - - -def functional_connectivity_dynamics(activities, window_size=30, step_size=5): - """Computes functional connectivity dynamics (FCD) matrix. - - Parameters:: - - activities: ndarray - The time series with shape of ``(num_time, num_sample)``. - window_size: int - Size of each rolling window in time steps, defaults to 30. - step_size: int - Step size between each rolling window, defaults to 5. - - Returns:: - - fcd_matrix: ndarray - FCD matrix. - """ - pass - - -def weighted_correlation(x, y, w, numpy=True): - """Weighted Pearson correlation of two data series. - - Parameters:: - - x: ndarray - The data series 1. - y: ndarray - The data series 2. - w: ndarray - Weight vector, must have same length as x and y. - numpy: bool - Whether we use numpy array as the functional output. - If ``False``, this function can be JIT compiled. - - Returns:: - - corr: ndarray - Weighted correlation coefficient. - """ - x = bm.as_numpy(x) if numpy else bm.as_device_array(x) - y = bm.as_numpy(y) if numpy else bm.as_device_array(y) - w = bm.as_numpy(w) if numpy else bm.as_device_array(w) - np = onp if numpy else jnp - - def _weighted_mean(x, w): - """Weighted Mean""" - return np.sum(x * w) / np.sum(w) - - def _weighted_cov(x, y, w): - """Weighted Covariance""" - return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w) - - if x.ndim != 1: - raise ValueError(f'Only support 1d array, but we got a array ' - f'with the shape of {x.shape}') - if y.ndim != 1: - raise ValueError(f'Only support 1d array, but we got a array ' - f'with the shape of {y.shape}') - if w.ndim != 1: - raise ValueError(f'Only support 1d array, but we got a array ' - f'with the shape of {w.shape}') - return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w)) +cross_correlation = braintools.metric.cross_correlation +voltage_fluctuation = braintools.metric.voltage_fluctuation +matrix_correlation = braintools.metric.matrix_correlation +functional_connectivity = braintools.metric.functional_connectivity +weighted_correlation = braintools.metric.weighted_correlation diff --git a/brainpy/_src/measure/firings.py b/brainpy/_src/measure/firings.py index b89fffd9b..7f98c58c8 100644 --- a/brainpy/_src/measure/firings.py +++ b/brainpy/_src/measure/firings.py @@ -1,72 +1,72 @@ # -*- coding: utf-8 -*- -import numpy as onp import jax.numpy as jnp +import numpy as onp from brainpy._src import math as bm __all__ = [ - 'raster_plot', - 'firing_rate', + 'raster_plot', + 'firing_rate', ] def raster_plot(sp_matrix, times): - """Get spike raster plot which displays the spiking activity - of a group of neurons over time. + """Get spike raster plot which displays the spiking activity + of a group of neurons over time. - Parameters:: + Parameters:: - sp_matrix : bnp.ndarray - The matrix which record spiking activities. - times : bnp.ndarray - The time steps. + sp_matrix : bnp.ndarray + The matrix which record spiking activities. + times : bnp.ndarray + The time steps. - Returns:: + Returns:: - raster_plot : tuple - Include (neuron index, spike time). - """ - sp_matrix = bm.as_numpy(sp_matrix) - times = onp.asarray(times) - elements = onp.where(sp_matrix > 0.) - index = elements[1] - time = times[elements[0]] - return index, time + raster_plot : tuple + Include (neuron index, spike time). + """ + sp_matrix = bm.as_numpy(sp_matrix) + times = onp.asarray(times) + elements = onp.where(sp_matrix > 0.) + index = elements[1] + time = times[elements[0]] + return index, time def firing_rate(spikes, width, dt=None, numpy=True): - r"""Calculate the mean firing rate over in a neuron group. + r"""Calculate the mean firing rate over in a neuron group. - This method is adopted from Brian2. + This method is adopted from Brian2. - The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}` - in an interval of duration :math:`T` divided by :math:`T`: + The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}` + in an interval of duration :math:`T` divided by :math:`T`: - .. math:: + .. math:: - v_k = {n_k^{sp} \over T} + v_k = {n_k^{sp} \over T} - Parameters:: + Parameters:: - spikes : ndarray - The spike matrix which record spiking activities. - width : int, float - The width of the ``window`` in millisecond. - dt : float, optional - The sample rate. - numpy: bool - Whether we use numpy array as the functional output. - If ``False``, this function can be JIT compiled. + spikes : ndarray + The spike matrix which record spiking activities. + width : int, float + The width of the ``window`` in millisecond. + dt : float, optional + The sample rate. + numpy: bool + Whether we use numpy array as the functional output. + If ``False``, this function can be JIT compiled. - Returns:: + Returns:: - rate : ndarray - The population rate in Hz, smoothed with the given window. - """ - spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes) - np = onp if numpy else jnp - dt = bm.get_dt() if (dt is None) else dt - width1 = int(width / 2 / dt) * 2 + 1 - window = np.ones(width1) * 1000 / width - return np.convolve(np.mean(spikes, axis=1), window, mode='same') + rate : ndarray + The population rate in Hz, smoothed with the given window. + """ + spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes) + np = onp if numpy else jnp + dt = bm.get_dt() if (dt is None) else dt + width1 = int(width / 2 / dt) * 2 + 1 + window = np.ones(width1) * 1000 / width + return np.convolve(np.mean(spikes, axis=1), window, mode='same') diff --git a/brainpy/_src/measure/lfp.py b/brainpy/_src/measure/lfp.py index fba05c24b..518f2d247 100644 --- a/brainpy/_src/measure/lfp.py +++ b/brainpy/_src/measure/lfp.py @@ -1,113 +1,8 @@ # -*- coding: utf-8 -*- - - -from jax import numpy as jnp - -import brainpy._src.math as bm +import braintools.metric __all__ = [ - 'unitary_LFP', + 'unitary_LFP', ] - -def unitary_LFP(times, spikes, spike_type, - xmax=0.2, ymax=0.2, va=200., lambda_=0.2, - sig_i=2.1, sig_e=2.1 * 1.5, location='soma layer', seed=None): - """A kernel-based method to calculate unitary local field potentials (uLFP) - from a network of spiking neurons [1]_. - - .. note:: - This method calculates LFP only from the neuronal spikes. It does not consider - the subthreshold synaptic events, or the dendritic voltage-dependent ion channels. - - Examples:: - - If you have spike data of excitatory and inhibtiory neurons, you can get the LFP - by the following methods: - - >>> import brainpy as bp - >>> n_time = 1000 - >>> n_exc = 100 - >>> n_inh = 25 - >>> times = bm.arange(n_time) * 0.1 - >>> exc_sps = bp.math.random.random((n_time, n_exc)) < 0.3 - >>> inh_sps = bp.math.random.random((n_time, n_inh)) < 0.4 - >>> lfp = bp.measure.unitary_LFP(times, exc_sps, 'exc') - >>> lfp += bp.measure.unitary_LFP(times, inh_sps, 'inh') - - Parameters:: - - times: ndarray - The times of the recording points. - spikes: ndarray - The spikes of excitatory neurons recorded by brainpy monitors. - spike_type: str - The neuron type of the spike trains. It can be "exc" or "inh". - location: str - The location of the spikes recorded. It can be "soma layer", "deep layer", - "superficial layer" and "surface". - xmax: float - Size of the array (in mm). - ymax: float - Size of the array (in mm). - va: int, float - The axon velocity (mm/sec). - lambda_: float - The space constant (mm). - sig_i: float - The std-dev of inhibition (in ms) - sig_e: float - The std-dev for excitation (in ms). - seed: int - The random seed. - - References:: - - .. [1] Telenczuk, Bartosz, Maria Telenczuk, and Alain Destexhe. "A kernel-based - method to calculate local field potentials from networks of spiking - neurons." Journal of Neuroscience Methods 344 (2020): 108871. - - """ - times = bm.as_jax(times) - spikes = bm.as_jax(spikes) - if spike_type not in ['exc', 'inh']: - raise ValueError('"spike_type" should be "exc or ""inh". ') - if spikes.ndim != 2: - raise ValueError('"E_spikes" should be a matrix with shape of (num_time, num_neuron). ' - f'But we got {spikes.shape}') - if times.shape[0] != spikes.shape[0]: - raise ValueError('times and spikes should be consistent at the firs axis. ' - f'Bug we got {times.shape[0]} != {spikes.shape}.') - - # Distributing cells in a 2D grid - rng = bm.random.RandomState(seed) - num_neuron = spikes.shape[1] - pos_xs, pos_ys = rng.rand(2, num_neuron).value * jnp.array([[xmax], [ymax]]) - pos_xs, pos_ys = jnp.asarray(pos_xs), jnp.asarray(pos_ys) - - # distance/coordinates - xe, ye = xmax / 2, ymax / 2 # coordinates of electrode - dist = jnp.sqrt((pos_xs - xe) ** 2 + (pos_ys - ye) ** 2) # distance to electrode in mm - - # amplitude - if location == 'soma layer': - amp_e, amp_i = 0.48, 3. # exc/inh uLFP amplitude (soma layer) - elif location == 'deep layer': - amp_e, amp_i = -0.16, -0.2 # exc/inh uLFP amplitude (deep layer) - elif location == 'superficial layer': - amp_e, amp_i = 0.24, -1.2 # exc/inh uLFP amplitude (superficial layer) - elif location == 'surface layer': - amp_e, amp_i = -0.08, 0.3 # exc/inh uLFP amplitude (surface) - else: - raise NotImplementedError - A = jnp.exp(-dist / lambda_) * (amp_e if spike_type == 'exc' else amp_i) - - # delay - delay = 10.4 + dist / va # delay to peak (in ms) - - # LFP Calculation - iis, ids = jnp.where(spikes) - tts = times[iis] + delay[ids] - exc_amp = A[ids] - tau = (2 * sig_e * sig_e) if spike_type == 'exc' else (2 * sig_i * sig_i) - return bm.for_loop(lambda t: jnp.sum(exc_amp * jnp.exp(-(t - tts) ** 2 / tau)), times) +unitary_LFP = braintools.metric.unitary_LFP diff --git a/brainpy/_src/measure/tests/test_correlation.py b/brainpy/_src/measure/tests/test_correlation.py index f630112cb..405f1169a 100644 --- a/brainpy/_src/measure/tests/test_correlation.py +++ b/brainpy/_src/measure/tests/test_correlation.py @@ -11,101 +11,90 @@ bm.set_platform('cpu') + class TestCrossCorrelation(unittest.TestCase): - def test_c(self): - bm.random.seed() - spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T - cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.) - f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.)) - cc2 = f_cc(spikes) - print(cc1, cc2) - self.assertTrue(cc1 == cc2) - - - def test_cc(self): - bm.random.seed() - spikes = bm.ones((1000, 10)) - cc1 = bp.measure.cross_correlation(spikes, 1.) - self.assertTrue(cc1 == 1.) - - spikes = bm.zeros((1000, 10)) - cc2 = bp.measure.cross_correlation(spikes, 1.) - self.assertTrue(cc2 == 0.) - - - - def test_cc2(self): - bm.random.seed() - spikes = bm.random.randint(0, 2, (1000, 10)) - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - - - def test_cc3(self): - bm.random.seed() - spikes = bm.random.random((1000, 100)) < 0.8 - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - - - def test_cc4(self): - bm.random.seed() - spikes = bm.random.random((1000, 100)) < 0.2 - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - - - def test_cc5(self): - bm.random.seed() - spikes = bm.random.random((1000, 100)) < 0.05 - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - + def test_c(self): + bm.random.seed() + spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T + cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.) + f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.)) + cc2 = f_cc(spikes) + print(cc1, cc2) + self.assertTrue(cc1 == cc2) + + def test_cc(self): + bm.random.seed() + spikes = bm.ones((1000, 10)) + cc1 = bp.measure.cross_correlation(spikes, 1.) + self.assertTrue(cc1 == 1.) + + spikes = bm.zeros((1000, 10)) + cc2 = bp.measure.cross_correlation(spikes, 1.) + self.assertTrue(cc2 == 0.) + + def test_cc2(self): + bm.random.seed() + spikes = bm.random.randint(0, 2, (1000, 10)) + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + + def test_cc3(self): + bm.random.seed() + spikes = bm.random.random((1000, 100)) < 0.8 + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + + def test_cc4(self): + bm.random.seed() + spikes = bm.random.random((1000, 100)) < 0.2 + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + + def test_cc5(self): + bm.random.seed() + spikes = bm.random.random((1000, 100)) < 0.05 + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) class TestVoltageFluctuation(unittest.TestCase): - def test_vf1(self): - bm.random.seed() - voltages = bm.random.normal(0, 10, size=(100, 10)) - print(bp.measure.voltage_fluctuation(voltages)) + def test_vf1(self): + bm.random.seed() + voltages = bm.random.normal(0, 10, size=(100, 10)) + print(bp.measure.voltage_fluctuation(voltages)) - bm.enable_x64() - voltages = bm.ones((100, 10)) - r1 = bp.measure.voltage_fluctuation(voltages) + bm.enable_x64() + voltages = bm.ones((100, 10)) + r1 = bp.measure.voltage_fluctuation(voltages) - jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False)) - jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False)) - r2 = jit_f(voltages) - print(r1, r2) # TODO: JIT results are different? - # self.assertTrue(r1 == r2) + jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False)) + jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False)) + r2 = jit_f(voltages) + print(r1, r2) # TODO: JIT results are different? + # self.assertTrue(r1 == r2) - bm.disable_x64() - + bm.disable_x64() class TestFunctionalConnectivity(unittest.TestCase): - def test_cf1(self): - bm.random.seed() - act = bm.random.random((10000, 3)) - r1 = bp.measure.functional_connectivity(act) + def test_cf1(self): + bm.random.seed() + act = bm.random.random((10000, 3)) + r1 = bp.measure.functional_connectivity(act) - jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False)) - r2 = jit_f(act) + jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False)) + r2 = jit_f(act) - self.assertTrue(bm.allclose(r1, r2)) - + self.assertTrue(bm.allclose(r1, r2)) class TestMatrixCorrelation(unittest.TestCase): - def test_mc(self): - bm.random.seed() - A = bm.random.random((100, 100)) - B = bm.random.random((100, 100)) - r1 = (bp.measure.matrix_correlation(A, B)) - - jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False)) - r2 = jit_f(A, B) - self.assertTrue(bm.allclose(r1, r2)) - - - + def test_mc(self): + bm.random.seed() + A = bm.random.random((100, 100)) + B = bm.random.random((100, 100)) + r1 = (bp.measure.matrix_correlation(A, B)) + + jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False)) + r2 = jit_f(A, B) + self.assertTrue(bm.allclose(r1, r2)) diff --git a/brainpy/_src/measure/tests/test_firings.py b/brainpy/_src/measure/tests/test_firings.py index 98d6b280b..c08e33801 100644 --- a/brainpy/_src/measure/tests/test_firings.py +++ b/brainpy/_src/measure/tests/test_firings.py @@ -2,23 +2,23 @@ import unittest + import brainpy as bp class TestFiringRate(unittest.TestCase): - def test_fr1(self): - spikes = bp.math.ones((1000, 10)) - print(bp.measure.firing_rate(spikes, 1.)) - - def test_fr2(self): - bp.math.random.seed() - spikes = bp.math.random.random((1000, 10)) < 0.2 - print(bp.measure.firing_rate(spikes, 1.)) - print(bp.measure.firing_rate(spikes, 10.)) + def test_fr1(self): + spikes = bp.math.ones((1000, 10)) + print(bp.measure.firing_rate(spikes, 1.)) - def test_fr3(self): - bp.math.random.seed() - spikes = bp.math.random.random((1000, 10)) < 0.02 - print(bp.measure.firing_rate(spikes, 1.)) - print(bp.measure.firing_rate(spikes, 5.)) + def test_fr2(self): + bp.math.random.seed() + spikes = bp.math.random.random((1000, 10)) < 0.2 + print(bp.measure.firing_rate(spikes, 1.)) + print(bp.measure.firing_rate(spikes, 10.)) + def test_fr3(self): + bp.math.random.seed() + spikes = bp.math.random.random((1000, 10)) < 0.02 + print(bp.measure.firing_rate(spikes, 1.)) + print(bp.measure.firing_rate(spikes, 5.)) diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 800e5a8ff..deaf21a82 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -4,9 +4,10 @@ import sys import warnings from dataclasses import dataclass -from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any -from typing import (_SpecialForm, _type_check, _remove_dups_flatten) +from typing import Union, Dict, Callable, Sequence, Optional, Any +from typing import (_type_check, _remove_dups_flatten) +import brainstate.mixin import jax from brainpy import math as bm, tools @@ -14,673 +15,443 @@ from brainpy.types import ArrayType if sys.version_info.minor > 8: - from typing import (_UnionGenericAlias) + from typing import (_UnionGenericAlias) else: - from typing import (_GenericAlias, _tp_cache) + from typing import (_GenericAlias, _tp_cache) DynamicalSystem = None delay_identifier, init_delay_by_return = None, None __all__ = [ - 'MixIn', - 'ParamDesc', - 'ParamDescriber', - 'DelayRegister', - 'AlignPost', - 'Container', - 'TreeNode', - 'BindCondData', - 'JointType', - 'SupportSTDP', - 'SupportAutoDelay', - 'SupportInputProj', - 'SupportOnline', - 'SupportOffline', + 'MixIn', + 'ParamDesc', + 'ParamDescriber', + 'DelayRegister', + 'AlignPost', + 'Container', + 'TreeNode', + 'BindCondData', + 'JointType', + 'SupportSTDP', + 'SupportAutoDelay', + 'SupportInputProj', + 'SupportOnline', + 'SupportOffline', ] def _get_delay_tool(): - global delay_identifier, init_delay_by_return - if init_delay_by_return is None: from brainpy._src.delay import init_delay_by_return - if delay_identifier is None: from brainpy._src.delay import delay_identifier - return delay_identifier, init_delay_by_return + global delay_identifier, init_delay_by_return + if init_delay_by_return is None: from brainpy._src.delay import init_delay_by_return + if delay_identifier is None: from brainpy._src.delay import delay_identifier + return delay_identifier, init_delay_by_return def _get_dynsys(): - global DynamicalSystem - if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem - return DynamicalSystem - - -class MixIn(object): - """Base MixIn object. - - The key for a :py:class:`~.MixIn` is that: no initialization function, only behavioral functions. - """ - pass - - -class ParamDesc(MixIn): - """:py:class:`~.MixIn` indicates the function for describing initialization parameters. - - This mixin enables the subclass has a classmethod ``desc``, which - produces an instance of :py:class:`~.ParamDescInit`. - - Note this MixIn can be applied in any Python object. - """ - - not_desc_params: Optional[Sequence[str]] = None - - @classmethod - def desc(cls, *args, **kwargs) -> 'ParamDescriber': - return ParamDescriber(cls, *args, **kwargs) - - -class ParamDescriber(object): - """Delayed initialization for parameter describers. - """ - - def __init__(self, cls: type, *desc_tuple, **desc_dict): - self.cls = cls - - # arguments - self.args = desc_tuple - self.kwargs = desc_dict - - # identifier - if isinstance(cls, _JointGenericAlias): - name = str(cls) - repr_kwargs = {k: v for k, v in desc_dict.items()} - else: - assert isinstance(cls, type) - if issubclass(cls, ParamDesc) and (cls.not_desc_params is not None): - repr_kwargs = {k: v for k, v in desc_dict.items() if k not in cls.not_desc_params} - else: - repr_kwargs = {k: v for k, v in desc_dict.items()} - name = cls.__name__ - for k in tuple(repr_kwargs.keys()): - if isinstance(repr_kwargs[k], bm.Variable): - repr_kwargs[k] = id(repr_kwargs[k]) - repr_args = tools.repr_dict(repr_kwargs) - if len(desc_tuple): - repr_args = f"{', '.join([repr(arg) for arg in desc_tuple])}, {repr_args}" - self._identifier = f'{name}({repr_args})' - - def __call__(self, *args, **kwargs): - return self.cls(*self.args, *args, **self.kwargs, **kwargs) - - def init(self, *args, **kwargs): - return self.__call__(*args, **kwargs) - - def __instancecheck__(self, instance): - if not isinstance(instance, ParamDescriber): - return False - if not issubclass(instance.cls, self.cls): - return False - return True - - @classmethod - def __class_getitem__(cls, item: type): - return ParamDescriber(item) - - @property - def identifier(self): - return self._identifier - - @identifier.setter - def identifier(self, value): - self._identifier = value - - -class AlignPost(MixIn): - """Align post MixIn. + global DynamicalSystem + if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem + return DynamicalSystem - This class provides a ``add_current()`` function for - add external currents. - """ - def add_current(self, *args, **kwargs): - raise NotImplementedError +MixIn = brainstate.mixin.Mixin +ParamDesc = brainstate.mixin.ParamDesc +ParamDescriber = brainstate.mixin.ParamDescriber +AlignPost = brainstate.mixin.AlignPost +BindCondData = brainstate.mixin.BindCondData +JointType = brainstate.mixin.JointTypes @dataclass class ReturnInfo: - size: Sequence[int] - axis_names: Optional[Sequence[str]] = None - batch_or_mode: Optional[Union[int, bm.Mode]] = None - data: Union[Callable, bm.Array, jax.Array] = bm.zeros - - def get_data(self): - if isinstance(self.data, Callable): - if isinstance(self.batch_or_mode, int): - size = (self.batch_or_mode,) + tuple(self.size) - elif isinstance(self.batch_or_mode, bm.NonBatchingMode): - size = tuple(self.size) - elif isinstance(self.batch_or_mode, bm.BatchingMode): - size = (self.batch_or_mode.batch_size,) + tuple(self.size) - else: - size = tuple(self.size) - init = self.data(size) - elif isinstance(self.data, (bm.BaseArray, jax.Array)): - init = self.data - else: - raise ValueError - return init + size: Sequence[int] + axis_names: Optional[Sequence[str]] = None + batch_or_mode: Optional[Union[int, bm.Mode]] = None + data: Union[Callable, bm.Array, jax.Array] = bm.zeros + + def get_data(self): + if isinstance(self.data, Callable): + if isinstance(self.batch_or_mode, int): + size = (self.batch_or_mode,) + tuple(self.size) + elif isinstance(self.batch_or_mode, bm.NonBatchingMode): + size = tuple(self.size) + elif isinstance(self.batch_or_mode, bm.BatchingMode): + size = (self.batch_or_mode.batch_size,) + tuple(self.size) + else: + size = tuple(self.size) + init = self.data(size) + elif isinstance(self.data, (bm.BaseArray, jax.Array)): + init = self.data + else: + raise ValueError + return init class Container(MixIn): - """Container :py:class:`~.MixIn` which wrap a group of objects. - """ - children: bm.node_dict - - def __getitem__(self, item): - """Overwrite the slice access (`self['']`). """ - if item in self.children: - return self.children[item] - else: - raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}') - - def __getattr__(self, item): - """Overwrite the dot access (`self.`). """ - if item == 'children': - return super().__getattribute__('children') - else: - children = super().__getattribute__('children') - if item in children: - return children[item] - else: - return super().__getattribute__(item) - - def __repr__(self): - cls_name = self.__class__.__name__ - indent = ' ' * len(cls_name) - child_str = [tools.repr_context(repr(val), indent) for val in self.children.values()] - string = ", \n".join(child_str) - return f'{cls_name}({string})' - - def __get_elem_name(self, elem): - if isinstance(elem, bm.BrainPyObject): - return elem.name - else: - return get_unique_name('ContainerElem') - - def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict): - res = dict() - - # add tuple-typed components - for module in children_as_tuple: - if isinstance(module, child_type): - res[self.__get_elem_name(module)] = module - elif isinstance(module, (list, tuple)): - for m in module: - if not isinstance(m, child_type): - raise ValueError(f'Should be instance of {child_type.__name__}. ' - f'But we got {type(m)}') - res[self.__get_elem_name(m)] = m - elif isinstance(module, dict): - for k, v in module.items(): - if not isinstance(v, child_type): - raise ValueError(f'Should be instance of {child_type.__name__}. ' - f'But we got {type(v)}') - res[k] = v - else: - raise ValueError(f'Cannot parse sub-systems. They should be {child_type.__name__} ' - f'or a list/tuple/dict of {child_type.__name__}.') - # add dict-typed components - for k, v in children_as_dict.items(): - if not isinstance(v, child_type): - raise ValueError(f'Should be instance of {child_type.__name__}. ' - f'But we got {type(v)}') - res[k] = v - return res - - def add_elem(self, *elems, **elements): - """Add new elements. - - >>> obj = Container() - >>> obj.add_elem(a=1.) - - Args: - elements: children objects. + """Container :py:class:`~.MixIn` which wrap a group of objects. """ - self.children.update(self.format_elements(object, *elems, **elements)) + children: bm.node_dict + + def __getitem__(self, item): + """Overwrite the slice access (`self['']`). """ + if item in self.children: + return self.children[item] + else: + raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}') + + def __getattr__(self, item): + """Overwrite the dot access (`self.`). """ + if item == 'children': + return super().__getattribute__('children') + else: + children = super().__getattribute__('children') + if item in children: + return children[item] + else: + return super().__getattribute__(item) + + def __repr__(self): + cls_name = self.__class__.__name__ + indent = ' ' * len(cls_name) + child_str = [tools.repr_context(repr(val), indent) for val in self.children.values()] + string = ", \n".join(child_str) + return f'{cls_name}({string})' + + def __get_elem_name(self, elem): + if isinstance(elem, bm.BrainPyObject): + return elem.name + else: + return get_unique_name('ContainerElem') + + def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict): + res = dict() + + # add tuple-typed components + for module in children_as_tuple: + if isinstance(module, child_type): + res[self.__get_elem_name(module)] = module + elif isinstance(module, (list, tuple)): + for m in module: + if not isinstance(m, child_type): + raise ValueError(f'Should be instance of {child_type.__name__}. ' + f'But we got {type(m)}') + res[self.__get_elem_name(m)] = m + elif isinstance(module, dict): + for k, v in module.items(): + if not isinstance(v, child_type): + raise ValueError(f'Should be instance of {child_type.__name__}. ' + f'But we got {type(v)}') + res[k] = v + else: + raise ValueError(f'Cannot parse sub-systems. They should be {child_type.__name__} ' + f'or a list/tuple/dict of {child_type.__name__}.') + # add dict-typed components + for k, v in children_as_dict.items(): + if not isinstance(v, child_type): + raise ValueError(f'Should be instance of {child_type.__name__}. ' + f'But we got {type(v)}') + res[k] = v + return res + + def add_elem(self, *elems, **elements): + """Add new elements. + + >>> obj = Container() + >>> obj.add_elem(a=1.) + + Args: + elements: children objects. + """ + self.children.update(self.format_elements(object, *elems, **elements)) class TreeNode(MixIn): - """Tree node. """ - - master_type: type - - def check_hierarchies(self, root, *leaves, **named_leaves): - global DynamicalSystem - if DynamicalSystem is None: - from brainpy._src.dynsys import DynamicalSystem - - for leaf in leaves: - if isinstance(leaf, DynamicalSystem): - self.check_hierarchy(root, leaf) - elif isinstance(leaf, (list, tuple)): - self.check_hierarchies(root, *leaf) - elif isinstance(leaf, dict): - self.check_hierarchies(root, **leaf) - else: - raise ValueError(f'Do not support {type(leaf)}.') - for leaf in named_leaves.values(): - if not isinstance(leaf, DynamicalSystem): - raise ValueError(f'Do not support {type(leaf)}. Must be instance of {DynamicalSystem.__name__}') - self.check_hierarchy(root, leaf) - - def check_hierarchy(self, root, leaf): - if hasattr(leaf, 'master_type'): - master_type = leaf.master_type - else: - raise ValueError('Child class should define "master_type" to ' - 'specify the type of the root node. ' - f'But we did not found it in {leaf}') - if not issubclass(root, master_type): - raise TypeError(f'Type does not match. {leaf} requires a master with type ' - f'of {leaf.master_type}, but the master now is {root}.') + """Tree node. """ + + master_type: type + + def check_hierarchies(self, root, *leaves, **named_leaves): + global DynamicalSystem + if DynamicalSystem is None: + from brainpy._src.dynsys import DynamicalSystem + + for leaf in leaves: + if isinstance(leaf, DynamicalSystem): + self.check_hierarchy(root, leaf) + elif isinstance(leaf, (list, tuple)): + self.check_hierarchies(root, *leaf) + elif isinstance(leaf, dict): + self.check_hierarchies(root, **leaf) + else: + raise ValueError(f'Do not support {type(leaf)}.') + for leaf in named_leaves.values(): + if not isinstance(leaf, DynamicalSystem): + raise ValueError(f'Do not support {type(leaf)}. Must be instance of {DynamicalSystem.__name__}') + self.check_hierarchy(root, leaf) + + def check_hierarchy(self, root, leaf): + if hasattr(leaf, 'master_type'): + master_type = leaf.master_type + else: + raise ValueError('Child class should define "master_type" to ' + 'specify the type of the root node. ' + f'But we did not found it in {leaf}') + if not issubclass(root, master_type): + raise TypeError(f'Type does not match. {leaf} requires a master with type ' + f'of {leaf.master_type}, but the master now is {root}.') class DelayRegister(MixIn): - def register_delay( - self, - identifier: str, - delay_step: Optional[Union[int, ArrayType, Callable]], - delay_target: bm.Variable, - initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None, - ): - """Register delay variable. - - Args: - identifier: str. The delay access name. - delay_target: The target variable for delay. - delay_step: The delay time step. - initial_delay_data: The initializer for the delay data. - - Returns: - delay_pos: The position of the delay. - """ - _delay_identifier, _init_delay_by_return = _get_delay_tool() - DynamicalSystem = _get_dynsys() - assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' - _delay_identifier = _delay_identifier + identifier - if not self.has_aft_update(_delay_identifier): - self.add_aft_update(_delay_identifier, _init_delay_by_return(delay_target, initial_delay_data)) - delay_cls = self.get_aft_update(_delay_identifier) - name = get_unique_name('delay') - delay_cls.register_entry(name, delay_step) - return name - - def get_delay_data( - self, - identifier: str, - delay_pos: str, - *indices: Union[int, slice, bm.Array, jax.Array], - ): - """Get delay data according to the provided delay steps. - - Parameters:: - - identifier: str - The delay variable name. - delay_pos: str - The delay length. - indices: optional, int, slice, ArrayType - The indices of the delay. - - Returns:: - - delay_data: ArrayType - The delay data at the given time. - """ - _delay_identifier, _init_delay_by_return = _get_delay_tool() - _delay_identifier = _delay_identifier + identifier - delay_cls = self.get_aft_update(_delay_identifier) - return delay_cls.at(delay_pos, *indices) - - def update_local_delays(self, nodes: Union[Sequence, Dict] = None): - """Update local delay variables. - - This function should be called after updating neuron groups or delay sources. - For example, in a network model, - - - Parameters:: - - nodes: sequence, dict - The nodes to update their delay variables. - """ - warnings.warn('.update_local_delays() has been removed since brainpy>=2.4.6', - DeprecationWarning) - - def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): - """Reset local delay variables. - - Parameters:: - - nodes: sequence, dict - The nodes to Reset their delay variables. - """ - warnings.warn('.reset_local_delays() has been removed since brainpy>=2.4.6', - DeprecationWarning) - - def get_delay_var(self, name): - _delay_identifier, _init_delay_by_return = _get_delay_tool() - _delay_identifier = _delay_identifier + name - delay_cls = self.get_aft_update(_delay_identifier) - return delay_cls + def register_delay( + self, + identifier: str, + delay_step: Optional[Union[int, ArrayType, Callable]], + delay_target: bm.Variable, + initial_delay_data: Union[Callable, ArrayType, numbers.Number] = None, + ): + """Register delay variable. + + Args: + identifier: str. The delay access name. + delay_target: The target variable for delay. + delay_step: The delay time step. + initial_delay_data: The initializer for the delay data. + + Returns: + delay_pos: The position of the delay. + """ + _delay_identifier, _init_delay_by_return = _get_delay_tool() + DynamicalSystem = _get_dynsys() + assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' + _delay_identifier = _delay_identifier + identifier + if not self.has_aft_update(_delay_identifier): + self.add_aft_update(_delay_identifier, _init_delay_by_return(delay_target, initial_delay_data)) + delay_cls = self.get_aft_update(_delay_identifier) + name = get_unique_name('delay') + delay_cls.register_entry(name, delay_step) + return name + + def get_delay_data( + self, + identifier: str, + delay_pos: str, + *indices: Union[int, slice, bm.Array, jax.Array], + ): + """Get delay data according to the provided delay steps. + + Parameters:: + + identifier: str + The delay variable name. + delay_pos: str + The delay length. + indices: optional, int, slice, ArrayType + The indices of the delay. + + Returns:: + + delay_data: ArrayType + The delay data at the given time. + """ + _delay_identifier, _init_delay_by_return = _get_delay_tool() + _delay_identifier = _delay_identifier + identifier + delay_cls = self.get_aft_update(_delay_identifier) + return delay_cls.at(delay_pos, *indices) + + def update_local_delays(self, nodes: Union[Sequence, Dict] = None): + """Update local delay variables. + + This function should be called after updating neuron groups or delay sources. + For example, in a network model, + + + Parameters:: + + nodes: sequence, dict + The nodes to update their delay variables. + """ + warnings.warn('.update_local_delays() has been removed since brainpy>=2.4.6', + DeprecationWarning) + + def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): + """Reset local delay variables. + + Parameters:: + + nodes: sequence, dict + The nodes to Reset their delay variables. + """ + warnings.warn('.reset_local_delays() has been removed since brainpy>=2.4.6', + DeprecationWarning) + + def get_delay_var(self, name): + _delay_identifier, _init_delay_by_return = _get_delay_tool() + _delay_identifier = _delay_identifier + name + delay_cls = self.get_aft_update(_delay_identifier) + return delay_cls class SupportInputProj(MixIn): - """The :py:class:`~.MixIn` that receives the input projections. - - Note that the subclass should define a ``cur_inputs`` attribute. Otherwise, - the input function utilities cannot be used. - - """ - current_inputs: bm.node_dict - delta_inputs: bm.node_dict + """The :py:class:`~.MixIn` that receives the input projections. - def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): - """Add an input function. + Note that the subclass should define a ``cur_inputs`` attribute. Otherwise, + the input function utilities cannot be used. - Args: - key: str. The dict key. - fun: Callable. The function to generate inputs. - label: str. The input label. - category: str. The input category, should be ``current`` (the current) or - ``delta`` (the delta synapse, indicating the delta function). """ - if not callable(fun): - raise TypeError('Must be a function.') - - key = self._input_label_repr(key, label) - if category == 'current': - if key in self.current_inputs: - raise ValueError(f'Key "{key}" has been defined and used.') - self.current_inputs[key] = fun - elif category == 'delta': - if key in self.delta_inputs: - raise ValueError(f'Key "{key}" has been defined and used.') - self.delta_inputs[key] = fun - else: - raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".') - - def get_inp_fun(self, key: str): - """Get the input function. - - Args: - key: str. The key. - - Returns: - The input function which generates currents. - """ - if key in self.current_inputs: - return self.current_inputs[key] - elif key in self.delta_inputs: - return self.delta_inputs[key] - else: - raise ValueError(f'Unknown key: {key}') - - def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): - """Summarize all current inputs by the defined input functions ``.current_inputs``. - - Args: - *args: The arguments for input functions. - init: The initial input data. - label: str. The input label. - **kwargs: The arguments for input functions. - - Returns: - The total currents. - """ - if label is None: - for key, out in self.current_inputs.items(): - init = init + out(*args, **kwargs) - else: - label_repr = self._input_label_start(label) - for key, out in self.current_inputs.items(): - if key.startswith(label_repr): - init = init + out(*args, **kwargs) - return init - - def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): - """Summarize all delta inputs by the defined input functions ``.delta_inputs``. - - Args: - *args: The arguments for input functions. - init: The initial input data. - label: str. The input label. - **kwargs: The arguments for input functions. - - Returns: - The total currents. - """ - if label is None: - for key, out in self.delta_inputs.items(): - init = init + out(*args, **kwargs) - else: - label_repr = self._input_label_start(label) - for key, out in self.delta_inputs.items(): - if key.startswith(label_repr): - init = init + out(*args, **kwargs) - return init - - @classmethod - def _input_label_start(cls, label: str): - # unify the input label repr. - return f'{label} // ' - - @classmethod - def _input_label_repr(cls, name: str, label: Optional[str] = None): - # unify the input label repr. - return name if label is None else (cls._input_label_start(label) + str(name)) - - # deprecated # - # ---------- # - - @property - def cur_inputs(self): - return self.current_inputs - - def sum_inputs(self, *args, **kwargs): - warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning) - return self.sum_current_inputs(*args, **kwargs) + current_inputs: bm.node_dict + delta_inputs: bm.node_dict + + def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): + """Add an input function. + + Args: + key: str. The dict key. + fun: Callable. The function to generate inputs. + label: str. The input label. + category: str. The input category, should be ``current`` (the current) or + ``delta`` (the delta synapse, indicating the delta function). + """ + if not callable(fun): + raise TypeError('Must be a function.') + + key = self._input_label_repr(key, label) + if category == 'current': + if key in self.current_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.current_inputs[key] = fun + elif category == 'delta': + if key in self.delta_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.delta_inputs[key] = fun + else: + raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".') + + def get_inp_fun(self, key: str): + """Get the input function. + + Args: + key: str. The key. + + Returns: + The input function which generates currents. + """ + if key in self.current_inputs: + return self.current_inputs[key] + elif key in self.delta_inputs: + return self.delta_inputs[key] + else: + raise ValueError(f'Unknown key: {key}') + + def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): + """Summarize all current inputs by the defined input functions ``.current_inputs``. + + Args: + *args: The arguments for input functions. + init: The initial input data. + label: str. The input label. + **kwargs: The arguments for input functions. + + Returns: + The total currents. + """ + if label is None: + for key, out in self.current_inputs.items(): + init = init + out(*args, **kwargs) + else: + label_repr = self._input_label_start(label) + for key, out in self.current_inputs.items(): + if key.startswith(label_repr): + init = init + out(*args, **kwargs) + return init + + def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): + """Summarize all delta inputs by the defined input functions ``.delta_inputs``. + + Args: + *args: The arguments for input functions. + init: The initial input data. + label: str. The input label. + **kwargs: The arguments for input functions. + + Returns: + The total currents. + """ + if label is None: + for key, out in self.delta_inputs.items(): + init = init + out(*args, **kwargs) + else: + label_repr = self._input_label_start(label) + for key, out in self.delta_inputs.items(): + if key.startswith(label_repr): + init = init + out(*args, **kwargs) + return init + + @classmethod + def _input_label_start(cls, label: str): + # unify the input label repr. + return f'{label} // ' + + @classmethod + def _input_label_repr(cls, name: str, label: Optional[str] = None): + # unify the input label repr. + return name if label is None else (cls._input_label_start(label) + str(name)) + + # deprecated # + # ---------- # + + @property + def cur_inputs(self): + return self.current_inputs + + def sum_inputs(self, *args, **kwargs): + warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning) + return self.sum_current_inputs(*args, **kwargs) class SupportReturnInfo(MixIn): - """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" + """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" - def return_info(self) -> Union[bm.Variable, ReturnInfo]: - raise NotImplementedError('Must implement the "return_info()" function.') + def return_info(self) -> Union[bm.Variable, ReturnInfo]: + raise NotImplementedError('Must implement the "return_info()" function.') class SupportAutoDelay(SupportReturnInfo): - pass + pass class SupportOnline(MixIn): - """:py:class:`~.MixIn` to support the online training methods. + """:py:class:`~.MixIn` to support the online training methods. - .. versionadded:: 2.4.5 - """ + .. versionadded:: 2.4.5 + """ - online_fit_by: Optional # methods for online fitting + online_fit_by: Optional # methods for online fitting - def online_init(self, *args, **kwargs): - raise NotImplementedError + def online_init(self, *args, **kwargs): + raise NotImplementedError - def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): - raise NotImplementedError + def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): + raise NotImplementedError class SupportOffline(MixIn): - """:py:class:`~.MixIn` to support the offline training methods. - - .. versionadded:: 2.4.5 - """ - - offline_fit_by: Optional # methods for offline fitting - - def offline_init(self, *args, **kwargs): - pass - - def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): - raise NotImplementedError - - -class BindCondData(MixIn): - """Bind temporary conductance data. + """:py:class:`~.MixIn` to support the offline training methods. + .. versionadded:: 2.4.5 + """ - """ - _conductance: Optional + offline_fit_by: Optional # methods for offline fitting - def bind_cond(self, conductance): - self._conductance = conductance + def offline_init(self, *args, **kwargs): + pass - def unbind_cond(self): - self._conductance = None + def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]): + raise NotImplementedError class SupportSTDP(MixIn): - """Support synaptic plasticity by modifying the weights. - """ - - def stdp_update(self, *args, on_pre=None, onn_post=None, **kwargs): - raise NotImplementedError - - -T = TypeVar('T') - - -def get_type(types): - class NewType(type): - def __instancecheck__(self, other): - cls_of_other = other.__class__ - return all([issubclass(cls_of_other, cls) for cls in types]) - - return NewType - - -class _MetaUnionType(type): - def __new__(cls, name, bases, dct): - if isinstance(bases, type): - bases = (bases,) - elif isinstance(bases, (list, tuple)): - bases = tuple(bases) - for base in bases: - assert isinstance(base, type), f'Must be type. But got {base}' - else: - raise TypeError(f'Must be type. But got {bases}') - return super().__new__(cls, name, bases, dct) - - def __instancecheck__(self, other): - cls_of_other = other.__class__ - return all([issubclass(cls_of_other, cls) for cls in self.__bases__]) - - def __subclasscheck__(self, subclass): - return all([issubclass(subclass, cls) for cls in self.__bases__]) - - -if sys.version_info.minor > 8: - class _JointGenericAlias(_UnionGenericAlias, _root=True): - def __subclasscheck__(self, subclass): - return all([issubclass(subclass, cls) for cls in set(self.__args__)]) - - - @_SpecialForm - def JointType(self, parameters): - """Joint type; JointType[X, Y] means both X and Y. - - To define a union, use e.g. Union[int, str]. - - Details: - - The arguments must be types and there must be at least one. - - None as an argument is a special case and is replaced by `type(None)`. - - Unions of unions are flattened, e.g.:: - - JointType[JointType[int, str], float] == JointType[int, str, float] - - - Unions of a single argument vanish, e.g.:: - - JointType[int] == int # The constructor actually returns int - - - Redundant arguments are skipped, e.g.:: - - JointType[int, str, int] == JointType[int, str] - - - When comparing unions, the argument order is ignored, e.g.:: - - JointType[int, str] == JointType[str, int] - - - You cannot subclass or instantiate a union. - - You can use Optional[X] as a shorthand for JointType[X, None]. + """Support synaptic plasticity by modifying the weights. """ - if parameters == (): - raise TypeError("Cannot take a Joint of no types.") - if not isinstance(parameters, tuple): - parameters = (parameters,) - msg = "JointType[arg, ...]: each arg must be a type." - parameters = tuple(_type_check(p, msg) for p in parameters) - parameters = _remove_dups_flatten(parameters) - if len(parameters) == 1: - return parameters[0] - return _JointGenericAlias(self, parameters) -else: - class _JointGenericAlias(_GenericAlias, _root=True): - def __subclasscheck__(self, subclass): - return all([issubclass(subclass, cls) for cls in set(self.__args__)]) - - - class _SpecialForm2(_SpecialForm, _root=True): - @_tp_cache - def __getitem__(self, parameters): - if self._name == 'JointType': - if parameters == (): - raise TypeError("Cannot take a Joint of no types.") - if not isinstance(parameters, tuple): - parameters = (parameters,) - msg = "JointType[arg, ...]: each arg must be a type." - parameters = tuple(_type_check(p, msg) for p in parameters) - parameters = _remove_dups_flatten(parameters) - if len(parameters) == 1: - return parameters[0] - return _JointGenericAlias(self, parameters) - else: - return super().__getitem__(parameters) - - - JointType = _SpecialForm2( - 'JointType', - doc="""Joint type; JointType[X, Y] means both X and Y. - - To define a joint, use e.g. JointType[int, str]. - - Details: - - - The arguments must be types and there must be at least one. - - None as an argument is a special case and is replaced by `type(None)`. - - Unions of unions are flattened, e.g.:: - - JointType[JointType[int, str], float] == JointType[int, str, float] - - - Unions of a single argument vanish, e.g.:: - - JointType[int] == int # The constructor actually returns int - - - Redundant arguments are skipped, e.g.:: - - JointType[int, str, int] == JointType[int, str] - - - When comparing unions, the argument order is ignored, e.g.:: - - JointType[int, str] == JointType[str, int] - - - You cannot subclass or instantiate a union. - - You can use Optional[X] as a shorthand for JointType[X, None]. - """ - ) + def stdp_update(self, *args, on_pre=None, onn_post=None, **kwargs): + raise NotImplementedError diff --git a/brainpy/_src/modes.py b/brainpy/_src/modes.py index 150e182a0..e84e2d2a6 100644 --- a/brainpy/_src/modes.py +++ b/brainpy/_src/modes.py @@ -5,23 +5,19 @@ Please use ``brainpy.math.*`` instead. """ -from brainpy._src.math import modes from brainpy import check from brainpy._src.deprecations import deprecation_getattr2 +from brainpy._src.math import modes __deprecations = { - 'Mode': ('brainpy.modes.Mode', 'brainpy.math.Mode', modes.Mode), - 'NormalMode': ('brainpy.modes.NormalMode', 'brainpy.math.NonBatchingMode', modes.NonBatchingMode), - 'BatchingMode': ('brainpy.modes.BatchingMode', 'brainpy.math.BatchingMode', modes.BatchingMode), - 'TrainingMode': ('brainpy.modes.TrainingMode', 'brainpy.math.TrainingMode', modes.TrainingMode), - 'normal': ('brainpy.modes.normal', 'brainpy.math.nonbatching_mode', modes.nonbatching_mode), - 'batching': ('brainpy.modes.batching', 'brainpy.math.batching_mode', modes.batching_mode), - 'training': ('brainpy.modes.training', 'brainpy.math.training_mode', modes.training_mode), - 'check_mode': ('brainpy.modes.check_mode', 'brainpy.check.is_subclass', check.is_subclass), + 'Mode': ('brainpy.modes.Mode', 'brainpy.math.Mode', modes.Mode), + 'NormalMode': ('brainpy.modes.NormalMode', 'brainpy.math.NonBatchingMode', modes.NonBatchingMode), + 'BatchingMode': ('brainpy.modes.BatchingMode', 'brainpy.math.BatchingMode', modes.BatchingMode), + 'TrainingMode': ('brainpy.modes.TrainingMode', 'brainpy.math.TrainingMode', modes.TrainingMode), + 'normal': ('brainpy.modes.normal', 'brainpy.math.nonbatching_mode', modes.nonbatching_mode), + 'batching': ('brainpy.modes.batching', 'brainpy.math.batching_mode', modes.batching_mode), + 'training': ('brainpy.modes.training', 'brainpy.math.training_mode', modes.training_mode), + 'check_mode': ('brainpy.modes.check_mode', 'brainpy.check.is_subclass', check.is_subclass), } __getattr__ = deprecation_getattr2('brainpy.modes', __deprecations) del deprecation_getattr2 - - - - diff --git a/brainpy/_src/optimizers/optimizer.py b/brainpy/_src/optimizers/optimizer.py index 09c25b4db..11e45fd57 100644 --- a/brainpy/_src/optimizers/optimizer.py +++ b/brainpy/_src/optimizers/optimizer.py @@ -5,7 +5,6 @@ import jax.numpy as jnp from jax.lax import cond -import brainpy as bp import brainpy.math as bm from brainpy import check @@ -14,1087 +13,1087 @@ from .scheduler import make_schedule, Scheduler __all__ = [ - 'Optimizer', - 'SGD', - 'Momentum', - 'MomentumNesterov', - 'Adagrad', - 'Adadelta', - 'RMSProp', - 'Adam', - 'LARS', - 'Adan', - 'AdamW', + 'Optimizer', + 'SGD', + 'Momentum', + 'MomentumNesterov', + 'Adagrad', + 'Adadelta', + 'RMSProp', + 'Adam', + 'LARS', + 'Adan', + 'AdamW', ] class Optimizer(BrainPyObject): - """Base Optimizer Class. + """Base Optimizer Class. - Parameters:: + Parameters:: - lr: float, Scheduler - learning rate. - """ + lr: float, Scheduler + learning rate. + """ - lr: Scheduler # learning rate - '''Learning rate''' + lr: Scheduler # learning rate + '''Learning rate''' - vars_to_train: ArrayCollector # variables to train - '''Variables to train.''' + vars_to_train: ArrayCollector # variables to train + '''Variables to train.''' - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, - name: Optional[str] = None - ): - super(Optimizer, self).__init__(name=name) - self.lr: Scheduler = make_schedule(lr) - self.vars_to_train = ArrayCollector() - self.register_train_vars(train_vars) + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + name: Optional[str] = None + ): + super(Optimizer, self).__init__(name=name) + self.lr: Scheduler = make_schedule(lr) + self.vars_to_train = ArrayCollector() + self.register_train_vars(train_vars) - def register_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - warnings.warn('Using "register_train_vars()" instead.', UserWarning) - self.register_train_vars(train_vars) + def register_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + warnings.warn('Using "register_train_vars()" instead.', UserWarning) + self.register_train_vars(train_vars) - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - raise NotImplementedError + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + raise NotImplementedError - def check_grads(self, grads): - if len(grads) != len(self.vars_to_train): - raise MathError(f'The length of "grads" must be equal to "self.vars_to_train", ' - f'while we got {len(grads)} != {len(self.vars_to_train)}!') + def check_grads(self, grads): + if len(grads) != len(self.vars_to_train): + raise MathError(f'The length of "grads" must be equal to "self.vars_to_train", ' + f'while we got {len(grads)} != {len(self.vars_to_train)}!') - def __repr__(self): - return f"{self.__class__.__name__}(lr={self.lr})" + def __repr__(self): + return f"{self.__class__.__name__}(lr={self.lr})" - def update(self, grads: dict): - raise NotImplementedError + def update(self, grads: dict): + raise NotImplementedError class CommonOpt(Optimizer): - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, - weight_decay: Optional[float] = None, - name: Optional[str] = None - ): - super(Optimizer, self).__init__(name=name) - self.lr: Scheduler = make_schedule(lr) - self.vars_to_train = ArrayCollector() - self.register_train_vars(train_vars) - self.weight_decay = check.is_float(weight_decay, min_bound=0., max_bound=1., allow_none=True) + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + weight_decay: Optional[float] = None, + name: Optional[str] = None + ): + super(Optimizer, self).__init__(name=name) + self.lr: Scheduler = make_schedule(lr) + self.vars_to_train = ArrayCollector() + self.register_train_vars(train_vars) + self.weight_decay = check.is_float(weight_decay, min_bound=0., max_bound=1., allow_none=True) class SGD(CommonOpt): - r"""Stochastic gradient descent optimizer. - - SGD performs a parameter update for training examples :math:`x` and label - :math:`y`: - - .. math:: + r"""Stochastic gradient descent optimizer. - \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y) + SGD performs a parameter update for training examples :math:`x` and label + :math:`y`: + .. math:: - Parameters:: - - lr: float, Scheduler - learning rate. - - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - weight_decay: Optional[float] = None, - name: Optional[str] = None - ): - super(SGD, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - - def __repr__(self): - return f'{self.__class__.__name__}(lr={self.lr})' - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - for key, p in self.vars_to_train.items(): - if self.weight_decay is None: - p.value -= lr * grads[key] - else: - p.value = (1 - self.weight_decay) * p - lr * grads[key] - self.lr.step_call() - + \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y) -class Momentum(CommonOpt): - r"""Momentum optimizer. - Momentum [1]_ is a method that helps accelerate SGD in the relevant direction - and dampens oscillations. It does this by adding a fraction :math:`\gamma` - of the update vector of the past time step to the current update vector: + Parameters:: - .. math:: + lr: float, Scheduler + learning rate. - \begin{align} - \begin{split} - v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\ - \theta &= \theta - v_t - \end{split} - \end{align} + """ - Parameters:: - - lr: float, Scheduler - learning rate. - - References:: + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + weight_decay: Optional[float] = None, + name: Optional[str] = None + ): + super(SGD, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) - .. [1] Qian, N. (1999). On the momentum term in gradient descent learning - algorithms. Neural Networks : The Official Journal of the International - Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6 + def __repr__(self): + return f'{self.__class__.__name__}(lr={self.lr})' - """ + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - momentum: float = 0.9, - weight_decay: Optional[float] = None, - name: Optional[str] = None - ): - super(Momentum, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + for key, p in self.vars_to_train.items(): + if self.weight_decay is None: + p.value -= lr * grads[key] + else: + p.value = (1 - self.weight_decay) * p - lr * grads[key] + self.lr.step_call() - self.momentum = momentum - - def __repr__(self): - return f'{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})' - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - vs = dict((key + '_v', bm.Variable(bm.zeros_like(x))) - for key, x in train_vars.items()) - self.register_implicit_vars(vs) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - for key, p in self.vars_to_train.items(): - g = grads[key] - v = self.implicit_vars[key + '_v'] - v.value = self.momentum * v.value - lr * g - if self.weight_decay is None: - p.value += v.value - else: - p.value = (1 - self.weight_decay) * p + v - self.lr.step_call() +class Momentum(CommonOpt): + r"""Momentum optimizer. -class MomentumNesterov(CommonOpt): - r"""Nesterov accelerated gradient optimizer [2]_. + Momentum [1]_ is a method that helps accelerate SGD in the relevant direction + and dampens oscillations. It does this by adding a fraction :math:`\gamma` + of the update vector of the past time step to the current update vector: - .. math:: + .. math:: \begin{align} \begin{split} - v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\ + v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta) \\ \theta &= \theta - v_t \end{split} \end{align} - Parameters:: + Parameters:: - lr: float, Scheduler - learning rate. + lr: float, Scheduler + learning rate. - References:: - - .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547. - - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - weight_decay: Optional[float] = None, - momentum: float = 0.9, - name: Optional[str] = None - ): - super(MomentumNesterov, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - - self.momentum = momentum - - def __repr__(self): - return f'{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})' - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - vs = dict((key + '_v', bm.Variable(bm.zeros_like(x))) - for key, x in train_vars.items()) - self.register_implicit_vars(vs) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - for key, p in self.vars_to_train.items(): - g = grads[key] - v = self.implicit_vars[key + '_v'] - v.value = self.momentum * v.value - lr * g - if self.weight_decay is None: - p.value += v - else: - p.value = (1 - self.weight_decay) * p + v - self.lr.step_call() + References:: + .. [1] Qian, N. (1999). On the momentum term in gradient descent learning + algorithms. Neural Networks : The Official Journal of the International + Neural Network Society, 12(1), 145–151. http://doi.org/10.1016/S0893-6080(98)00116-6 -class Adagrad(CommonOpt): - r"""Optimizer that implements the Adagrad algorithm. + """ - Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are - adapted relative to how frequently a parameter gets updated during training. - The more updates a parameter receives, the smaller the updates. + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + momentum: float = 0.9, + weight_decay: Optional[float] = None, + name: Optional[str] = None + ): + super(Momentum, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) - .. math:: + self.momentum = momentum - \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t} + def __repr__(self): + return f'{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})' - where :math:`G(t)` contains the sum of the squares of the past gradients + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + vs = dict((key + '_v', bm.Variable(bm.zeros_like(x))) + for key, x in train_vars.items()) + self.register_implicit_vars(vs) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + for key, p in self.vars_to_train.items(): + g = grads[key] + v = self.implicit_vars[key + '_v'] + v.value = self.momentum * v.value - lr * g + if self.weight_decay is None: + p.value += v.value + else: + p.value = (1 - self.weight_decay) * p + v + self.lr.step_call() - One of Adagrad's main benefits is that it eliminates the need to manually tune - the learning rate. Most implementations use a default value of 0.01 and leave it at that. - Adagrad's main weakness is its accumulation of the squared gradients in the denominator: - Since every added term is positive, the accumulated sum keeps growing during training. - This in turn causes the learning rate to shrink and eventually become infinitesimally - small, at which point the algorithm is no longer able to acquire additional knowledge. - Parameters:: +class MomentumNesterov(CommonOpt): + r"""Nesterov accelerated gradient optimizer [2]_. - lr: float, Scheduler - learning rate. + .. math:: - References:: + \begin{align} + \begin{split} + v_t &= \gamma v_{t-1} + \eta \nabla_\theta J( \theta - \gamma v_{t-1} ) \\ + \theta &= \theta - v_t + \end{split} + \end{align} - .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html + Parameters:: - """ + lr: float, Scheduler + learning rate. - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - weight_decay: Optional[float] = None, - epsilon: float = 1e-6, - name: Optional[str] = None - ): - super(Adagrad, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - self.epsilon = epsilon - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - caches = dict((key + '_cache', bm.Variable(bm.zeros_like(x))) - for key, x in train_vars.items()) - self.register_implicit_vars(caches) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - for key, p in self.vars_to_train.items(): - g = grads[key] - c = self.implicit_vars[key + '_cache'] - c.value += g ** 2 - update = lr * g / jnp.sqrt(c.value + self.epsilon) - if self.weight_decay is None: - p.value -= update - else: - p.value = (1 - self.weight_decay) * p - update - self.lr.step_call() - - def __repr__(self): - return f"{self.__class__.__name__}(lr={self.lr}, epsilon={self.epsilon})" + References:: + .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547. -class Adadelta(CommonOpt): - r"""Optimizer that implements the Adadelta algorithm. + """ - Adadelta [4]_ optimization is a stochastic gradient descent method that is based - on adaptive learning rate per dimension to address two drawbacks: + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + weight_decay: Optional[float] = None, + momentum: float = 0.9, + name: Optional[str] = None + ): + super(MomentumNesterov, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) - - The continual decay of learning rates throughout training. - - The need for a manually selected global learning rate. + self.momentum = momentum - Adadelta is a more robust extension of Adagrad that adapts learning rates based on - a moving window of gradient updates, instead of accumulating all past gradients. - This way, Adadelta continues learning even when many updates have been done. Compared - to Adagrad, in the original version of Adadelta you don't have to set an initial - learning rate. + def __repr__(self): + return f'{self.__class__.__name__}(lr={self.lr}, momentum={self.momentum})' - .. math:: + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + vs = dict((key + '_v', bm.Variable(bm.zeros_like(x))) + for key, x in train_vars.items()) + self.register_implicit_vars(vs) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + for key, p in self.vars_to_train.items(): + g = grads[key] + v = self.implicit_vars[key + '_v'] + v.value = self.momentum * v.value - lr * g + if self.weight_decay is None: + p.value += v + else: + p.value = (1 - self.weight_decay) * p + v + self.lr.step_call() - \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\ - \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\ - \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\ - \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t. - :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the - moving average slowly and a value close to 0 will decay the moving average fast. +class Adagrad(CommonOpt): + r"""Optimizer that implements the Adagrad algorithm. - :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported - to work for multiple datasets (MNIST, speech). + Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are + adapted relative to how frequently a parameter gets updated during training. + The more updates a parameter receives, the smaller the updates. - In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to - keep it at this value. epsilon is important for the very first update (so the - numerator does not become 0). + .. math:: - Parameters:: + \theta_{t+1} = \theta_{t} - \dfrac{\eta}{\sqrt{G_{t} + \epsilon}} \odot g_{t} + + where :math:`G(t)` contains the sum of the squares of the past gradients + + One of Adagrad's main benefits is that it eliminates the need to manually tune + the learning rate. Most implementations use a default value of 0.01 and leave it at that. + Adagrad's main weakness is its accumulation of the squared gradients in the denominator: + Since every added term is positive, the accumulated sum keeps growing during training. + This in turn causes the learning rate to shrink and eventually become infinitesimally + small, at which point the algorithm is no longer able to acquire additional knowledge. + + Parameters:: + + lr: float, Scheduler + learning rate. + + References:: + + .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html + + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + weight_decay: Optional[float] = None, + epsilon: float = 1e-6, + name: Optional[str] = None + ): + super(Adagrad, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) + self.epsilon = epsilon + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + caches = dict((key + '_cache', bm.Variable(bm.zeros_like(x))) + for key, x in train_vars.items()) + self.register_implicit_vars(caches) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + for key, p in self.vars_to_train.items(): + g = grads[key] + c = self.implicit_vars[key + '_cache'] + c.value += g ** 2 + update = lr * g / jnp.sqrt(c.value + self.epsilon) + if self.weight_decay is None: + p.value -= update + else: + p.value = (1 - self.weight_decay) * p - update + self.lr.step_call() + + def __repr__(self): + return f"{self.__class__.__name__}(lr={self.lr}, epsilon={self.epsilon})" - lr: float, Scheduler - learning rate. - References:: +class Adadelta(CommonOpt): + r"""Optimizer that implements the Adadelta algorithm. - .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701 + Adadelta [4]_ optimization is a stochastic gradient descent method that is based + on adaptive learning rate per dimension to address two drawbacks: - """ + - The continual decay of learning rates throughout training. + - The need for a manually selected global learning rate. - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable] = 0.01, - train_vars: Dict[str, bm.Variable] = None, - weight_decay: Optional[float] = None, - epsilon: float = 1e-6, - rho: float = 0.95, - name: Optional[str] = None - ): - super(Adadelta, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) + Adadelta is a more robust extension of Adagrad that adapts learning rates based on + a moving window of gradient updates, instead of accumulating all past gradients. + This way, Adadelta continues learning even when many updates have been done. Compared + to Adagrad, in the original version of Adadelta you don't have to set an initial + learning rate. - self.epsilon = epsilon - self.rho = rho + .. math:: - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - caches = dict((key + '_cache', bm.Variable(bm.zeros_like(x))) - for key, x in train_vars.items()) - deltas = dict((key + '_delta', bm.Variable(bm.zeros_like(x))) - for key, x in train_vars.items()) - self.register_implicit_vars(caches) - self.register_implicit_vars(deltas) - - def update(self, grads: dict): - self.check_grads(grads) - for key, p in self.vars_to_train.items(): - g = grads[key] - c = self.implicit_vars[key + '_cache'] - d = self.implicit_vars[key + '_delta'] - c.value = self.rho * c.value + (1 - self.rho) * g ** 2 - update = g * jnp.sqrt(d.value + self.epsilon) / jnp.sqrt(c + self.epsilon) - d.value = self.rho * d.value + (1 - self.rho) * update ** 2 - if self.weight_decay is None: - p.value -= update - else: - p.value = (1 - self.weight_decay) * p - update - self.lr.step_call() - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={self.lr}, " - f"epsilon={self.epsilon}, rho={self.rho})") + \boldsymbol{s}_t \leftarrow \rho \boldsymbol{s}_{t-1} + (1 - \rho) \boldsymbol{g}_t \odot \boldsymbol{g}_t, \\ + \boldsymbol{g}_t' \leftarrow \sqrt{\frac{\Delta\boldsymbol{x}_{t-1} + \epsilon}{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, \\ + \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \boldsymbol{g}'_t, \\ + \Delta\boldsymbol{x}_t \leftarrow \rho \Delta\boldsymbol{x}_{t-1} + (1 - \rho) \boldsymbol{g}'_t \odot \boldsymbol{g}'_t. + + :math:`\rho` should be between 0 and 1. A value of rho close to 1 will decay the + moving average slowly and a value close to 0 will decay the moving average fast. + + :math:`\rho` = 0.95 and :math:`\epsilon`=1e-6 are suggested in the paper and reported + to work for multiple datasets (MNIST, speech). + + In the paper, no learning rate is considered (so learning_rate=1.0). Probably best to + keep it at this value. epsilon is important for the very first update (so the + numerator does not become 0). + + Parameters:: + + lr: float, Scheduler + learning rate. + + References:: + + .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701 + + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable] = 0.01, + train_vars: Dict[str, bm.Variable] = None, + weight_decay: Optional[float] = None, + epsilon: float = 1e-6, + rho: float = 0.95, + name: Optional[str] = None + ): + super(Adadelta, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) + + self.epsilon = epsilon + self.rho = rho + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + caches = dict((key + '_cache', bm.Variable(bm.zeros_like(x))) + for key, x in train_vars.items()) + deltas = dict((key + '_delta', bm.Variable(bm.zeros_like(x))) + for key, x in train_vars.items()) + self.register_implicit_vars(caches) + self.register_implicit_vars(deltas) + + def update(self, grads: dict): + self.check_grads(grads) + for key, p in self.vars_to_train.items(): + g = grads[key] + c = self.implicit_vars[key + '_cache'] + d = self.implicit_vars[key + '_delta'] + c.value = self.rho * c.value + (1 - self.rho) * g ** 2 + update = g * jnp.sqrt(d.value + self.epsilon) / jnp.sqrt(c + self.epsilon) + d.value = self.rho * d.value + (1 - self.rho) * update ** 2 + if self.weight_decay is None: + p.value -= update + else: + p.value = (1 - self.weight_decay) * p - update + self.lr.step_call() + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={self.lr}, " + f"epsilon={self.epsilon}, rho={self.rho})") class RMSProp(CommonOpt): - r"""Optimizer that implements the RMSprop algorithm. - - RMSprop [5]_ and Adadelta have both been developed independently around the same time - stemming from the need to resolve Adagrad's radically diminishing learning rates. - - The gist of RMSprop is to: - - - Maintain a moving (discounted) average of the square of gradients - - Divide the gradient by the root of this average + r"""Optimizer that implements the RMSprop algorithm. - .. math:: + RMSprop [5]_ and Adadelta have both been developed independently around the same time + stemming from the need to resolve Adagrad's radically diminishing learning rates. - \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\ - p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split} + The gist of RMSprop is to: - The centered version additionally maintains a moving average of the gradients, - and uses that average to estimate the variance. + - Maintain a moving (discounted) average of the square of gradients + - Divide the gradient by the root of this average - Parameters:: - - lr: float, Scheduler - learning rate. - - References:: - - .. [5] Tieleman, T. and Hinton, G. (2012): - Neural Networks for Machine Learning, Lecture 6.5 - rmsprop. - Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20) - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - weight_decay: Optional[float] = None, - epsilon: float = 1e-6, - rho: float = 0.9, - name: Optional[str] = None - ): - super(RMSProp, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) + .. math:: - self.epsilon = epsilon - self.rho = rho - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - caches = dict((key + '_cache', bm.Variable(bm.zeros_like(x))) - for key, x in train_vars.items()) - self.register_implicit_vars(caches) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - for k, p in self.vars_to_train.items(): - g = grads[k] - c = self.implicit_vars[k + '_cache'] - c.value = self.rho * c.value + (1 - self.rho) * g ** 2 - update = (lr * g / jnp.sqrt(c.value + self.epsilon)) - if self.weight_decay is None: - p.value -= update - else: - p.value = (1 - self.weight_decay) * p - update - self.lr.step_call() - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={self.lr}, " - f"epsilon={self.epsilon}, rho={self.rho})") + \begin{split}c_t &= \rho c_{t-1} + (1-\rho)*g^2\\ + p_t &= \frac{\eta}{\sqrt{c_t + \epsilon}} * g \end{split} + + The centered version additionally maintains a moving average of the gradients, + and uses that average to estimate the variance. + + Parameters:: + + lr: float, Scheduler + learning rate. + + References:: + + .. [5] Tieleman, T. and Hinton, G. (2012): + Neural Networks for Machine Learning, Lecture 6.5 - rmsprop. + Coursera. http://www.youtube.com/watch?v=O3sxAc4hxZU (formula @5:20) + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + weight_decay: Optional[float] = None, + epsilon: float = 1e-6, + rho: float = 0.9, + name: Optional[str] = None + ): + super(RMSProp, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) + + self.epsilon = epsilon + self.rho = rho + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + caches = dict((key + '_cache', bm.Variable(bm.zeros_like(x))) + for key, x in train_vars.items()) + self.register_implicit_vars(caches) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + for k, p in self.vars_to_train.items(): + g = grads[k] + c = self.implicit_vars[k + '_cache'] + c.value = self.rho * c.value + (1 - self.rho) * g ** 2 + update = (lr * g / jnp.sqrt(c.value + self.epsilon)) + if self.weight_decay is None: + p.value -= update + else: + p.value = (1 - self.weight_decay) * p - update + self.lr.step_call() + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={self.lr}, " + f"epsilon={self.epsilon}, rho={self.rho})") class Adam(CommonOpt): - """Optimizer that implements the Adam algorithm. - - Adam [6]_ - a stochastic gradient descent method (SGD) that computes - individual adaptive learning rates for different parameters from estimates of - first- and second-order moments of the gradients. - - Parameters:: + """Optimizer that implements the Adam algorithm. + + Adam [6]_ - a stochastic gradient descent method (SGD) that computes + individual adaptive learning rates for different parameters from estimates of + first- and second-order moments of the gradients. + + Parameters:: + + lr: float, Scheduler + learning rate. + beta1: optional, float + A positive scalar value for beta_1, the exponential decay rate + for the first moment estimates (default 0.9). + beta2: optional, float + A positive scalar value for beta_2, the exponential decay rate + for the second moment estimates (default 0.999). + eps: optional, float + A positive scalar value for epsilon, a small constant for + numerical stability (default 1e-8). + name : optional, str + The optimizer name. + + References:: + + .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980. + """ + + def __init__( + self, + lr: Union[float, bm.Variable, Scheduler], + train_vars: Dict[str, bm.Variable] = None, + beta1: float = 0.9, + beta2: float = 0.999, + eps: float = 1e-8, + weight_decay: Optional[float] = None, + name: Optional[str] = None + ): + super(Adam, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) - lr: float, Scheduler - learning rate. - beta1: optional, float - A positive scalar value for beta_1, the exponential decay rate - for the first moment estimates (default 0.9). - beta2: optional, float - A positive scalar value for beta_2, the exponential decay rate - for the second moment estimates (default 0.999). - eps: optional, float - A positive scalar value for epsilon, a small constant for - numerical stability (default 1e-8). - name : optional, str - The optimizer name. - - References:: - - .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980. - """ - - def __init__( - self, - lr: Union[float, bm.Variable, Scheduler], - train_vars: Dict[str, bm.Variable] = None, - beta1: float = 0.9, - beta2: float = 0.999, - eps: float = 1e-8, - weight_decay: Optional[float] = None, - name: Optional[str] = None - ): - super(Adam, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - - self.beta1 = beta1 - self.beta2 = beta2 - self.eps = eps - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={str(self.lr)}, " - f"beta1={self.beta1}, beta2={self.beta2}, eps={self.eps})") - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - ms = dict((k + '_m', bm.Variable(bm.zeros_like(x))) - for k, x in train_vars.items()) - self.register_implicit_vars(ms) - vs = dict((k + '_v', bm.Variable(bm.zeros_like(x))) - for k, x in train_vars.items()) - self.register_implicit_vars(vs) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - lr /= (1 - self.beta1 ** (self.lr.last_epoch.value + 2)) - lr *= jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2)) - for key, p in self.vars_to_train.items(): - m = self.implicit_vars[key + '_m'] - v = self.implicit_vars[key + '_v'] - g = grads[key] - # First moment estimate. - m.value = self.beta1 * m.value + (1 - self.beta1) * g - # Second moment estimate. - v.value = self.beta2 * v.value + (1 - self.beta2) * g ** 2 - # Bias correction. - update = lr * m.value / (jnp.sqrt(v.value) + self.eps) - if self.weight_decay is None: - p.value -= update - else: - p.value = (1 - self.weight_decay) * p - update - self.lr.step_call() + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={str(self.lr)}, " + f"beta1={self.beta1}, beta2={self.beta2}, eps={self.eps})") + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + ms = dict((k + '_m', bm.Variable(bm.zeros_like(x))) + for k, x in train_vars.items()) + self.register_implicit_vars(ms) + vs = dict((k + '_v', bm.Variable(bm.zeros_like(x))) + for k, x in train_vars.items()) + self.register_implicit_vars(vs) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + lr /= (1 - self.beta1 ** (self.lr.last_epoch.value + 2)) + lr *= jnp.sqrt(1 - self.beta2 ** (self.lr.last_epoch.value + 2)) + for key, p in self.vars_to_train.items(): + m = self.implicit_vars[key + '_m'] + v = self.implicit_vars[key + '_v'] + g = grads[key] + # First moment estimate. + m.value = self.beta1 * m.value + (1 - self.beta1) * g + # Second moment estimate. + v.value = self.beta2 * v.value + (1 - self.beta2) * g ** 2 + # Bias correction. + update = lr * m.value / (jnp.sqrt(v.value) + self.eps) + if self.weight_decay is None: + p.value -= update + else: + p.value = (1 - self.weight_decay) * p - update + self.lr.step_call() class LARS(CommonOpt): - r"""Layer-wise adaptive rate scaling (LARS) optimizer [1]_. + r"""Layer-wise adaptive rate scaling (LARS) optimizer [1]_. - Layer-wise Adaptive Rate Scaling, or LARS, is a large batch - optimization technique. There are two notable differences - between LARS and other adaptive algorithms such as `Adam` or `RMSProp`: - first, LARS uses a separate learning rate for each layer and not for - each weight. And second, the magnitude of the update is controlled - with respect to the weight norm for better control of training speed. + Layer-wise Adaptive Rate Scaling, or LARS, is a large batch + optimization technique. There are two notable differences + between LARS and other adaptive algorithms such as `Adam` or `RMSProp`: + first, LARS uses a separate learning rate for each layer and not for + each weight. And second, the magnitude of the update is controlled + with respect to the weight norm for better control of training speed. - .. math:: - - m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\ - x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)} + .. math:: - Parameters:: + m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\ + x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)} + + Parameters:: + + lr: float, Scheduler + learning rate. + momentum: float + coefficient used for the moving average of the gradient. + weight_decay: float + weight decay coefficient. + tc: float + trust coefficient eta ( < 1) for trust ratio computation. + eps: float + epsilon used for trust ratio computation. + + References:: + + .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag. + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + momentum: float = 0.9, + weight_decay: float = 1e-4, + tc: float = 1e-3, + eps: float = 1e-5, + name: Optional[str] = None + ): + super(LARS, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) - lr: float, Scheduler - learning rate. - momentum: float - coefficient used for the moving average of the gradient. - weight_decay: float - weight decay coefficient. - tc: float - trust coefficient eta ( < 1) for trust ratio computation. - eps: float - epsilon used for trust ratio computation. - - References:: - - .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag. - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - momentum: float = 0.9, - weight_decay: float = 1e-4, - tc: float = 1e-3, - eps: float = 1e-5, - name: Optional[str] = None - ): - super(LARS, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - - self.momentum = momentum - self.tc = tc - self.eps = eps - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={self.lr}, " - f"momentum={self.momentum}, weight_decay={self.weight_decay}, " - f"tc={self.tc}, eps={self.eps})") - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - self.register_implicit_vars({k + '_m': bm.Variable(bm.zeros_like(x)) - for k, x in train_vars.items()}) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - for k, p in self.vars_to_train.items(): - g = bm.as_jax(grads[k]) - m = self.implicit_vars[k + '_m'] - p_norm = jnp.linalg.norm(p.value) - g_norm = jnp.linalg.norm(g) - trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps) - local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio) - m.value = self.momentum * m.value + local_lr * (g + self.weight_decay * p.value) - p.value -= m.value - self.lr.step_call() + self.momentum = momentum + self.tc = tc + self.eps = eps + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={self.lr}, " + f"momentum={self.momentum}, weight_decay={self.weight_decay}, " + f"tc={self.tc}, eps={self.eps})") + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + self.register_implicit_vars({k + '_m': bm.Variable(bm.zeros_like(x)) + for k, x in train_vars.items()}) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + for k, p in self.vars_to_train.items(): + g = bm.as_jax(grads[k]) + m = self.implicit_vars[k + '_m'] + p_norm = jnp.linalg.norm(p.value) + g_norm = jnp.linalg.norm(g) + trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps) + local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio) + m.value = self.momentum * m.value + local_lr * (g + self.weight_decay * p.value) + p.value -= m.value + self.lr.step_call() class Adan(CommonOpt): - r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_. - - .. math:: - - \begin{equation} - \begin{aligned} - & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\ - & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\ - & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\ - & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\ - & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\ - \end{aligned} - \end{equation} - - Parameters:: - - lr: float, Scheduler - learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3) - betas : tuple - Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01)) - eps : float - The term added to the denominator to improve numerical stability. (default: 1e-8) - weight_decay : float - decoupled weight decay (L2 penalty) (default: 0) - no_prox: bool - how to perform the decoupled weight decay (default: False). - It determines the update rule of parameters with weight decay. - By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper: + r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_. .. math:: - \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right], - But one also can update the parameter like Adamw: + \begin{equation} + \begin{aligned} + & \mathbf{m}_k=\left(1-\beta_1\right) \mathbf{m}_{k-1}+\beta_1 \mathbf{g}_k \\ + & \mathbf{v}_k=\left(1-\beta_2\right) \mathbf{v}_{k-1}+\beta_2\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right) \\ + & \mathbf{n}_k=\left(1-\beta_3\right) \mathbf{n}_{k-1}+\beta_3\left[\mathbf{g}_k+\left(1-\beta_2\right)\left(\mathbf{g}_k-\mathbf{g}_{k-1}\right)\right]^2 \\ + & \boldsymbol{\eta}_k=\eta /\left(\sqrt{\mathbf{n}_k+\varepsilon}\right) \\ + & \boldsymbol{\theta}_{k+1}=\left(1+\lambda_k \eta\right)^{-1}\left[\boldsymbol{\theta}_k-\boldsymbol{\eta}_k \circ\left(\mathbf{m}_k+\left(1-\beta_2\right) \mathbf{v}_k\right)\right] \\ + \end{aligned} + \end{equation} + + Parameters:: + + lr: float, Scheduler + learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3) + betas : tuple + Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01)) + eps : float + The term added to the denominator to improve numerical stability. (default: 1e-8) + weight_decay : float + decoupled weight decay (L2 penalty) (default: 0) + no_prox: bool + how to perform the decoupled weight decay (default: False). + It determines the update rule of parameters with weight decay. + By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper: + + .. math:: + \boldsymbol{\theta}_{k+1} = ( 1+\lambda \eta)^{-1}\left[\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}k)\right], + + But one also can update the parameter like Adamw: + + .. math:: + \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k). + + References:: + + .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan. + “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing + Deep Models.” ArXiv abs/2208.06677 (2022): n. pag. + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable] = 1e-3, + train_vars: Dict[str, bm.Variable] = None, + betas: Tuple[float, float, float] = (0.02, 0.08, 0.01), + eps: float = 1e-8, + weight_decay: float = 0.02, + no_prox: bool = False, + name: Optional[str] = None, + ): + super(Adan, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) - .. math:: - \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k). - - References:: - - .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan. - “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing - Deep Models.” ArXiv abs/2208.06677 (2022): n. pag. - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable] = 1e-3, - train_vars: Dict[str, bm.Variable] = None, - betas: Tuple[float, float, float] = (0.02, 0.08, 0.01), - eps: float = 1e-8, - weight_decay: float = 0.02, - no_prox: bool = False, - name: Optional[str] = None, - ): - super(Adan, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - - assert len(betas) == 3 - if eps < 0.: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0.0 <= betas[2] < 1.0: - raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) - - self.betas = betas - self.eps = eps - self.weight_decay = weight_decay - self.no_prox = no_prox - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={self.lr}, " - f"betas={self.betas}, " - f"weight_decay={self.weight_decay}, " - f"no_prox={self.no_prox}, " - f"eps={self.eps}") - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - # Exponential moving average of gradient values - exp_avg = {k + '_m': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} - # Exponential moving average of squared gradient values - exp_avg_sq = {k + '_v': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} - # Exponential moving average of gradient difference - exp_avg_diff = {k + '_n': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} - # previous gradient - pre_grad = {k + '_prev_grad': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} - self.register_implicit_vars(exp_avg, exp_avg_sq, exp_avg_diff, pre_grad) - - def _update_moments(self, m, n, v, pre_g, g): - m = m * (1 - self.betas[0]) + self.betas[0] * g - gd = g - pre_g - v = v * (1 - self.betas[1]) + self.betas[1] * gd - n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2 - return m, n, v - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - step = self.lr.last_epoch.value + 1 - correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1)) - correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1)) - correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1)) - for key, p_var in self.vars_to_train.items(): - m_var = self.implicit_vars[key + '_m'] - n_var = self.implicit_vars[key + '_n'] - v_var = self.implicit_vars[key + '_v'] - prev_g_var = self.implicit_vars[key + '_prev_grad'] - g = grads[key] - pre_g = cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g)) - diff = g - pre_g - m = m_var.value * (1 - self.betas[0]) + self.betas[0] * g - v = v_var.value * (1 - self.betas[1]) + self.betas[1] * diff - n = n_var.value * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * diff) ** 2 - weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps) - if self.no_prox: - p = (p_var.value * (1 - self.weight_decay * lr) - - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) - else: - p = ( - (p_var.value - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) - / (1 + self.weight_decay * lr) - ) - m_var.value = m - n_var.value = n - v_var.value = v - prev_g_var.value = g - p_var.value = p - self.lr.step_call() + assert len(betas) == 3 + if eps < 0.: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + + self.betas = betas + self.eps = eps + self.weight_decay = weight_decay + self.no_prox = no_prox + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={self.lr}, " + f"betas={self.betas}, " + f"weight_decay={self.weight_decay}, " + f"no_prox={self.no_prox}, " + f"eps={self.eps}") + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + # Exponential moving average of gradient values + exp_avg = {k + '_m': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} + # Exponential moving average of squared gradient values + exp_avg_sq = {k + '_v': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} + # Exponential moving average of gradient difference + exp_avg_diff = {k + '_n': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} + # previous gradient + pre_grad = {k + '_prev_grad': bm.Variable(bm.zeros_like(x)) for k, x in train_vars.items()} + self.register_implicit_vars(exp_avg, exp_avg_sq, exp_avg_diff, pre_grad) + + def _update_moments(self, m, n, v, pre_g, g): + m = m * (1 - self.betas[0]) + self.betas[0] * g + gd = g - pre_g + v = v * (1 - self.betas[1]) + self.betas[1] * gd + n = n * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * gd) ** 2 + return m, n, v + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + step = self.lr.last_epoch.value + 1 + correct_m = 1 / (1 - (1 - self.betas[0]) ** (step + 1)) + correct_v = 1 / (1 - (1 - self.betas[1]) ** (step + 1)) + correct_n = 1 / (1 - (1 - self.betas[2]) ** (step + 1)) + for key, p_var in self.vars_to_train.items(): + m_var = self.implicit_vars[key + '_m'] + n_var = self.implicit_vars[key + '_n'] + v_var = self.implicit_vars[key + '_v'] + prev_g_var = self.implicit_vars[key + '_prev_grad'] + g = grads[key] + pre_g = cond(step == 0, lambda pg, g: g, lambda pg, g: pg, (prev_g_var.value, g)) + diff = g - pre_g + m = m_var.value * (1 - self.betas[0]) + self.betas[0] * g + v = v_var.value * (1 - self.betas[1]) + self.betas[1] * diff + n = n_var.value * (1 - self.betas[2]) + self.betas[2] * (g + (1 - self.betas[1]) * diff) ** 2 + weighted_step_size = lr / (jnp.sqrt(n * correct_n) + self.eps) + if self.no_prox: + p = (p_var.value * (1 - self.weight_decay * lr) - + weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) + else: + p = ( + (p_var.value - weighted_step_size * (m * correct_m + (1 - self.betas[1]) * v * correct_v)) + / (1 + self.weight_decay * lr) + ) + m_var.value = m + n_var.value = n + v_var.value = v + prev_g_var.value = g + p_var.value = p + self.lr.step_call() class AdamW(CommonOpt): - r"""Adam with weight decay regularization [1]_. - - AdamW uses weight decay to regularize learning towards small weights, as - this leads to better generalization. In SGD you can also use L2 regularization - to implement this as an additive loss term, however L2 regularization - does not behave as intended for adaptive gradient algorithms such as Adam. - - .. math:: - - \begin{aligned} - &\rule{110mm}{0.4pt} \\ - &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 - \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, - \: \epsilon \text{ (epsilon)} \\ - &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, - \: \textit{maximize} \\ - &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 - \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] - &\rule{110mm}{0.4pt} \\ - &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ - - &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ - &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm}\textbf{else} \\ - &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ - &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ - &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ - &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ - &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ - &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ - &\hspace{5mm}\textbf{if} \: amsgrad \\ - &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, - \widehat{v_t}) \\ - &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ - \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ - &\hspace{5mm}\textbf{else} \\ - &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ - \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ - &\rule{110mm}{0.4pt} \\[-1.ex] - &\bf{return} \: \theta_t \\[-1.ex] - &\rule{110mm}{0.4pt} \\[-1.ex] - \end{aligned} - - - Parameters:: - - lr: float, Scheduler - learning rate. - beta1: optional, float - A positive scalar value for beta_1, the exponential decay rate - for the first moment estimates. Generally close to 1. - beta2: optional, float - A positive scalar value for beta_2, the exponential decay rate - for the second moment estimates. Generally close to 1. - eps: optional, float - A positive scalar value for epsilon, a small constant for - numerical stability. - weight_decay: float - Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. - amsgrad: bool - whether to use the AMSGrad variant of this algorithm - from the paper `On the Convergence of Adam and Beyond`. - name : optional, str - The optimizer name. - - References:: - - .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019). - - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - beta1: float = 0.9, - beta2: float = 0.999, - eps: float = 1e-8, - weight_decay: float = 1e-2, - amsgrad: bool = False, - name: Optional[str] = None, - ): - self.amsgrad = amsgrad - super(AdamW, self).__init__(lr=lr, - train_vars=train_vars, - weight_decay=weight_decay, - name=name) - - if eps < 0.: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= beta1 < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(beta1)) - if not 0.0 <= beta2 < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(beta2)) - if weight_decay < 0.: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - - self.beta1 = beta1 - self.beta2 = beta2 - self.eps = eps - self.weight_decay = weight_decay - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={self.lr}, " - f"beta1={self.beta1}, " - f"beta2={self.beta2}, " - f"weight_decay={self.weight_decay}, " - f"eps={self.eps}, " - f"amsgrad={self.amsgrad})") - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - # Exponential moving average of gradient values - ms = dict((k + '_m', bm.Variable(bm.zeros_like(x))) - for k, x in train_vars.items()) - # Exponential moving average of squared gradient values - vs = dict((k + '_v', bm.Variable(bm.zeros_like(x))) - for k, x in train_vars.items()) - self.register_implicit_vars(ms, vs) - # Maintains max of all exp. moving avg. of sq. grad. values - if self.amsgrad: - gs = {k + '_vmax': bm.Variable(bm.zeros_like(x)) - for k, x in train_vars.items()} - self.register_implicit_vars(gs) - - def update(self, grads: dict): - self.check_grads(grads) - lr_old = self.lr() - step = self.lr.last_epoch.value + 2 - bias_correction1 = 1 - self.beta1 ** step - bias_correction2 = 1 - self.beta2 ** step - lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1 - for key, p in self.vars_to_train.items(): - m = self.implicit_vars[key + '_m'] - v = self.implicit_vars[key + '_v'] - g = grads[key] - if self.weight_decay != 0: - p *= (1 - lr_old * self.weight_decay) - # First moment estimate. - m.value = self.beta1 * m.value + (1 - self.beta1) * g - # Second moment estimate. - v.value = self.beta2 * v.value + (1 - self.beta2) * g ** 2 - if self.amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - vmax = self.implicit_vars[key + '_vmax'] - vmax.value = jnp.maximum(vmax.value, v) - # Use the max. for normalizing running avg. of gradient - denom = jnp.sqrt(vmax) + self.eps - else: - denom = jnp.sqrt(v.value) + self.eps - # Bias correction. - p.value -= lr * m / denom - self.lr.step_call() + r"""Adam with weight decay regularization [1]_. + + AdamW uses weight decay to regularize learning towards small weights, as + this leads to better generalization. In SGD you can also use L2 regularization + to implement this as an additive loss term, however L2 regularization + does not behave as intended for adaptive gradient algorithms such as Adam. + + .. math:: + + \begin{aligned} + &\rule{110mm}{0.4pt} \\ + &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 + \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, + \: \epsilon \text{ (epsilon)} \\ + &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, + \: \textit{maximize} \\ + &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 + \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] + &\rule{110mm}{0.4pt} \\ + &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ + + &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ + &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ + &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ + &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ + &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ + &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ + &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ + &\hspace{5mm}\textbf{if} \: amsgrad \\ + &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, + \widehat{v_t}) \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ + &\hspace{5mm}\textbf{else} \\ + &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ + \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ + &\rule{110mm}{0.4pt} \\[-1.ex] + &\bf{return} \: \theta_t \\[-1.ex] + &\rule{110mm}{0.4pt} \\[-1.ex] + \end{aligned} + + + Parameters:: + + lr: float, Scheduler + learning rate. + beta1: optional, float + A positive scalar value for beta_1, the exponential decay rate + for the first moment estimates. Generally close to 1. + beta2: optional, float + A positive scalar value for beta_2, the exponential decay rate + for the second moment estimates. Generally close to 1. + eps: optional, float + A positive scalar value for epsilon, a small constant for + numerical stability. + weight_decay: float + Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. + amsgrad: bool + whether to use the AMSGrad variant of this algorithm + from the paper `On the Convergence of Adam and Beyond`. + name : optional, str + The optimizer name. + + References:: + + .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019). + + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + beta1: float = 0.9, + beta2: float = 0.999, + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + name: Optional[str] = None, + ): + self.amsgrad = amsgrad + super(AdamW, self).__init__(lr=lr, + train_vars=train_vars, + weight_decay=weight_decay, + name=name) + + if eps < 0.: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= beta1 < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(beta1)) + if not 0.0 <= beta2 < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(beta2)) + if weight_decay < 0.: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.weight_decay = weight_decay + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={self.lr}, " + f"beta1={self.beta1}, " + f"beta2={self.beta2}, " + f"weight_decay={self.weight_decay}, " + f"eps={self.eps}, " + f"amsgrad={self.amsgrad})") + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + # Exponential moving average of gradient values + ms = dict((k + '_m', bm.Variable(bm.zeros_like(x))) + for k, x in train_vars.items()) + # Exponential moving average of squared gradient values + vs = dict((k + '_v', bm.Variable(bm.zeros_like(x))) + for k, x in train_vars.items()) + self.register_implicit_vars(ms, vs) + # Maintains max of all exp. moving avg. of sq. grad. values + if self.amsgrad: + gs = {k + '_vmax': bm.Variable(bm.zeros_like(x)) + for k, x in train_vars.items()} + self.register_implicit_vars(gs) + + def update(self, grads: dict): + self.check_grads(grads) + lr_old = self.lr() + step = self.lr.last_epoch.value + 2 + bias_correction1 = 1 - self.beta1 ** step + bias_correction2 = 1 - self.beta2 ** step + lr = lr_old * jnp.sqrt(bias_correction2) / bias_correction1 + for key, p in self.vars_to_train.items(): + m = self.implicit_vars[key + '_m'] + v = self.implicit_vars[key + '_v'] + g = grads[key] + if self.weight_decay != 0: + p *= (1 - lr_old * self.weight_decay) + # First moment estimate. + m.value = self.beta1 * m.value + (1 - self.beta1) * g + # Second moment estimate. + v.value = self.beta2 * v.value + (1 - self.beta2) * g ** 2 + if self.amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + vmax = self.implicit_vars[key + '_vmax'] + vmax.value = jnp.maximum(vmax.value, v) + # Use the max. for normalizing running avg. of gradient + denom = jnp.sqrt(vmax) + self.eps + else: + denom = jnp.sqrt(v.value) + self.eps + # Bias correction. + p.value -= lr * m / denom + self.lr.step_call() class SM3(CommonOpt): - """SM3 algorithm [1]_. - - The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method' - (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar - to Adam and Adagrad with greatly reduced memory usage for history tensors. - For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history - tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor - of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)` - memory for storage tensors, while the optimization using SM3 will use - `O(sum n_i)` memory. Despite storing fewer parameters, this optimization - algorithm manages to be comparably effective. - - This advantage drastically shrinks when `momentum > 0`. The momentum is - tracked using a tensor of the same shape as the tensor being optimized. With - momentum, SM3 will use just over half as much memory as Adam, and a bit more - than Adagrad. - - Parameters:: - - lr: float, Scheduler - learning rate. - momentum: float - coefficient used to scale prior updates - before adding. This drastically increases memory usage if - `momentum > 0.0`. (default: 0.0) - beta: float - coefficient used for exponential moving averages (default: 0.0) - eps: float - Term added to square-root in denominator to - improve numerical stability (default: 1e-30). - - References:: - - .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019). - - """ - - def __init__( - self, - lr: Union[float, Scheduler, bm.Variable], - train_vars: Dict[str, bm.Variable] = None, - beta: float = 0., - momentum: float = 0., - eps: float = 1e-30, - weight_decay: Optional[float] = None, - name: Optional[str] = None, - ): - super(SM3, self).__init__(lr=lr, - weight_decay=weight_decay, - train_vars=train_vars, - name=name) - - if not 0.0 <= momentum < 1.0: - raise ValueError("Invalid momentum: {0}".format(momentum)) - if not 0.0 <= beta < 1.0: - raise ValueError("Invalid beta: {0}".format(beta)) - if not 0.0 <= eps: - raise ValueError("Invalid eps: {0}".format(eps)) - - self.eps = eps - self.beta = beta - self.momentum = momentum - - def __repr__(self): - return (f"{self.__class__.__name__}(lr={self.lr}, " - f"beta={self.beta}, eps={self.eps}, momentum={self.momentum})") - - def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): - train_vars = dict() if train_vars is None else train_vars - if not isinstance(train_vars, dict): - raise MathError('"train_vars" must be a dict of Variable.') - self.vars_to_train.update(train_vars) - vs = dict() - for k, v in train_vars.items(): - rank, ndim = v.shape, v.ndim - for i in range(ndim): - shape = [1] * ndim - shape[i] = rank[i] - vs[f'{k}_m{i}'] = bm.Variable(bm.zeros(shape, dtype=v.dtype)) - self.register_implicit_vars(vs) - if self.momentum > 0.: - ms = {k + '_mbuffer': bm.Variable(bm.zeros_like(v)) - for k, v in train_vars.items()} - self.register_implicit_vars(ms) - - def update(self, grads: dict): - self.check_grads(grads) - lr = self.lr() - - for k, p in self.vars_to_train.items(): - g = grads[k] - ndim = p.ndim - update = self.implicit_vars[f'{k}_m0'] - for i in range(1, ndim): - update = bm.minimum(update, self.implicit_vars[f'{k}_m{i}']) - if self.beta > 0.: - update *= self.beta - update += g * g * (1 - self.beta) - # Computes max along all dimensions except the given dim. - # If tensor is a scalar, it returns tensor. - for i in range(ndim): - result = update - for j in range(ndim): - if i != j: - result = result.max(axis=j, keepdim=True) - acc = self.implicit_vars[f'{k}_m{i}'] - if self.beta > 0.: - acc.value = bm.maximum(acc, result) - else: - # No need to compare - nu_max is bigger because of grad ** 2 - acc.value = result - update = g / bm.sqrt(update + self.eps) - if self.momentum > 0.: - m_buffer = self.implicit_vars[f'{k}_mbuffer'] - update = update * (1. - self.momentum) + m_buffer * self.momentum - m_buffer.value = update - if self.weight_decay is None: - p -= lr * update - else: - p.value = (1 - self.weight_decay) * p - lr * update - self.lr.step_call() + """SM3 algorithm [1]_. + + The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method' + (SM3) algorithm is a memory-efficient adaptive optimization algorithm similar + to Adam and Adagrad with greatly reduced memory usage for history tensors. + For an `n x m` matrix, Adam and Adagrad use `O(nm)` memory for history + tensors, while SM3 uses `O(n+m)` due to the chosen cover. In general, a tensor + of shape `(n_1, n_2, ..., n_k)` optimized using Adam will use `O(prod n_i)` + memory for storage tensors, while the optimization using SM3 will use + `O(sum n_i)` memory. Despite storing fewer parameters, this optimization + algorithm manages to be comparably effective. + + This advantage drastically shrinks when `momentum > 0`. The momentum is + tracked using a tensor of the same shape as the tensor being optimized. With + momentum, SM3 will use just over half as much memory as Adam, and a bit more + than Adagrad. + + Parameters:: + + lr: float, Scheduler + learning rate. + momentum: float + coefficient used to scale prior updates + before adding. This drastically increases memory usage if + `momentum > 0.0`. (default: 0.0) + beta: float + coefficient used for exponential moving averages (default: 0.0) + eps: float + Term added to square-root in denominator to + improve numerical stability (default: 1e-30). + + References:: + + .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019). + + """ + + def __init__( + self, + lr: Union[float, Scheduler, bm.Variable], + train_vars: Dict[str, bm.Variable] = None, + beta: float = 0., + momentum: float = 0., + eps: float = 1e-30, + weight_decay: Optional[float] = None, + name: Optional[str] = None, + ): + super(SM3, self).__init__(lr=lr, + weight_decay=weight_decay, + train_vars=train_vars, + name=name) + + if not 0.0 <= momentum < 1.0: + raise ValueError("Invalid momentum: {0}".format(momentum)) + if not 0.0 <= beta < 1.0: + raise ValueError("Invalid beta: {0}".format(beta)) + if not 0.0 <= eps: + raise ValueError("Invalid eps: {0}".format(eps)) + + self.eps = eps + self.beta = beta + self.momentum = momentum + + def __repr__(self): + return (f"{self.__class__.__name__}(lr={self.lr}, " + f"beta={self.beta}, eps={self.eps}, momentum={self.momentum})") + + def register_train_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): + train_vars = dict() if train_vars is None else train_vars + if not isinstance(train_vars, dict): + raise MathError('"train_vars" must be a dict of Variable.') + self.vars_to_train.update(train_vars) + vs = dict() + for k, v in train_vars.items(): + rank, ndim = v.shape, v.ndim + for i in range(ndim): + shape = [1] * ndim + shape[i] = rank[i] + vs[f'{k}_m{i}'] = bm.Variable(bm.zeros(shape, dtype=v.dtype)) + self.register_implicit_vars(vs) + if self.momentum > 0.: + ms = {k + '_mbuffer': bm.Variable(bm.zeros_like(v)) + for k, v in train_vars.items()} + self.register_implicit_vars(ms) + + def update(self, grads: dict): + self.check_grads(grads) + lr = self.lr() + + for k, p in self.vars_to_train.items(): + g = grads[k] + ndim = p.ndim + update = self.implicit_vars[f'{k}_m0'] + for i in range(1, ndim): + update = bm.minimum(update, self.implicit_vars[f'{k}_m{i}']) + if self.beta > 0.: + update *= self.beta + update += g * g * (1 - self.beta) + # Computes max along all dimensions except the given dim. + # If tensor is a scalar, it returns tensor. + for i in range(ndim): + result = update + for j in range(ndim): + if i != j: + result = result.max(axis=j, keepdim=True) + acc = self.implicit_vars[f'{k}_m{i}'] + if self.beta > 0.: + acc.value = bm.maximum(acc, result) + else: + # No need to compare - nu_max is bigger because of grad ** 2 + acc.value = result + update = g / bm.sqrt(update + self.eps) + if self.momentum > 0.: + m_buffer = self.implicit_vars[f'{k}_mbuffer'] + update = update * (1. - self.momentum) + m_buffer * self.momentum + m_buffer.value = update + if self.weight_decay is None: + p -= lr * update + else: + p.value = (1 - self.weight_decay) * p - lr * update + self.lr.step_call() diff --git a/brainpy/_src/optimizers/scheduler.py b/brainpy/_src/optimizers/scheduler.py index 0126127e8..43c92375d 100644 --- a/brainpy/_src/optimizers/scheduler.py +++ b/brainpy/_src/optimizers/scheduler.py @@ -2,11 +2,11 @@ import warnings from typing import Sequence, Union +import brainstate import jax import jax.numpy as jnp import brainpy.math as bm -import brainstate from brainpy import check from brainpy._src.math.object_transform.base import BrainPyObject from brainpy.errors import MathError diff --git a/brainpy/_src/optimizers/tests/test_ModifyLr.py b/brainpy/_src/optimizers/tests/test_ModifyLr.py index cc8eb5101..c73b6fcd2 100644 --- a/brainpy/_src/optimizers/tests/test_ModifyLr.py +++ b/brainpy/_src/optimizers/tests/test_ModifyLr.py @@ -11,69 +11,65 @@ @bm.jit def build_inputs_and_targets(mean=0.025, scale=0.01): - sample = bm.random.normal(size=(num_batch, 1, 1)) - bias = mean * 2.0 * (sample - 0.5) - samples = bm.random.normal(size=(num_batch, num_step, 1)) - noise_t = scale / dt ** 0.5 * samples - inputs = bias + noise_t - targets = bm.cumsum(inputs, axis=1) - return inputs, targets + sample = bm.random.normal(size=(num_batch, 1, 1)) + bias = mean * 2.0 * (sample - 0.5) + samples = bm.random.normal(size=(num_batch, num_step, 1)) + noise_t = scale / dt ** 0.5 * samples + inputs = bias + noise_t + targets = bm.cumsum(inputs, axis=1) + return inputs, targets def train_data(): - for _ in range(100): - yield build_inputs_and_targets() + for _ in range(100): + yield build_inputs_and_targets() class RNN(bp.DynamicalSystem): - def __init__(self, num_in, num_hidden): - super(RNN, self).__init__() - self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) - self.out = bp.dnn.Dense(num_hidden, 1) + def __init__(self, num_in, num_hidden): + super(RNN, self).__init__() + self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) + self.out = bp.dnn.Dense(num_hidden, 1) - def update(self, x): - return self.out(self.rnn(x)) + def update(self, x): + return self.out(self.rnn(x)) - def loss(self, predictions, targets, l2_reg=2e-4): - mse = bp.losses.mean_squared_error(predictions, targets) - l2 = l2_reg * bp.losses.l2_norm(self.train_vars().unique().dict()) ** 2 - return mse + l2 + def loss(self, predictions, targets, l2_reg=2e-4): + mse = bp.losses.mean_squared_error(predictions, targets) + l2 = l2_reg * bp.losses.l2_norm(self.train_vars().unique().dict()) ** 2 + return mse + l2 class test_ModifyLr(parameterized.TestCase): - @parameterized.product( - LearningRate=[ - bp.optim.ExponentialDecayLR(lr=bm.Variable(bm.as_jax(0.025)), decay_steps=1, decay_rate=0.99975), - bp.optim.InverseTimeDecayLR(lr=bm.Variable(bm.as_jax(0.025)), decay_steps=1, decay_rate=0.99975), - bp.optim.PolynomialDecayLR(lr=bm.Variable(bm.as_jax(0.1)), decay_steps=1, final_lr=0.025), - bp.optim.PiecewiseConstantLR(boundaries=(2, 2), values=(2, 2, 2)) - ] - ) - def test_NewScheduler(self, LearningRate): - with bm.training_environment(): - model = RNN(1, 100) + @parameterized.product( + LearningRate=[ + bp.optim.ExponentialDecayLR(lr=bm.Variable(bm.as_jax(0.025)), decay_steps=1, decay_rate=0.99975), + bp.optim.InverseTimeDecayLR(lr=bm.Variable(bm.as_jax(0.025)), decay_steps=1, decay_rate=0.99975), + bp.optim.PolynomialDecayLR(lr=bm.Variable(bm.as_jax(0.1)), decay_steps=1, final_lr=0.025), + bp.optim.PiecewiseConstantLR(boundaries=(2, 2), values=(2, 2, 2)) + ] + ) + def test_NewScheduler(self, LearningRate): + with bm.training_environment(): + model = RNN(1, 100) - opt = bp.optim.Adam(lr=LearningRate, eps=1e-1) - trainer = bp.BPTT(model, loss_fun=model.loss, optimizer=opt) + opt = bp.optim.Adam(lr=LearningRate, eps=1e-1) + trainer = bp.BPTT(model, loss_fun=model.loss, optimizer=opt) - + def test_modifylr(self): + with bm.training_environment(): + model = RNN(1, 100) - def test_modifylr(self): - with bm.training_environment(): - model = RNN(1, 100) + Scheduler_lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) - Scheduler_lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) + opt1 = bp.optim.Adam(lr=Scheduler_lr, eps=1e-1) + opt1.lr.lr = 0.01 + trainer1 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt1) - opt1 = bp.optim.Adam(lr=Scheduler_lr, eps=1e-1) - opt1.lr.lr = 0.01 - trainer1 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt1) - - - opt2 = bp.optim.SGD(lr=Scheduler_lr) - opt2.lr.set_value(0.01) - trainer2 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt2) - + opt2 = bp.optim.SGD(lr=Scheduler_lr) + opt2.lr.set_value(0.01) + trainer2 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt2) if __name__ == '__main__': - absltest.main() + absltest.main() diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py index adebbbd72..7740f12c4 100644 --- a/brainpy/_src/optimizers/tests/test_scheduler.py +++ b/brainpy/_src/optimizers/tests/test_scheduler.py @@ -16,93 +16,89 @@ class TestMultiStepLR(parameterized.TestCase): - @parameterized.product( - last_epoch=[-1, 0, 5, 10] - ) - def test2(self, last_epoch): - bm.random.seed() - scheduler1 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) - - for i in range(1, 25): - lr1 = scheduler1(i + last_epoch) - lr2 = scheduler2() - scheduler2.step_epoch() - print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}') - self.assertTrue(lr1 == lr2) - + @parameterized.product( + last_epoch=[-1, 0, 5, 10] + ) + def test2(self, last_epoch): + bm.random.seed() + scheduler1 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) + scheduler2 = scheduler.MultiStepLR(0.1, [10, 20], gamma=0.1, last_epoch=last_epoch) + + for i in range(1, 25): + lr1 = scheduler1(i + last_epoch) + lr2 = scheduler2() + scheduler2.step_epoch() + print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}') + self.assertTrue(lr1 == lr2) class TestStepLR(parameterized.TestCase): - @parameterized.named_parameters( - {'testcase_name': f'last_epoch={last_epoch}', - 'last_epoch': last_epoch} - for last_epoch in [-1, 0, 5, 10] - ) - def test1(self, last_epoch): - bm.random.seed() - scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) - scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) - for i in range(1, 25): - lr1 = scheduler1(i + last_epoch) - lr2 = scheduler2() - scheduler2.step_epoch() - print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}') - self.assertTrue(lr1 == lr2) - + @parameterized.named_parameters( + {'testcase_name': f'last_epoch={last_epoch}', + 'last_epoch': last_epoch} + for last_epoch in [-1, 0, 5, 10] + ) + def test1(self, last_epoch): + bm.random.seed() + scheduler1 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) + scheduler2 = scheduler.StepLR(0.1, 10, gamma=0.1, last_epoch=last_epoch) + for i in range(1, 25): + lr1 = scheduler1(i + last_epoch) + lr2 = scheduler2() + scheduler2.step_epoch() + print(f'{scheduler2.last_epoch}, {lr1:.4f}, {lr2:.4f}') + self.assertTrue(lr1 == lr2) class TestCosineAnnealingLR(unittest.TestCase): - def test1(self): - bm.random.seed() - max_epoch = 50 - iters = 200 - sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) - all_lr1 = [[], []] - all_lr2 = [[], []] - for epoch in range(max_epoch): - for batch in range(iters): - all_lr1[0].append(epoch + batch / iters) - all_lr1[1].append(sch()) - sch.step_epoch() - all_lr2[0].append(epoch) - all_lr2[1].append(sch()) - sch.step_epoch() - - if show: - plt.subplot(211) - plt.plot(jax.numpy.asarray(all_lr1[0]), jax.numpy.asarray(all_lr1[1])) - plt.subplot(212) - plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1])) - plt.show() - plt.close() - + def test1(self): + bm.random.seed() + max_epoch = 50 + iters = 200 + sch = scheduler.CosineAnnealingLR(0.1, T_max=5, eta_min=0, last_epoch=-1) + all_lr1 = [[], []] + all_lr2 = [[], []] + for epoch in range(max_epoch): + for batch in range(iters): + all_lr1[0].append(epoch + batch / iters) + all_lr1[1].append(sch()) + sch.step_epoch() + all_lr2[0].append(epoch) + all_lr2[1].append(sch()) + sch.step_epoch() + + if show: + plt.subplot(211) + plt.plot(jax.numpy.asarray(all_lr1[0]), jax.numpy.asarray(all_lr1[1])) + plt.subplot(212) + plt.plot(jax.numpy.asarray(all_lr2[0]), jax.numpy.asarray(all_lr2[1])) + plt.show() + plt.close() class TestCosineAnnealingWarmRestarts(unittest.TestCase): - def test1(self): - bm.random.seed() - max_epoch = 50 - iters = 200 - sch = scheduler.CosineAnnealingWarmRestarts(0.1, - iters, - T_0=5, - T_mult=1, - last_call=-1) - all_lr1 = [] - all_lr2 = [] - for epoch in range(max_epoch): - for batch in range(iters): - all_lr1.append(sch()) - sch.step_call() - all_lr2.append(sch()) - sch.step_epoch() - - if show: - plt.subplot(211) - plt.plot(jax.numpy.asarray(all_lr1)) - plt.subplot(212) - plt.plot(jax.numpy.asarray(all_lr2)) - plt.show() - plt.close() - + def test1(self): + bm.random.seed() + max_epoch = 50 + iters = 200 + sch = scheduler.CosineAnnealingWarmRestarts(0.1, + iters, + T_0=5, + T_mult=1, + last_call=-1) + all_lr1 = [] + all_lr2 = [] + for epoch in range(max_epoch): + for batch in range(iters): + all_lr1.append(sch()) + sch.step_call() + all_lr2.append(sch()) + sch.step_epoch() + + if show: + plt.subplot(211) + plt.plot(jax.numpy.asarray(all_lr1)) + plt.subplot(212) + plt.plot(jax.numpy.asarray(all_lr2)) + plt.show() + plt.close() diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 0f3a7ee65..5ffa2e4d6 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -1,17 +1,14 @@ # -*- coding: utf-8 -*- -import functools import inspect import time import warnings from collections.abc import Iterable from typing import Dict, Union, Sequence, Callable, Tuple, Optional, Any -import brainstate.compile import jax import jax.numpy as jnp import numpy as np -import tqdm.auto from jax.tree_util import tree_map, tree_flatten from brainpy import math as bm, tools @@ -24,7 +21,7 @@ from brainpy.types import Output, Monitor __all__ = [ - 'DSRunner', + 'DSRunner', ] SUPPORTED_INPUT_OPS = ['-', '+', '*', '/', '='] @@ -32,630 +29,628 @@ def _call_fun_with_share(f, *args, **kwargs): - try: - sha = share.get_shargs() - inspect.signature(f).bind(sha, *args, **kwargs) - warnings.warn(_input_deprecate_msg, UserWarning) - return f(sha, *args, **kwargs) - except TypeError: - return f(*args, **kwargs) + try: + sha = share.get_shargs() + inspect.signature(f).bind(sha, *args, **kwargs) + warnings.warn(_input_deprecate_msg, UserWarning) + return f(sha, *args, **kwargs) + except TypeError: + return f(*args, **kwargs) def _is_brainpy_array(x): - return isinstance(x, bm.BaseArray) + return isinstance(x, bm.BaseArray) def check_and_format_inputs(host, inputs): - """Check inputs and get the formatted inputs for the given population. - - Parameters:: - - host : DynamicalSystem - The host which contains all data. - inputs : tuple, list - The inputs of the population. - - Returns:: - - formatted_inputs : tuple, list - The formatted inputs of the population. - """ - - # 1. check inputs - # --------- - if inputs is None: - inputs = [] - if not isinstance(inputs, (tuple, list)): - raise RunningError('"inputs" must be a tuple/list.') - if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)): - if isinstance(inputs[0], (str, bm.Variable)): - inputs = [inputs] - else: - raise RunningError('Unknown input structure, only support inputs ' - 'with format of "(target, value, [type, operation])".') - for one_input in inputs: - if not 2 <= len(one_input) <= 4: - raise RunningError('For each target, you must specify ' - '"(target, value, [type, operation])".') - if len(one_input) == 3 and one_input[2] not in SUPPORTED_INPUT_TYPE: - raise RunningError(f'Input type only supports ' - f'"{SUPPORTED_INPUT_TYPE}", ' - f'not "{one_input[2]}".') - if len(one_input) == 4 and one_input[3] not in SUPPORTED_INPUT_OPS: - raise RunningError(f'Input operation only supports ' - f'"{SUPPORTED_INPUT_OPS}", ' - f'not "{one_input[3]}".') - - # 2. get targets and attributes - # --------- - inputs_which_found_target = [] - - # checking 1: absolute access - # Check whether the input target node is accessible, - # and check whether the target node has the attribute - for one_input in inputs: - key = one_input[0] - if isinstance(key, bm.Variable): - real_target = key - elif isinstance(key, str): - splits = key.split('.') - target = host - try: - for split in splits: - target = getattr(target, split) - except AttributeError: - raise AttributeError(f'target {target} does not have "{split}"') - real_target = target - else: - raise RunningError(f'For each input, input[0] must be a string to ' - f'specify variable of the target, but we got {key}.') - inputs_which_found_target.append((real_target,) + tuple(one_input[1:])) - - # 3. format inputs - # --------- - formatted_inputs = [] - for one_input in inputs_which_found_target: - # input value - data_value = one_input[1] - - # input type - if len(one_input) >= 3: - if one_input[2] == 'iter': - if not isinstance(data_value, Iterable): - raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n' - f'is set to be "iter" type, however we got the value with ' - f'the type of {type(data_value)}') - elif one_input[2] == 'func': - if not callable(data_value): - raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n' - f'is set to be "func" type, however we got the value with ' - f'the type of {type(data_value)}') - elif one_input[2] != 'fix': - raise RunningError(f'Only support {SUPPORTED_INPUT_TYPE} input type, but ' - f'we got "{one_input[2]}"') - - data_type = one_input[2] - else: - data_type = 'fix' - - # operation - if len(one_input) == 4: - data_op = one_input[3] - else: - data_op = '+' - if data_op not in SUPPORTED_INPUT_OPS: - raise RunningError(f'Only support {SUPPORTED_INPUT_OPS}, while we got ' - f'{data_op} in {one_input}') - - # final - format_inp = (one_input[0], data_value, data_type, data_op) - formatted_inputs.append(format_inp) - - fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} - - for variable, value, type_, op in formatted_inputs: - # variable - if not isinstance(variable, bm.Variable): - raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' - f'its value will not change, we think there is no need to ' - f'give its input.') - - # input data - if type_ == 'iter': - if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): - array_inputs[op].append([variable, bm.as_jax(value)]) - else: - next_inputs[op].append([variable, iter(value)]) - elif type_ == 'func': - func_inputs[op].append([variable, value]) - else: - fix_inputs[op].append([variable, value]) - - return {'fixed': fix_inputs, - 'iterated': next_inputs, - 'functional': func_inputs, - 'array': array_inputs} - - -def _f_ops(ops, var, data): - if ops == '=': - var[:] = data - elif ops == '+': - var += data - elif ops == '-': - var -= data - elif ops == '*': - var *= data - elif ops == '/': - var /= data - else: - raise ValueError(f'Unknown input operation: {ops}') + """Check inputs and get the formatted inputs for the given population. + Parameters:: -class DSRunner(Runner): - """The runner for :py:class:`~.DynamicalSystem`. - - Parameters:: - - target : DynamicalSystem - The target model to run. - - inputs : list, tuple, callable - The inputs for variables in the target model. - - .. note:: - - This argument can be used to set the inputs to the - :py:class:`~.Variable` instances in the ``target``. - If you peruse to give time-dependent inputs, please use - ``DSRunner.predict()`` or ``DSRunner.run()`` function. - - - It can be a list/tuple with the format - of `[(target, value, [type, operation])]`, where `target` is the - input target, `value` is the input value, `type` is the input type - (such as "fix", "iter", "func"), `operation` is the operation for inputs - (such as "+", "-", "*", "/", "="). - - - ``target``: should be a string or :py:class:`~.Variable`. Can be specified by the - *absolute access* or *relative access*. - - ``value``: should be a scalar, vector, matrix. - - ``type``: should be a string. "fix" means the input `value` - is a constant. "iter" means the input `value` can be changed - over time. "func" mean the input is obtained through the functional call. - - ``operation``: should be a string, support `+`, `-`, `*`, `/`, `=`. - - Also, if you want to specify multiple inputs, just give multiple - ``(target, value, [type, operation])``, - for example ``[(target1, value1), (target2, value2)]``. - - - It can also be a callable function which receives the shared argument. - In this functional input, users can manually specify the inputs for the target variables. - This input function should receive one argument ``shared`` which contains the - shared arguments like time ``t``, time step ``dt``, and index ``i``. + host : DynamicalSystem + The host which contains all data. + inputs : tuple, list + The inputs of the population. - .. versionchanged:: 2.3.1 - ``fun_inputs`` are merged into ``inputs``. - fun_inputs: callable - The functional inputs. Manually specify the inputs for the target variables. - This input function should receive one argument ``shared`` which contains the - shared arguments like time ``t``, time step ``dt``, and index ``i``. - - .. deprecated:: 2.3.1 - Will be removed since version 2.4.0. - monitors: Optional, sequence of str, dict, Monitor - Variables to monitor. - - - A list of string. Like ``monitors=['a', 'b', 'c']``. - - A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']`` - - A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}`` - - A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}`` - - A dict with the callable function, like ``monitors={'a': lambda: model.spike[:5]}`` - - .. versionchanged:: 2.3.1 - ``fun_monitors`` are merged into ``monitors``. - fun_monitors: dict - Monitoring variables by a dict of callable functions. - The dict ``key`` should be a string for the later retrieval by ``runner.mon[key]``. - The dict ``value`` should be a callable function which receives two arguments: ``t`` and ``dt``. - .. code-block:: - fun_monitors = {'spike': lambda: model.spike[:10], - 'V10': lambda: model.V[10]} - - .. deprecated:: 2.3.1 - Will be removed since version 2.4.0. - jit: bool, dict - The JIT settings. - Using dict is able to set the jit mode at different phase, - for instance, ``jit={'predict': True, 'fit': False}``. - - progress_bar: bool - Use progress bar to report the running progress or not? - - dyn_vars: Optional, dict - The dynamically changed variables. Instance of :py:class:`~.Variable`. - These variables together with variable retrieved from the ``target`` - constitute all dynamical variables in this runner. - - numpy_mon_after_run : bool - When finishing the network running, transform the JAX arrays into numpy ndarray or not? - - data_first_axis: str - Set the default data dimension arrangement. - To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the - time length (``data_first_axis='T'``). - In order to be compatible with previous API, default is set to be ``False``. - - .. versionadded:: 2.3.1 - - memory_efficient: bool - Whether using the memory-efficient way to just-in-time compile the given target. - Default is False. - - .. versionadded:: 2.3.8 - - """ - - target: DynamicalSystem - - def __init__( - self, - target: DynamicalSystem, - - # inputs for target variables - inputs: Union[Sequence, Callable] = (), - - # monitors - monitors: Optional[Union[Sequence, Dict]] = None, - numpy_mon_after_run: bool = True, - - # jit - jit: Union[bool, Dict[str, bool]] = True, - dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None, - memory_efficient: bool = False, - - # extra info - dt: Optional[float] = None, - t0: Union[float, int] = 0., - progress_bar: bool = True, - data_first_axis: Optional[str] = None, - - # deprecated - fun_inputs: Optional[Callable] = None, - fun_monitors: Optional[Dict[str, Callable]] = None, - ): - if not isinstance(target, DynamicalSystem): - raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, ' - f'but we got {type(target)}: {target}') - super().__init__(target=target, - monitors=monitors, - fun_monitors=fun_monitors, - jit=jit, - progress_bar=progress_bar, - dyn_vars=dyn_vars, - numpy_mon_after_run=numpy_mon_after_run) - - # t0 and i0 - self.i0 = 0 - self.t0 = t0 - if data_first_axis is None: - data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T' - assert data_first_axis in ['B', 'T'] - self.data_first_axis = data_first_axis - - # parameters - dt = bm.get_dt() if dt is None else dt - if not isinstance(dt, float): - raise RunningError(f'"dt" must be float, but got {dt}') - self.dt = dt - - # Build input function - if fun_inputs is not None: - warnings.warn('`fun_inputs` is deprecated since version 2.3.1. ' - 'Define `fun_inputs` as `inputs` instead.', - UserWarning) - self._fun_inputs = fun_inputs - if callable(inputs): - self._inputs = inputs - else: - self._inputs = check_and_format_inputs(host=target, inputs=inputs) - - # run function - self._jit_step_func_predict = bm.jit(self._step_func_predict) - - # monitors - self._memory_efficient = memory_efficient - if memory_efficient and not numpy_mon_after_run: - raise ValueError('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.') - - def __repr__(self): - name = self.__class__.__name__ - indent = " " * len(name) + ' ' - indent2 = indent + " " * len("target") - return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n' - f'{indent}jit={self.jit},\n' - f'{indent}dt={self.dt},\n' - f'{indent}data_first_axis={self.data_first_axis})') - - def reset_state(self): - """Reset state of the ``DSRunner``.""" - self.i0 = 0 - - def predict( - self, - duration: float = None, - inputs: Any = None, - reset_state: bool = False, - eval_time: bool = False, - shared_args: Dict = None, - - # deprecated - inputs_are_batching: bool = None, - ) -> Union[Output, Tuple[float, Output]]: - """Running a duration with the given target model. See `.predict()` function - for more details. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - and its output. + Returns:: - Parameters:: + formatted_inputs : tuple, list + The formatted inputs of the population. + """ - duration: float - The simulation time length. - If you have provided ``inputs``, there is no longer need to provide ``duration``. - inputs: ArrayType, dict of ArrayType, sequence of ArrayType - The input data. - - - If the mode of ``target`` is instance of :py:class:`~.BatchingMode`, - ``inputs`` must be a PyTree of data with two dimensions: - ``(batch, time, ...)`` when ``data_first_axis='B'``, - or ``(time, batch, ...)`` when ``data_first_axis='T'``. - - If the mode of ``target`` is instance of :py:class:`~.NonBatchingMode`, - the ``inputs`` should be a PyTree of data with one dimension: - ``(time, ...)``. - inputs_are_batching: bool - Whether the ``inputs`` are batching. If `True`, the batching axis is the - first dimension. + # 1. check inputs + # --------- + if inputs is None: + inputs = [] + if not isinstance(inputs, (tuple, list)): + raise RunningError('"inputs" must be a tuple/list.') + if len(inputs) > 0 and not isinstance(inputs[0], (list, tuple)): + if isinstance(inputs[0], (str, bm.Variable)): + inputs = [inputs] + else: + raise RunningError('Unknown input structure, only support inputs ' + 'with format of "(target, value, [type, operation])".') + for one_input in inputs: + if not 2 <= len(one_input) <= 4: + raise RunningError('For each target, you must specify ' + '"(target, value, [type, operation])".') + if len(one_input) == 3 and one_input[2] not in SUPPORTED_INPUT_TYPE: + raise RunningError(f'Input type only supports ' + f'"{SUPPORTED_INPUT_TYPE}", ' + f'not "{one_input[2]}".') + if len(one_input) == 4 and one_input[3] not in SUPPORTED_INPUT_OPS: + raise RunningError(f'Input operation only supports ' + f'"{SUPPORTED_INPUT_OPS}", ' + f'not "{one_input[3]}".') + + # 2. get targets and attributes + # --------- + inputs_which_found_target = [] + + # checking 1: absolute access + # Check whether the input target node is accessible, + # and check whether the target node has the attribute + for one_input in inputs: + key = one_input[0] + if isinstance(key, bm.Variable): + real_target = key + elif isinstance(key, str): + splits = key.split('.') + target = host + try: + for split in splits: + target = getattr(target, split) + except AttributeError: + raise AttributeError(f'target {target} does not have "{split}"') + real_target = target + else: + raise RunningError(f'For each input, input[0] must be a string to ' + f'specify variable of the target, but we got {key}.') + inputs_which_found_target.append((real_target,) + tuple(one_input[1:])) + + # 3. format inputs + # --------- + formatted_inputs = [] + for one_input in inputs_which_found_target: + # input value + data_value = one_input[1] + + # input type + if len(one_input) >= 3: + if one_input[2] == 'iter': + if not isinstance(data_value, Iterable): + raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n' + f'is set to be "iter" type, however we got the value with ' + f'the type of {type(data_value)}') + elif one_input[2] == 'func': + if not callable(data_value): + raise ValueError(f'Input "{data_value}" for "{one_input[0]}" \n' + f'is set to be "func" type, however we got the value with ' + f'the type of {type(data_value)}') + elif one_input[2] != 'fix': + raise RunningError(f'Only support {SUPPORTED_INPUT_TYPE} input type, but ' + f'we got "{one_input[2]}"') + + data_type = one_input[2] + else: + data_type = 'fix' - .. deprecated:: 2.3.1 - Will be removed after version 2.4.0. - reset_state: bool - Whether reset the model states. - eval_time: bool - Whether ro evaluate the running time. - shared_args: optional, dict - The shared arguments across different layers. + # operation + if len(one_input) == 4: + data_op = one_input[3] + else: + data_op = '+' + if data_op not in SUPPORTED_INPUT_OPS: + raise RunningError(f'Only support {SUPPORTED_INPUT_OPS}, while we got ' + f'{data_op} in {one_input}') + + # final + format_inp = (one_input[0], data_value, data_type, data_op) + formatted_inputs.append(format_inp) + + fix_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + next_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + func_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + array_inputs = {'=': [], '+': [], '-': [], '*': [], '/': []} + + for variable, value, type_, op in formatted_inputs: + # variable + if not isinstance(variable, bm.Variable): + raise RunningError(f'{variable}\n is not a dynamically changed Variable, ' + f'its value will not change, we think there is no need to ' + f'give its input.') + + # input data + if type_ == 'iter': + if isinstance(value, (bm.ndarray, np.ndarray, jnp.ndarray)): + array_inputs[op].append([variable, bm.as_jax(value)]) + else: + next_inputs[op].append([variable, iter(value)]) + elif type_ == 'func': + func_inputs[op].append([variable, value]) + else: + fix_inputs[op].append([variable, value]) - Returns:: + return {'fixed': fix_inputs, + 'iterated': next_inputs, + 'functional': func_inputs, + 'array': array_inputs} - output: ArrayType, dict, sequence - The model output. - """ - if inputs_are_batching is not None: - raise warnings.warn( - f''' - `inputs_are_batching` is no longer supported. - The target mode of {self.target.mode} has already indicated the input should be batching. - ''', - UserWarning - ) - if duration is None: - if inputs is None: - raise ValueError('Please provide "duration" or "inputs".') +def _f_ops(ops, var, data): + if ops == '=': + var[:] = data + elif ops == '+': + var += data + elif ops == '-': + var -= data + elif ops == '*': + var *= data + elif ops == '/': + var /= data else: - if inputs is not None: - warnings.warn('"inputs" has already has the time information. ' - 'Therefore there no longer need to provide "duration".', - UserWarning) - duration = None + raise ValueError(f'Unknown input operation: {ops}') - num_step = self._get_input_time_step(duration, inputs) - description = f'Predict {num_step} steps: ' - # reset the states of the model and the runner - if reset_state: - self.target.reset(self._get_input_batch_size(inputs)) - self.reset_state() +class DSRunner(Runner): + """The runner for :py:class:`~.DynamicalSystem`. - # shared arguments and inputs - indices = np.arange(self.i0, self.i0 + num_step, dtype=bm.int_) + Parameters:: - if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B': - inputs = tree_map(lambda x: jnp.moveaxis(x, 0, 1), inputs) + target : DynamicalSystem + The target model to run. + + inputs : list, tuple, callable + The inputs for variables in the target model. + + .. note:: + + This argument can be used to set the inputs to the + :py:class:`~.Variable` instances in the ``target``. + If you peruse to give time-dependent inputs, please use + ``DSRunner.predict()`` or ``DSRunner.run()`` function. + + - It can be a list/tuple with the format + of `[(target, value, [type, operation])]`, where `target` is the + input target, `value` is the input value, `type` is the input type + (such as "fix", "iter", "func"), `operation` is the operation for inputs + (such as "+", "-", "*", "/", "="). + + - ``target``: should be a string or :py:class:`~.Variable`. Can be specified by the + *absolute access* or *relative access*. + - ``value``: should be a scalar, vector, matrix. + - ``type``: should be a string. "fix" means the input `value` + is a constant. "iter" means the input `value` can be changed + over time. "func" mean the input is obtained through the functional call. + - ``operation``: should be a string, support `+`, `-`, `*`, `/`, `=`. + - Also, if you want to specify multiple inputs, just give multiple + ``(target, value, [type, operation])``, + for example ``[(target1, value1), (target2, value2)]``. + + - It can also be a callable function which receives the shared argument. + In this functional input, users can manually specify the inputs for the target variables. + This input function should receive one argument ``shared`` which contains the + shared arguments like time ``t``, time step ``dt``, and index ``i``. + + .. versionchanged:: 2.3.1 + ``fun_inputs`` are merged into ``inputs``. + fun_inputs: callable + The functional inputs. Manually specify the inputs for the target variables. + This input function should receive one argument ``shared`` which contains the + shared arguments like time ``t``, time step ``dt``, and index ``i``. - # build monitor - for key in self._monitors.keys(): - self.mon[key] = [] # reshape the monitor items + .. deprecated:: 2.3.1 + Will be removed since version 2.4.0. + monitors: Optional, sequence of str, dict, Monitor + Variables to monitor. + - A list of string. Like ``monitors=['a', 'b', 'c']``. + - A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']`` + - A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}`` + - A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}`` + - A dict with the callable function, like ``monitors={'a': lambda: model.spike[:5]}`` - # running - if eval_time: - t0 = time.time() - if inputs is None: - inputs = tuple() - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - outputs, hists = self._predict(indices, *inputs, shared_args=shared_args) - if eval_time: - running_time = time.time() - t0 + .. versionchanged:: 2.3.1 + ``fun_monitors`` are merged into ``monitors``. + fun_monitors: dict + Monitoring variables by a dict of callable functions. + The dict ``key`` should be a string for the later retrieval by ``runner.mon[key]``. + The dict ``value`` should be a callable function which receives two arguments: ``t`` and ``dt``. + .. code-block:: + fun_monitors = {'spike': lambda: model.spike[:10], + 'V10': lambda: model.V[10]} + .. deprecated:: 2.3.1 + Will be removed since version 2.4.0. + jit: bool, dict + The JIT settings. + Using dict is able to set the jit mode at different phase, + for instance, ``jit={'predict': True, 'fit': False}``. - # post-running for monitors - if self._memory_efficient: - self.mon['ts'] = indices * self.dt + self.t0 - for key in self._monitors.keys(): - self.mon[key] = np.asarray(self.mon[key]) - else: - hists['ts'] = indices * self.dt + self.t0 - if self.numpy_mon_after_run: - hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - else: - hists['ts'] = bm.as_jax(hists['ts']) - for key in hists.keys(): - self.mon[key] = hists[key] - self.i0 += num_step - return outputs if not eval_time else (running_time, outputs) - - def run(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]: - """Same as :py:func:`~.DSRunner.predict`. - """ - return self.predict(*args, **kwargs) + progress_bar: bool + Use progress bar to report the running progress or not? - def __call__(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]: - """Same as :py:func:`~.DSRunner.predict`. - """ - return self.predict(*args, **kwargs) + dyn_vars: Optional, dict + The dynamically changed variables. Instance of :py:class:`~.Variable`. + These variables together with variable retrieved from the ``target`` + constitute all dynamical variables in this runner. - def _predict(self, indices, *xs, shared_args=None) -> Union[Output, Monitor]: - """Predict the output according to the inputs. + numpy_mon_after_run : bool + When finishing the network running, transform the JAX arrays into numpy ndarray or not? - Parameters:: + data_first_axis: str + Set the default data dimension arrangement. + To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the + time length (``data_first_axis='T'``). + In order to be compatible with previous API, default is set to be ``False``. - xs: sequence - If `inputs` is not None, it should be a tensor with the shape of - :math:`(num_time, ...)`. - shared_args: optional, dict - The shared keyword arguments. + .. versionadded:: 2.3.1 - Returns:: + memory_efficient: bool + Whether using the memory-efficient way to just-in-time compile the given target. + Default is False. + + .. versionadded:: 2.3.8 - outputs, hists - A tuple of pair of (outputs, hists). """ - if shared_args is None: - shared_args = dict() - shared_args = tools.DotDict(shared_args) - - outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args) - if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B': - outs_and_mons = tree_map(lambda x: jnp.moveaxis(x, 0, 1) if x.ndim >= 2 else x, - outs_and_mons) - return outs_and_mons - - def _step_func_monitor(self): - res = dict() - for key, val in self._monitors.items(): - if callable(val): - res[key] = _call_fun_with_share(val) - else: - (variable, idx) = val - if idx is None: - res[key] = variable.value + + target: DynamicalSystem + + def __init__( + self, + target: DynamicalSystem, + + # inputs for target variables + inputs: Union[Sequence, Callable] = (), + + # monitors + monitors: Optional[Union[Sequence, Dict]] = None, + numpy_mon_after_run: bool = True, + + # jit + jit: Union[bool, Dict[str, bool]] = True, + dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None, + memory_efficient: bool = False, + + # extra info + dt: Optional[float] = None, + t0: Union[float, int] = 0., + progress_bar: bool = True, + data_first_axis: Optional[str] = None, + + # deprecated + fun_inputs: Optional[Callable] = None, + fun_monitors: Optional[Dict[str, Callable]] = None, + ): + if not isinstance(target, DynamicalSystem): + raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, ' + f'but we got {type(target)}: {target}') + super().__init__(target=target, + monitors=monitors, + fun_monitors=fun_monitors, + jit=jit, + progress_bar=progress_bar, + dyn_vars=dyn_vars, + numpy_mon_after_run=numpy_mon_after_run) + + # t0 and i0 + self.i0 = 0 + self.t0 = t0 + if data_first_axis is None: + data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T' + assert data_first_axis in ['B', 'T'] + self.data_first_axis = data_first_axis + + # parameters + dt = bm.get_dt() if dt is None else dt + if not isinstance(dt, float): + raise RunningError(f'"dt" must be float, but got {dt}') + self.dt = dt + + # Build input function + if fun_inputs is not None: + warnings.warn('`fun_inputs` is deprecated since version 2.3.1. ' + 'Define `fun_inputs` as `inputs` instead.', + UserWarning) + self._fun_inputs = fun_inputs + if callable(inputs): + self._inputs = inputs else: - res[key] = variable[bm.as_jax(idx)] - return res - - def _step_func_input(self): - if self._fun_inputs is not None: - self._fun_inputs(share.get_shargs()) - if callable(self._inputs): - _call_fun_with_share(self._inputs) - else: - for ops, values in self._inputs['fixed'].items(): - for var, data in values: - _f_ops(ops, var, data) - for ops, values in self._inputs['array'].items(): - for var, data in values: - _f_ops(ops, var, data[share['i']]) - for ops, values in self._inputs['functional'].items(): - for var, data in values: - _f_ops(ops, var, _call_fun_with_share(data)) - for ops, values in self._inputs['iterated'].items(): - for var, data in values: - _f_ops(ops, var, next(data)) - - def _get_input_batch_size(self, xs=None) -> Optional[int]: - """Get the batch size in the given input data.""" - if xs is None: - return None - if isinstance(self.target.mode, bm.NonBatchingMode): - return None - if isinstance(xs, (bm.BaseArray, jax.Array, np.ndarray)): - return xs.shape[1] if self.data_first_axis == 'T' else xs.shape[0] - leaves, _ = tree_flatten(xs, is_leaf=_is_brainpy_array) - if self.data_first_axis == 'T': - num_batch_sizes = [x.shape[1] for x in leaves] - else: - num_batch_sizes = [x.shape[0] for x in leaves] - if len(set(num_batch_sizes)) != 1: - raise ValueError(f'Number of batch size is different across arrays in ' - f'the provided "xs". We got {set(num_batch_sizes)}.') - return num_batch_sizes[0] - - def _get_input_time_step(self, duration=None, xs=None) -> int: - """Get the length of time step in the given ``duration`` and ``xs``.""" - if duration is not None: - return int(duration / self.dt) - if xs is not None: - if isinstance(xs, (bm.BaseArray, jax.Array, np.ndarray)): - return xs.shape[0] if self.data_first_axis == 'T' else xs.shape[1] - else: - leaves, _ = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.BaseArray)) + self._inputs = check_and_format_inputs(host=target, inputs=inputs) + + # run function + self._jit_step_func_predict = bm.jit(self._step_func_predict) + + # monitors + self._memory_efficient = memory_efficient + if memory_efficient and not numpy_mon_after_run: + raise ValueError('When setting "gpu_memory_efficient=True", "numpy_mon_after_run" can not be False.') + + def __repr__(self): + name = self.__class__.__name__ + indent = " " * len(name) + ' ' + indent2 = indent + " " * len("target") + return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n' + f'{indent}jit={self.jit},\n' + f'{indent}dt={self.dt},\n' + f'{indent}data_first_axis={self.data_first_axis})') + + def reset_state(self): + """Reset state of the ``DSRunner``.""" + self.i0 = 0 + + def predict( + self, + duration: float = None, + inputs: Any = None, + reset_state: bool = False, + eval_time: bool = False, + shared_args: Dict = None, + + # deprecated + inputs_are_batching: bool = None, + ) -> Union[Output, Tuple[float, Output]]: + """Running a duration with the given target model. See `.predict()` function + for more details. + + This function use the JIT compilation to accelerate the model simulation. + Moreover, it can automatically monitor the node variables, states, inputs, + and its output. + + Parameters:: + + duration: float + The simulation time length. + If you have provided ``inputs``, there is no longer need to provide ``duration``. + inputs: ArrayType, dict of ArrayType, sequence of ArrayType + The input data. + + - If the mode of ``target`` is instance of :py:class:`~.BatchingMode`, + ``inputs`` must be a PyTree of data with two dimensions: + ``(batch, time, ...)`` when ``data_first_axis='B'``, + or ``(time, batch, ...)`` when ``data_first_axis='T'``. + - If the mode of ``target`` is instance of :py:class:`~.NonBatchingMode`, + the ``inputs`` should be a PyTree of data with one dimension: + ``(time, ...)``. + inputs_are_batching: bool + Whether the ``inputs`` are batching. If `True`, the batching axis is the + first dimension. + + .. deprecated:: 2.3.1 + Will be removed after version 2.4.0. + reset_state: bool + Whether reset the model states. + eval_time: bool + Whether ro evaluate the running time. + shared_args: optional, dict + The shared arguments across different layers. + + Returns:: + + output: ArrayType, dict, sequence + The model output. + """ + + if inputs_are_batching is not None: + raise warnings.warn( + f''' + `inputs_are_batching` is no longer supported. + The target mode of {self.target.mode} has already indicated the input should be batching. + ''', + UserWarning + ) + if duration is None: + if inputs is None: + raise ValueError('Please provide "duration" or "inputs".') + else: + if inputs is not None: + warnings.warn('"inputs" has already has the time information. ' + 'Therefore there no longer need to provide "duration".', + UserWarning) + duration = None + + num_step = self._get_input_time_step(duration, inputs) + description = f'Predict {num_step} steps: ' + + # reset the states of the model and the runner + if reset_state: + self.target.reset(self._get_input_batch_size(inputs)) + self.reset_state() + + # shared arguments and inputs + indices = np.arange(self.i0, self.i0 + num_step, dtype=bm.int_) + + if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B': + inputs = tree_map(lambda x: jnp.moveaxis(x, 0, 1), inputs) + + # build monitor + for key in self._monitors.keys(): + self.mon[key] = [] # reshape the monitor items + + # running + if eval_time: + t0 = time.time() + if inputs is None: + inputs = tuple() + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + outputs, hists = self._predict(indices, *inputs, shared_args=shared_args) + if eval_time: + running_time = time.time() - t0 + + # post-running for monitors + if self._memory_efficient: + self.mon['ts'] = indices * self.dt + self.t0 + for key in self._monitors.keys(): + self.mon[key] = np.asarray(self.mon[key]) + else: + hists['ts'] = indices * self.dt + self.t0 + if self.numpy_mon_after_run: + hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.BaseArray)) + else: + hists['ts'] = bm.as_jax(hists['ts']) + for key in hists.keys(): + self.mon[key] = hists[key] + self.i0 += num_step + return outputs if not eval_time else (running_time, outputs) + + def run(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]: + """Same as :py:func:`~.DSRunner.predict`. + """ + return self.predict(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]: + """Same as :py:func:`~.DSRunner.predict`. + """ + return self.predict(*args, **kwargs) + + def _predict(self, indices, *xs, shared_args=None) -> Union[Output, Monitor]: + """Predict the output according to the inputs. + + Parameters:: + + xs: sequence + If `inputs` is not None, it should be a tensor with the shape of + :math:`(num_time, ...)`. + shared_args: optional, dict + The shared keyword arguments. + + Returns:: + + outputs, hists + A tuple of pair of (outputs, hists). + """ + if shared_args is None: + shared_args = dict() + shared_args = tools.DotDict(shared_args) + + outs_and_mons = self._fun_predict(indices, *xs, shared_args=shared_args) + if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B': + outs_and_mons = tree_map(lambda x: jnp.moveaxis(x, 0, 1) if x.ndim >= 2 else x, + outs_and_mons) + return outs_and_mons + + def _step_func_monitor(self): + res = dict() + for key, val in self._monitors.items(): + if callable(val): + res[key] = _call_fun_with_share(val) + else: + (variable, idx) = val + if idx is None: + res[key] = variable.value + else: + res[key] = variable[bm.as_jax(idx)] + return res + + def _step_func_input(self): + if self._fun_inputs is not None: + self._fun_inputs(share.get_shargs()) + if callable(self._inputs): + _call_fun_with_share(self._inputs) + else: + for ops, values in self._inputs['fixed'].items(): + for var, data in values: + _f_ops(ops, var, data) + for ops, values in self._inputs['array'].items(): + for var, data in values: + _f_ops(ops, var, data[share['i']]) + for ops, values in self._inputs['functional'].items(): + for var, data in values: + _f_ops(ops, var, _call_fun_with_share(data)) + for ops, values in self._inputs['iterated'].items(): + for var, data in values: + _f_ops(ops, var, next(data)) + + def _get_input_batch_size(self, xs=None) -> Optional[int]: + """Get the batch size in the given input data.""" + if xs is None: + return None + if isinstance(self.target.mode, bm.NonBatchingMode): + return None + if isinstance(xs, (bm.BaseArray, jax.Array, np.ndarray)): + return xs.shape[1] if self.data_first_axis == 'T' else xs.shape[0] + leaves, _ = tree_flatten(xs, is_leaf=_is_brainpy_array) if self.data_first_axis == 'T': - num_steps = [x.shape[0] for x in leaves] + num_batch_sizes = [x.shape[1] for x in leaves] else: - num_steps = [x.shape[1] for x in leaves] - if len(set(num_steps)) != 1: - raise ValueError(f'Number of time step is different across arrays in ' - f'the provided "xs". We got {set(num_steps)}.') - return num_steps[0] - else: - raise ValueError - - def _step_mon_on_cpu(self, args, transforms): - for key, val in args.items(): - self.mon[key].append(val) - - def _step_func_predict(self, i, *x, shared_args=None): - # input step - if shared_args is not None: - assert isinstance(shared_args, dict) - share.save(**shared_args) - share.save(t=self.t0 + i * self.dt, i=i, dt=self.dt) - self._step_func_input() - - # dynamics update step - out = self.target(*x) - - # monitor step - mon = self._step_func_monitor() - - # share.clear_shargs() - clear_input(self.target) - - if self._memory_efficient: - mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype) - result = jax.pure_callback( - self._step_mon_on_cpu, - mon_shape_dtype, - mon, - ) - return out, None - else: - return out, mon - - def _fun_predict(self, indices, *inputs, shared_args=None): - if self._memory_efficient: - if self.jit['predict']: - run_fun = self._jit_step_func_predict - else: - run_fun = self._step_func_predict - - outs = None - for i in range(indices.shape[0]): - out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs)) - if outs is None: - outs = tree_map(lambda a: [], out) - outs = tree_map(lambda a, o: o.append(a), out, outs) - outs = tree_map(lambda a: bm.as_jax(a), outs) - return outs, None + num_batch_sizes = [x.shape[0] for x in leaves] + if len(set(num_batch_sizes)) != 1: + raise ValueError(f'Number of batch size is different across arrays in ' + f'the provided "xs". We got {set(num_batch_sizes)}.') + return num_batch_sizes[0] + + def _get_input_time_step(self, duration=None, xs=None) -> int: + """Get the length of time step in the given ``duration`` and ``xs``.""" + if duration is not None: + return int(duration / self.dt) + if xs is not None: + if isinstance(xs, (bm.BaseArray, jax.Array, np.ndarray)): + return xs.shape[0] if self.data_first_axis == 'T' else xs.shape[1] + else: + leaves, _ = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.BaseArray)) + if self.data_first_axis == 'T': + num_steps = [x.shape[0] for x in leaves] + else: + num_steps = [x.shape[1] for x in leaves] + if len(set(num_steps)) != 1: + raise ValueError(f'Number of time step is different across arrays in ' + f'the provided "xs". We got {set(num_steps)}.') + return num_steps[0] + else: + raise ValueError + + def _step_mon_on_cpu(self, args, transforms): + for key, val in args.items(): + self.mon[key].append(val) + + def _step_func_predict(self, i, *x, shared_args=None): + # input step + if shared_args is not None: + assert isinstance(shared_args, dict) + share.save(**shared_args) + share.save(t=self.t0 + i * self.dt, i=i, dt=self.dt) + self._step_func_input() + + # dynamics update step + out = self.target(*x) + + # monitor step + mon = self._step_func_monitor() + + # share.clear_shargs() + clear_input(self.target) + + if self._memory_efficient: + mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype) + result = jax.pure_callback( + self._step_mon_on_cpu, + mon_shape_dtype, + mon, + ) + return out, None + else: + return out, mon + + def _fun_predict(self, indices, *inputs, shared_args=None): + if self._memory_efficient: + if self.jit['predict']: + run_fun = self._jit_step_func_predict + else: + run_fun = self._step_func_predict + + outs = None + for i in range(indices.shape[0]): + out, _ = run_fun(indices[i], *tree_map(lambda a: a[i], inputs)) + if outs is None: + outs = tree_map(lambda a: [], out) + outs = tree_map(lambda a, o: o.append(a), out, outs) + outs = tree_map(lambda a: bm.as_jax(a), outs) + return outs, None - else: - return bm.for_loop(self._step_func_predict, - (indices, *inputs), - jit=self.jit['predict'], - unroll_kwargs={'shared_args': shared_args}, - progress_bar=self.progress_bar) + else: + return bm.for_loop(self._step_func_predict, + (indices, *inputs), + jit=self.jit['predict'], + unroll_kwargs={'shared_args': shared_args}, + progress_bar=self.progress_bar) diff --git a/brainpy/_src/running/__init__.py b/brainpy/_src/running/__init__.py index 8aa04b462..e68271758 100644 --- a/brainpy/_src/running/__init__.py +++ b/brainpy/_src/running/__init__.py @@ -5,20 +5,17 @@ This module provides APIs for parallel brain simulations. """ +from . import constants from . import jax_multiprocessing from . import native_multiprocessing from . import pathos_multiprocessing -from . import constants - __all__ = (native_multiprocessing.__all__ + pathos_multiprocessing.__all__ + jax_multiprocessing.__all__ + constants.__all__) - from .jax_multiprocessing import * from .native_multiprocessing import * from .pathos_multiprocessing import * from .constants import * - diff --git a/brainpy/_src/running/constants.py b/brainpy/_src/running/constants.py index 363346f17..8ce71fc71 100644 --- a/brainpy/_src/running/constants.py +++ b/brainpy/_src/running/constants.py @@ -2,14 +2,13 @@ __all__ = [ - 'TRAIN_PHASE', 'FIT_PHASE', - 'TEST_PHASE', - 'PREDICT_PHASE', - 'RUN_PHASE', - 'LOSS_PHASE', + 'TRAIN_PHASE', 'FIT_PHASE', + 'TEST_PHASE', + 'PREDICT_PHASE', + 'RUN_PHASE', + 'LOSS_PHASE', ] - TRAIN_PHASE = 'train' FIT_PHASE = 'fit' @@ -17,5 +16,3 @@ PREDICT_PHASE = 'predict' RUN_PHASE = 'predict' LOSS_PHASE = 'loss' - - diff --git a/brainpy/_src/running/jax_multiprocessing.py b/brainpy/_src/running/jax_multiprocessing.py index b99a1fee7..c3bff1cd1 100644 --- a/brainpy/_src/running/jax_multiprocessing.py +++ b/brainpy/_src/running/jax_multiprocessing.py @@ -10,8 +10,8 @@ from brainpy.types import ArrayType __all__ = [ - 'jax_vectorize_map', - 'jax_parallelize_map', + 'jax_vectorize_map', + 'jax_parallelize_map', ] @@ -21,63 +21,63 @@ def jax_vectorize_map( num_parallel: int, clear_buffer: bool = False ): - """Perform a vectorized map of a function by using ``jax.vmap``. - - This function can be used in CPU or GPU backends. But it is highly - suitable to be used in GPU backends. This is because ``jax.vmap`` - can parallelize the mapped axis on GPU devices. - - Parameters:: - - func: callable, function - The function to be mapped. - arguments: sequence, dict - The function arguments, used to define tasks. - num_parallel: int - The number of batch size. - clear_buffer: bool - Clear the buffer memory after running each batch data. - - Returns:: - - results: Any - The running results. - """ - if not isinstance(arguments, (dict, tuple, list)): - raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') - elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - if clear_buffer: - elements = [np.asarray(ele) for ele in elements] - num_pars = [len(ele) for ele in elements] - if len(np.unique(num_pars)) != 1: - raise ValueError(f'All elements in parameters should have the same length. ' - f'But we got {tree_unflatten(tree, num_pars)}') - - res_tree = None - results = None - vmap_func = vmap(func) - for i in range(0, num_pars[0], num_parallel): - run_f = vmap(func) if clear_buffer else vmap_func - if isinstance(arguments, dict): - r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) - elif isinstance(arguments, (tuple, list)): - r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) - else: - raise TypeError - res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - if results is None: - results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) - else: - for j, val in enumerate(res_values): - results[j].append(np.asarray(val) if clear_buffer else val) + """Perform a vectorized map of a function by using ``jax.vmap``. + + This function can be used in CPU or GPU backends. But it is highly + suitable to be used in GPU backends. This is because ``jax.vmap`` + can parallelize the mapped axis on GPU devices. + + Parameters:: + + func: callable, function + The function to be mapped. + arguments: sequence, dict + The function arguments, used to define tasks. + num_parallel: int + The number of batch size. + clear_buffer: bool + Clear the buffer memory after running each batch data. + + Returns:: + + results: Any + The running results. + """ + if not isinstance(arguments, (dict, tuple, list)): + raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') + elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.BaseArray)) if clear_buffer: - bm.clear_buffer_memory() - if res_tree is None: - return None - results = ([np.concatenate(res, axis=0) for res in results] - if clear_buffer else - [bm.concatenate(res, axis=0) for res in results]) - return tree_unflatten(res_tree, results) + elements = [np.asarray(ele) for ele in elements] + num_pars = [len(ele) for ele in elements] + if len(np.unique(num_pars)) != 1: + raise ValueError(f'All elements in parameters should have the same length. ' + f'But we got {tree_unflatten(tree, num_pars)}') + + res_tree = None + results = None + vmap_func = vmap(func) + for i in range(0, num_pars[0], num_parallel): + run_f = vmap(func) if clear_buffer else vmap_func + if isinstance(arguments, dict): + r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + elif isinstance(arguments, (tuple, list)): + r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + else: + raise TypeError + res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.BaseArray)) + if results is None: + results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) + else: + for j, val in enumerate(res_values): + results[j].append(np.asarray(val) if clear_buffer else val) + if clear_buffer: + bm.clear_buffer_memory() + if res_tree is None: + return None + results = ([np.concatenate(res, axis=0) for res in results] + if clear_buffer else + [bm.concatenate(res, axis=0) for res in results]) + return tree_unflatten(res_tree, results) def jax_parallelize_map( @@ -86,58 +86,58 @@ def jax_parallelize_map( num_parallel: int, clear_buffer: bool = False ): - """Perform a parallelized map of a function by using ``jax.pmap``. - - This function can be used in multi- CPU or GPU backends. - If you are using it in a single CPU, please set host device count - by ``brainpy.math.set_host_device_count(n)`` before. - - Parameters:: - - func: callable, function - The function to be mapped. - arguments: sequence, dict - The function arguments, used to define tasks. - num_parallel: int - The number of batch size. - clear_buffer: bool - Clear the buffer memory after running each batch data. - - Returns:: - - results: Any - The running results. - """ - if not isinstance(arguments, (dict, tuple, list)): - raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') - elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - if clear_buffer: - elements = [np.asarray(ele) for ele in elements] - num_pars = [len(ele) for ele in elements] - if len(np.unique(num_pars)) != 1: - raise ValueError(f'All elements in parameters should have the same length. ' - f'But we got {tree_unflatten(tree, num_pars)}') - - res_tree = None - results = None - vmap_func = pmap(func) - for i in range(0, num_pars[0], num_parallel): - run_f = pmap(func) if clear_buffer else vmap_func - if isinstance(arguments, dict): - r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) - else: - r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) - res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - if results is None: - results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) - else: - for j, val in enumerate(res_values): - results[j].append(np.asarray(val) if clear_buffer else val) + """Perform a parallelized map of a function by using ``jax.pmap``. + + This function can be used in multi- CPU or GPU backends. + If you are using it in a single CPU, please set host device count + by ``brainpy.math.set_host_device_count(n)`` before. + + Parameters:: + + func: callable, function + The function to be mapped. + arguments: sequence, dict + The function arguments, used to define tasks. + num_parallel: int + The number of batch size. + clear_buffer: bool + Clear the buffer memory after running each batch data. + + Returns:: + + results: Any + The running results. + """ + if not isinstance(arguments, (dict, tuple, list)): + raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') + elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.BaseArray)) if clear_buffer: - bm.clear_buffer_memory() - if res_tree is None: - return None - results = ([np.concatenate(res, axis=0) for res in results] - if clear_buffer else - [bm.concatenate(res, axis=0) for res in results]) - return tree_unflatten(res_tree, results) + elements = [np.asarray(ele) for ele in elements] + num_pars = [len(ele) for ele in elements] + if len(np.unique(num_pars)) != 1: + raise ValueError(f'All elements in parameters should have the same length. ' + f'But we got {tree_unflatten(tree, num_pars)}') + + res_tree = None + results = None + vmap_func = pmap(func) + for i in range(0, num_pars[0], num_parallel): + run_f = pmap(func) if clear_buffer else vmap_func + if isinstance(arguments, dict): + r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + else: + r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements])) + res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.BaseArray)) + if results is None: + results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) + else: + for j, val in enumerate(res_values): + results[j].append(np.asarray(val) if clear_buffer else val) + if clear_buffer: + bm.clear_buffer_memory() + if res_tree is None: + return None + results = ([np.concatenate(res, axis=0) for res in results] + if clear_buffer else + [bm.concatenate(res, axis=0) for res in results]) + return tree_unflatten(res_tree, results) diff --git a/brainpy/_src/running/native_multiprocessing.py b/brainpy/_src/running/native_multiprocessing.py index 88e3e7561..ebff9c498 100644 --- a/brainpy/_src/running/native_multiprocessing.py +++ b/brainpy/_src/running/native_multiprocessing.py @@ -1,103 +1,103 @@ # -*- coding: utf-8 -*- -from typing import Union, Sequence, Dict import multiprocessing +from typing import Union, Sequence, Dict __all__ = [ - 'process_pool', - 'process_pool_lock', + 'process_pool', + 'process_pool_lock', ] def process_pool(func: callable, all_params: Union[Sequence, Dict], num_process: int): - """Run multiple models in multi-processes. - - .. Note:: - This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. - - Parameters:: - - func : callable - The function to run model. - all_params : list, tuple, dict - The parameters of the function arguments. - The parameters for each process can be a tuple, or a dictionary. - num_process : int - The number of the processes. - - Returns:: - - results : list - Process results. - """ - print('{} jobs total.'.format(len(all_params))) - pool = multiprocessing.Pool(processes=num_process) - results = [] - for params in all_params: - if isinstance(params, (list, tuple)): - results.append(pool.apply_async(func, args=tuple(params))) - elif isinstance(params, dict): - results.append(pool.apply_async(func, kwds=params)) - else: - raise ValueError('Unknown parameter type: ', type(params)) - pool.close() - pool.join() - return [r.get() for r in results] + """Run multiple models in multi-processes. + + .. Note:: + This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. + + Parameters:: + + func : callable + The function to run model. + all_params : list, tuple, dict + The parameters of the function arguments. + The parameters for each process can be a tuple, or a dictionary. + num_process : int + The number of the processes. + + Returns:: + + results : list + Process results. + """ + print('{} jobs total.'.format(len(all_params))) + pool = multiprocessing.Pool(processes=num_process) + results = [] + for params in all_params: + if isinstance(params, (list, tuple)): + results.append(pool.apply_async(func, args=tuple(params))) + elif isinstance(params, dict): + results.append(pool.apply_async(func, kwds=params)) + else: + raise ValueError('Unknown parameter type: ', type(params)) + pool.close() + pool.join() + return [r.get() for r in results] def process_pool_lock(func: callable, all_params: Union[Sequence, Dict], num_process: int): - """Run multiple models in multi-processes with lock. - - Sometimes, you want to synchronize the processes. For example, - if you want to write something in a document, you cannot let - multiprocess simultaneously open this same file. So, you need - add a `lock` argument in your defined `func`: - - .. code-block:: python - - def some_func(..., lock, ...): - ... do something .. - - lock.acquire() - ... something cannot simultaneously do by multi-process .. - lock.release() - - In such case, you can use `process_pool_lock()` to run your model. - - .. Note:: - This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. - - Parameters:: - - func: callable - The function to run model. - all_params : list, tuple, dict - The parameters of the function arguments. - num_process : int - The number of the processes. - - Returns:: - - results : list - Process results. - """ - print('{} jobs total.'.format(len(all_params))) - pool = multiprocessing.Pool(processes=num_process) - m = multiprocessing.Manager() - lock = m.Lock() - results = [] - for net_params in all_params: - if isinstance(net_params, (list, tuple)): - results.append(pool.apply_async(func, args=tuple(net_params) + (lock,))) - elif isinstance(net_params, dict): - net_params.update(lock=lock) - results.append(pool.apply_async(func, kwds=net_params)) - else: - raise ValueError('Unknown parameter type: ', type(net_params)) - pool.close() - pool.join() - return [r.get() for r in results] + """Run multiple models in multi-processes with lock. + + Sometimes, you want to synchronize the processes. For example, + if you want to write something in a document, you cannot let + multiprocess simultaneously open this same file. So, you need + add a `lock` argument in your defined `func`: + + .. code-block:: python + + def some_func(..., lock, ...): + ... do something .. + + lock.acquire() + ... something cannot simultaneously do by multi-process .. + lock.release() + + In such case, you can use `process_pool_lock()` to run your model. + + .. Note:: + This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. + + Parameters:: + + func: callable + The function to run model. + all_params : list, tuple, dict + The parameters of the function arguments. + num_process : int + The number of the processes. + + Returns:: + + results : list + Process results. + """ + print('{} jobs total.'.format(len(all_params))) + pool = multiprocessing.Pool(processes=num_process) + m = multiprocessing.Manager() + lock = m.Lock() + results = [] + for net_params in all_params: + if isinstance(net_params, (list, tuple)): + results.append(pool.apply_async(func, args=tuple(net_params) + (lock,))) + elif isinstance(net_params, dict): + net_params.update(lock=lock) + results.append(pool.apply_async(func, kwds=net_params)) + else: + raise ValueError('Unknown parameter type: ', type(net_params)) + pool.close() + pool.join() + return [r.get() for r in results] diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py index 3e0b59b1c..c7f4db21d 100644 --- a/brainpy/_src/running/pathos_multiprocessing.py +++ b/brainpy/_src/running/pathos_multiprocessing.py @@ -19,17 +19,18 @@ from brainpy.errors import PackageMissingError try: - from pathos.helpers import cpu_count # noqa - from pathos.multiprocessing import ProcessPool # noqa - import multiprocess.context as ctx # noqa - ctx._force_start_method('spawn') + from pathos.helpers import cpu_count # noqa + from pathos.multiprocessing import ProcessPool # noqa + import multiprocess.context as ctx # noqa + + ctx._force_start_method('spawn') except ModuleNotFoundError: - cpu_count = None - ProcessPool = None + cpu_count = None + ProcessPool = None __all__ = [ - 'cpu_ordered_parallel', - 'cpu_unordered_parallel', + 'cpu_ordered_parallel', + 'cpu_unordered_parallel', ] @@ -41,81 +42,81 @@ def _parallel( num_task: int = None, **tqdm_kwargs: Any ) -> Generator: - """Perform a parallel map with a progress bar. - - Parameters:: - - ordered: bool - True for an ordered map, false for an unordered map. - function: callable, function - The function to apply to each element of the given Iterables. - arguments: sequence of Iterable, dict - One or more Iterables containing the data to be mapped. - num_process: int, float - Number of threads used for parallel running. If `int`, it is - the number of threads to be used; if `float`, it is the fraction - of total threads to be used for running. - num_task: int - The total number of tasks in this parallel running. - tqdm_kwargs: Any - The setting for the progress bar. - - Returns:: - - results: Iterable - A generator which will apply the function to each element of the given Iterables - in parallel in order with a progress bar. - """ - if sys.platform == 'win32' and sys.version_info.minor >= 11: - raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. ' - 'Please use Linux or MacOS, or Windows with Python <= 3.10.') - - if ProcessPool is None or cpu_count is None: - raise PackageMissingError( - ''' - Please install "pathos" package first. - - >>> pip install pathos - ''' - ) - - # Determine num_process - if num_process is None: - num_process = cpu_count() - elif isinstance(num_process, int): - pass - elif isinstance(num_process, float): - num_process = int(round(num_process * cpu_count())) - else: - raise ValueError('"num_process" must be an int or a float.') - - # arguments - if isinstance(arguments, dict): - keys = list(arguments.keys()) - arguments = list(arguments.values()) - run_f = lambda *args: function(**{key: arg for key, arg in zip(keys, args)}) - else: - if not isinstance(arguments, (tuple, list)): - raise TypeError('"arguments" must be a sequence of Iterable or a dict of Iterable. ' - f'But we got {type(arguments)}') - run_f = function - - # Determine length of tqdm - lengths = [len(iterable) for iterable in arguments if isinstance(iterable, Sized)] - num_task = num_task or (min(lengths) if lengths else None) - - # Create parallel generator - pool = ProcessPool(nodes=num_process) - if ordered: - map_func = pool.imap - else: - map_func = pool.uimap - - # Choose tqdm variant - for item in tqdm(map_func(run_f, *arguments), total=num_task, **tqdm_kwargs): - yield item - - pool.clear() + """Perform a parallel map with a progress bar. + + Parameters:: + + ordered: bool + True for an ordered map, false for an unordered map. + function: callable, function + The function to apply to each element of the given Iterables. + arguments: sequence of Iterable, dict + One or more Iterables containing the data to be mapped. + num_process: int, float + Number of threads used for parallel running. If `int`, it is + the number of threads to be used; if `float`, it is the fraction + of total threads to be used for running. + num_task: int + The total number of tasks in this parallel running. + tqdm_kwargs: Any + The setting for the progress bar. + + Returns:: + + results: Iterable + A generator which will apply the function to each element of the given Iterables + in parallel in order with a progress bar. + """ + if sys.platform == 'win32' and sys.version_info.minor >= 11: + raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. ' + 'Please use Linux or MacOS, or Windows with Python <= 3.10.') + + if ProcessPool is None or cpu_count is None: + raise PackageMissingError( + ''' + Please install "pathos" package first. + + >>> pip install pathos + ''' + ) + + # Determine num_process + if num_process is None: + num_process = cpu_count() + elif isinstance(num_process, int): + pass + elif isinstance(num_process, float): + num_process = int(round(num_process * cpu_count())) + else: + raise ValueError('"num_process" must be an int or a float.') + + # arguments + if isinstance(arguments, dict): + keys = list(arguments.keys()) + arguments = list(arguments.values()) + run_f = lambda *args: function(**{key: arg for key, arg in zip(keys, args)}) + else: + if not isinstance(arguments, (tuple, list)): + raise TypeError('"arguments" must be a sequence of Iterable or a dict of Iterable. ' + f'But we got {type(arguments)}') + run_f = function + + # Determine length of tqdm + lengths = [len(iterable) for iterable in arguments if isinstance(iterable, Sized)] + num_task = num_task or (min(lengths) if lengths else None) + + # Create parallel generator + pool = ProcessPool(nodes=num_process) + if ordered: + map_func = pool.imap + else: + map_func = pool.uimap + + # Choose tqdm variant + for item in tqdm(map_func(run_f, *arguments), total=num_task, **tqdm_kwargs): + yield item + + pool.clear() def cpu_ordered_parallel( @@ -125,55 +126,55 @@ def cpu_ordered_parallel( num_task: Optional[int] = None, **tqdm_kwargs: Any ) -> List[Any]: - """Performs a parallel ordered map with a progress bar. - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import numpy as np - >>> - >>> def simulate(inp): - >>> inp = bm.as_jax(inp) - >>> hh = bp.dyn.HH(1) - >>> runner = bp.DSRunner(hh, inputs=['input', inp], - >>> monitors=['V', 'spike'], - >>> progress_bar=False) - >>> runner.run(100) - >>> bm.clear_buffer_memory() # clear all cached data and functions - >>> return runner.mon.spike.sum() - >>> - >>> if __name__ == '__main__': # This is important! - >>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10) - >>> print(results) - - Parameters:: - - func: callable, function - The function to apply to each element of the given Iterables. - arguments: sequence of Iterable, dict - One or more Iterables containing the data to be mapped. - num_process: int, float - Number of threads used for parallel running. If `int`, it is - the number of threads to be used; if `float`, it is the fraction - of total threads to be used for running. - num_task: int - The total number of tasks in this parallel running. - tqdm_kwargs: Any - The setting for the progress bar. - - Returns:: - - results: list - A list which will apply the function to each element of the given tasks. - """ - generator = _parallel(True, - func, - arguments, - num_process=num_process, - num_task=num_task, - **tqdm_kwargs) - return list(generator) + """Performs a parallel ordered map with a progress bar. + + Examples:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import numpy as np + >>> + >>> def simulate(inp): + >>> inp = bm.as_jax(inp) + >>> hh = bp.dyn.HH(1) + >>> runner = bp.DSRunner(hh, inputs=['input', inp], + >>> monitors=['V', 'spike'], + >>> progress_bar=False) + >>> runner.run(100) + >>> bm.clear_buffer_memory() # clear all cached data and functions + >>> return runner.mon.spike.sum() + >>> + >>> if __name__ == '__main__': # This is important! + >>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10) + >>> print(results) + + Parameters:: + + func: callable, function + The function to apply to each element of the given Iterables. + arguments: sequence of Iterable, dict + One or more Iterables containing the data to be mapped. + num_process: int, float + Number of threads used for parallel running. If `int`, it is + the number of threads to be used; if `float`, it is the fraction + of total threads to be used for running. + num_task: int + The total number of tasks in this parallel running. + tqdm_kwargs: Any + The setting for the progress bar. + + Returns:: + + results: list + A list which will apply the function to each element of the given tasks. + """ + generator = _parallel(True, + func, + arguments, + num_process=num_process, + num_task=num_task, + **tqdm_kwargs) + return list(generator) def cpu_unordered_parallel( @@ -183,52 +184,52 @@ def cpu_unordered_parallel( num_task: Optional[int] = None, **tqdm_kwargs: Any ) -> List[Any]: - """Performs a parallel unordered map with a progress bar. - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import numpy as np - >>> - >>> def simulate(inp): - >>> inp = bm.as_jax(inp) - >>> hh = bp.dyn.HH(1) - >>> runner = bp.DSRunner(hh, inputs=['input', inp], - >>> monitors=['V', 'spike'], - >>> progress_bar=False) - >>> runner.run(100) - >>> bm.clear_buffer_memory() # clear all cached data and functions - >>> return runner.mon.spike.sum() - >>> - >>> if __name__ == '__main__': # This is important! - >>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10) - >>> print(results) - - Parameters:: - - func: callable, function - The function to apply to each element of the given Iterables. - arguments: sequence of Iterable, dict - One or more Iterables containing the data to be mapped. - num_process: int, float - Number of threads used for parallel running. If `int`, it is - the number of threads to be used; if `float`, it is the fraction - of total threads to be used for running. - num_task: int - The total number of tasks in this parallel running. - tqdm_kwargs: Any - The setting for the progress bar. - - Returns:: - - results: list - A list which will apply the function to each element of the given tasks. - """ - generator = _parallel(False, - func, - arguments, - num_process=num_process, - num_task=num_task, - **tqdm_kwargs) - return list(generator) + """Performs a parallel unordered map with a progress bar. + + Examples:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import numpy as np + >>> + >>> def simulate(inp): + >>> inp = bm.as_jax(inp) + >>> hh = bp.dyn.HH(1) + >>> runner = bp.DSRunner(hh, inputs=['input', inp], + >>> monitors=['V', 'spike'], + >>> progress_bar=False) + >>> runner.run(100) + >>> bm.clear_buffer_memory() # clear all cached data and functions + >>> return runner.mon.spike.sum() + >>> + >>> if __name__ == '__main__': # This is important! + >>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10) + >>> print(results) + + Parameters:: + + func: callable, function + The function to apply to each element of the given Iterables. + arguments: sequence of Iterable, dict + One or more Iterables containing the data to be mapped. + num_process: int, float + Number of threads used for parallel running. If `int`, it is + the number of threads to be used; if `float`, it is the fraction + of total threads to be used for running. + num_task: int + The total number of tasks in this parallel running. + tqdm_kwargs: Any + The setting for the progress bar. + + Returns:: + + results: list + A list which will apply the function to each element of the given tasks. + """ + generator = _parallel(False, + func, + arguments, + num_process=num_process, + num_task=num_task, + **tqdm_kwargs) + return list(generator) diff --git a/brainpy/_src/running/runner.py b/brainpy/_src/running/runner.py index db05e7863..2f0cc80ac 100644 --- a/brainpy/_src/running/runner.py +++ b/brainpy/_src/running/runner.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -import gc import types import warnings from typing import Callable, Dict, Sequence, Union @@ -14,252 +13,252 @@ from . import constants as C __all__ = [ - 'Runner', + 'Runner', ] class Runner(BrainPyObject): - """Base Runner. - - Parameters:: - - target: Any - The target model. - - monitors: None, sequence of str, dict, Monitor - Variables to monitor. - - - A list of string. Like ``monitors=['a', 'b', 'c']`` - - A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']`` - - A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}`` - - A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}`` - - A dict with the callable function, like ``monitors={'a': lambda tdi: model.spike[:5]}`` - - .. versionchanged:: 2.3.1 - ``func_monitors`` are merged into ``monitors``. - - fun_monitors: dict - Monitoring variables by a dict of callable functions. - The `key` should be a string for later retrieval by `runner.mon[key]`. - The `value` should be a callable function which receives two arguments: `t` and `dt`. - - .. deprecated:: 2.3.1 - Use ``monitors`` instead. - jit: bool, dict - The JIT settings. - - progress_bar: bool - Use progress bar to report the running progress or not? - - dyn_vars: Optional, Variable, sequence of Variable, dict - The dynamically changed variables. Instance of :py:class:`~.Variable`. - - numpy_mon_after_run : bool - When finishing the network running, transform the JAX arrays into numpy ndarray or not? - """ - - mon: DotDict - '''Monitor data.''' - - jit: Dict[str, bool] - '''Flag to denote whether to use JIT.''' - - def __init__( - self, - target: BrainPyObject, - monitors: Union[Sequence, Dict] = None, - fun_monitors: Dict[str, Callable] = None, - jit: Union[bool, Dict[str, bool]] = True, - progress_bar: bool = True, - dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, - numpy_mon_after_run: bool = True - ): - super().__init__() - # target model, while implement __call__ function - self.target = target - - # jit instruction - self._origin_jit = jit - self.jit = dict() - if isinstance(jit, bool): - self.jit = {C.PREDICT_PHASE: jit} - elif isinstance(jit, dict): - for k, v in jit.items(): - self.jit[k] = v - self.jit[C.PREDICT_PHASE] = jit.pop(C.PREDICT_PHASE, True) - else: - raise ValueError(f'Unknown "jit" setting: {jit}') - - # monitor construction - if monitors is None: - monitors = dict() - elif isinstance(monitors, (list, tuple)): - # format string monitors - monitors = self._format_seq_monitors(monitors) - # get monitor targets - monitors = self._find_seq_monitor_targets(monitors) - elif isinstance(monitors, dict): - # format string monitors - monitors = self._format_dict_monitors(monitors) - # get monitor targets - monitors = self._find_dict_monitor_targets(monitors) - else: - raise MonitorError(f'We only supports a format of list/tuple/dict of ' - f'"vars", while we got {type(monitors)}.') - self._monitors: dict = monitors - - # deprecated func_monitors - if fun_monitors is not None: - if isinstance(fun_monitors, dict): - warnings.warn("`fun_monitors` is deprecated since version 2.3.1. " - "Define `func_monitors` in `monitors`") - check.is_dict_data(fun_monitors, key_type=str, val_type=types.FunctionType) - self._monitors.update(fun_monitors) - - # monitor for user access - self.mon = DotDict() - - # progress bar - assert isinstance(progress_bar, bool), 'Must be a boolean variable.' - self.progress_bar = progress_bar - self._pbar = None - - # dynamical changed variables - self._dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - - # numpy mon after run - self.numpy_mon_after_run = numpy_mon_after_run - - def _format_seq_monitors(self, monitors): - if not isinstance(monitors, (tuple, list)): - raise TypeError(f'Must be a tuple/list, but we got {type(monitors)}') - _monitors = [] - for mon in monitors: - if isinstance(mon, str): - _monitors.append((mon, None)) - elif isinstance(mon, (tuple, list)): - if isinstance(mon[0], str): - if len(mon) == 1: - _monitors.append((mon[0], None)) - elif len(mon) == 2: - if isinstance(mon[1], (int, np.integer)): - idx = bm.array([mon[1]]) - else: - idx = None if mon[1] is None else bm.asarray(mon[1]) - _monitors.append((mon[0], idx)) - else: - raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') - else: - raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') - else: - raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}') - return _monitors - - def _format_dict_monitors(self, monitors): - if not isinstance(monitors, dict): - raise TypeError(f'Must be a dict, but we got {type(monitors)}') - _monitors = dict() - for key, val in monitors.items(): - if not isinstance(key, str): - raise MonitorError('Expect the key of the dict "monitors" must be a string. But got ' - f'{type(key)}: {key}') - if isinstance(val, (bm.Variable, str)): - val = (val, None) - - if isinstance(val, (tuple, list)): - if not isinstance(val[0], (bm.Variable, str)): - raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' - f'But we got {val}') - if len(val) == 1: - _monitors[key] = (val[0], None) - elif len(val) == 2: - if isinstance(val[1], (int, np.integer)): - idx = bm.array([val[1]]) - else: - idx = None if val[1] is None else bm.asarray(val[1]) - _monitors[key] = (val[0], idx) - else: - raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' - f'But we got {val}') - elif callable(val): - _monitors[key] = val - else: - raise MonitorError('The value of dict monitor expect a sequence with (variable, index) ' - f'or a callable function. But we got {val}') - return _monitors - - def _find_seq_monitor_targets(self, _monitors): - if not isinstance(_monitors, (tuple, list)): - raise TypeError(f'Must be a sequence, but we got {type(_monitors)}') - # get monitor targets - monitors = {} - name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())} - for mon in _monitors: - key, index = mon[0], mon[1] - splits = key.split('.') - if len(splits) == 1: - if not hasattr(self.target, splits[0]): - raise RunningError(f'{self.target} does not has variable {key}.') - monitors[key] = (getattr(self.target, splits[-1]), index) - else: - if not hasattr(self.target, splits[0]): - if splits[0] not in name2node: - raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.') - else: - master = name2node[splits[0]] - assert len(splits) == 2 - monitors[key] = (getattr(master, splits[-1]), index) + """Base Runner. + + Parameters:: + + target: Any + The target model. + + monitors: None, sequence of str, dict, Monitor + Variables to monitor. + + - A list of string. Like ``monitors=['a', 'b', 'c']`` + - A list of string with index specification. Like ``monitors=[('a', 1), ('b', [1,3,5]), 'c']`` + - A dict with the explicit monitor target, like: ``monitors={'a': model.spike, 'b': model.V}`` + - A dict with the index specification, like: ``monitors={'a': (model.spike, 0), 'b': (model.V, [1,2])}`` + - A dict with the callable function, like ``monitors={'a': lambda tdi: model.spike[:5]}`` + + .. versionchanged:: 2.3.1 + ``func_monitors`` are merged into ``monitors``. + + fun_monitors: dict + Monitoring variables by a dict of callable functions. + The `key` should be a string for later retrieval by `runner.mon[key]`. + The `value` should be a callable function which receives two arguments: `t` and `dt`. + + .. deprecated:: 2.3.1 + Use ``monitors`` instead. + jit: bool, dict + The JIT settings. + + progress_bar: bool + Use progress bar to report the running progress or not? + + dyn_vars: Optional, Variable, sequence of Variable, dict + The dynamically changed variables. Instance of :py:class:`~.Variable`. + + numpy_mon_after_run : bool + When finishing the network running, transform the JAX arrays into numpy ndarray or not? + """ + + mon: DotDict + '''Monitor data.''' + + jit: Dict[str, bool] + '''Flag to denote whether to use JIT.''' + + def __init__( + self, + target: BrainPyObject, + monitors: Union[Sequence, Dict] = None, + fun_monitors: Dict[str, Callable] = None, + jit: Union[bool, Dict[str, bool]] = True, + progress_bar: bool = True, + dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + numpy_mon_after_run: bool = True + ): + super().__init__() + # target model, while implement __call__ function + self.target = target + + # jit instruction + self._origin_jit = jit + self.jit = dict() + if isinstance(jit, bool): + self.jit = {C.PREDICT_PHASE: jit} + elif isinstance(jit, dict): + for k, v in jit.items(): + self.jit[k] = v + self.jit[C.PREDICT_PHASE] = jit.pop(C.PREDICT_PHASE, True) else: - master = self.target - for s in splits[:-1]: - try: - master = getattr(master, s) - except KeyError: - raise MonitorError(f'Cannot find {key} in {master}, please check.') - monitors[key] = (getattr(master, splits[-1]), index) - return monitors - - def _find_dict_monitor_targets(self, _monitors): - if not isinstance(_monitors, dict): - raise TypeError(f'Must be a dict, but we got {type(_monitors)}') - # get monitor targets - monitors = {} - name2node = None - for _key, _mon in _monitors.items(): - if isinstance(_mon, str): - if name2node is None: - name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())} - - key, index = _mon[0], _mon[1] - splits = key.split('.') - if len(splits) == 1: - if not hasattr(self.target, splits[0]): - raise RunningError(f'{self.target} does not has variable {key}.') - monitors[key] = (getattr(self.target, splits[-1]), index) + raise ValueError(f'Unknown "jit" setting: {jit}') + + # monitor construction + if monitors is None: + monitors = dict() + elif isinstance(monitors, (list, tuple)): + # format string monitors + monitors = self._format_seq_monitors(monitors) + # get monitor targets + monitors = self._find_seq_monitor_targets(monitors) + elif isinstance(monitors, dict): + # format string monitors + monitors = self._format_dict_monitors(monitors) + # get monitor targets + monitors = self._find_dict_monitor_targets(monitors) else: - if not hasattr(self.target, splits[0]): - if splits[0] not in name2node: - raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.') + raise MonitorError(f'We only supports a format of list/tuple/dict of ' + f'"vars", while we got {type(monitors)}.') + self._monitors: dict = monitors + + # deprecated func_monitors + if fun_monitors is not None: + if isinstance(fun_monitors, dict): + warnings.warn("`fun_monitors` is deprecated since version 2.3.1. " + "Define `func_monitors` in `monitors`") + check.is_dict_data(fun_monitors, key_type=str, val_type=types.FunctionType) + self._monitors.update(fun_monitors) + + # monitor for user access + self.mon = DotDict() + + # progress bar + assert isinstance(progress_bar, bool), 'Must be a boolean variable.' + self.progress_bar = progress_bar + self._pbar = None + + # dynamical changed variables + self._dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') + + # numpy mon after run + self.numpy_mon_after_run = numpy_mon_after_run + + def _format_seq_monitors(self, monitors): + if not isinstance(monitors, (tuple, list)): + raise TypeError(f'Must be a tuple/list, but we got {type(monitors)}') + _monitors = [] + for mon in monitors: + if isinstance(mon, str): + _monitors.append((mon, None)) + elif isinstance(mon, (tuple, list)): + if isinstance(mon[0], str): + if len(mon) == 1: + _monitors.append((mon[0], None)) + elif len(mon) == 2: + if isinstance(mon[1], (int, np.integer)): + idx = bm.array([mon[1]]) + else: + idx = None if mon[1] is None else bm.asarray(mon[1]) + _monitors.append((mon[0], idx)) + else: + raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') + else: + raise MonitorError(f'We expect the monitor format with (name, index). But we got {mon}') + else: + raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}') + return _monitors + + def _format_dict_monitors(self, monitors): + if not isinstance(monitors, dict): + raise TypeError(f'Must be a dict, but we got {type(monitors)}') + _monitors = dict() + for key, val in monitors.items(): + if not isinstance(key, str): + raise MonitorError('Expect the key of the dict "monitors" must be a string. But got ' + f'{type(key)}: {key}') + if isinstance(val, (bm.Variable, str)): + val = (val, None) + + if isinstance(val, (tuple, list)): + if not isinstance(val[0], (bm.Variable, str)): + raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' + f'But we got {val}') + if len(val) == 1: + _monitors[key] = (val[0], None) + elif len(val) == 2: + if isinstance(val[1], (int, np.integer)): + idx = bm.array([val[1]]) + else: + idx = None if val[1] is None else bm.asarray(val[1]) + _monitors[key] = (val[0], idx) + else: + raise MonitorError('Expect the format of (variable, index) in the monitor setting. ' + f'But we got {val}') + elif callable(val): + _monitors[key] = val + else: + raise MonitorError('The value of dict monitor expect a sequence with (variable, index) ' + f'or a callable function. But we got {val}') + return _monitors + + def _find_seq_monitor_targets(self, _monitors): + if not isinstance(_monitors, (tuple, list)): + raise TypeError(f'Must be a sequence, but we got {type(_monitors)}') + # get monitor targets + monitors = {} + name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())} + for mon in _monitors: + key, index = mon[0], mon[1] + splits = key.split('.') + if len(splits) == 1: + if not hasattr(self.target, splits[0]): + raise RunningError(f'{self.target} does not has variable {key}.') + monitors[key] = (getattr(self.target, splits[-1]), index) + else: + if not hasattr(self.target, splits[0]): + if splits[0] not in name2node: + raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.') + else: + master = name2node[splits[0]] + assert len(splits) == 2 + monitors[key] = (getattr(master, splits[-1]), index) + else: + master = self.target + for s in splits[:-1]: + try: + master = getattr(master, s) + except KeyError: + raise MonitorError(f'Cannot find {key} in {master}, please check.') + monitors[key] = (getattr(master, splits[-1]), index) + return monitors + + def _find_dict_monitor_targets(self, _monitors): + if not isinstance(_monitors, dict): + raise TypeError(f'Must be a dict, but we got {type(_monitors)}') + # get monitor targets + monitors = {} + name2node = None + for _key, _mon in _monitors.items(): + if isinstance(_mon, str): + if name2node is None: + name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())} + + key, index = _mon[0], _mon[1] + splits = key.split('.') + if len(splits) == 1: + if not hasattr(self.target, splits[0]): + raise RunningError(f'{self.target} does not has variable {key}.') + monitors[key] = (getattr(self.target, splits[-1]), index) + else: + if not hasattr(self.target, splits[0]): + if splits[0] not in name2node: + raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.') + else: + master = name2node[splits[0]] + assert len(splits) == 2 + monitors[key] = (getattr(master, splits[-1]), index) + else: + master = self.target + for s in splits[:-1]: + try: + master = getattr(master, s) + except KeyError: + raise MonitorError(f'Cannot find {key} in {master}, please check.') + monitors[key] = (getattr(master, splits[-1]), index) else: - master = name2node[splits[0]] - assert len(splits) == 2 - monitors[key] = (getattr(master, splits[-1]), index) - else: - master = self.target - for s in splits[:-1]: - try: - master = getattr(master, s) - except KeyError: - raise MonitorError(f'Cannot find {key} in {master}, please check.') - monitors[key] = (getattr(master, splits[-1]), index) - else: - monitors[_key] = _mon - return monitors - - def __del__(self): - if hasattr(self, 'mon'): - for key in tuple(self.mon.keys()): - del self.mon[key] - for key in tuple(self.__dict__.keys()): - del self.__dict__[key] - # gc.collect() + monitors[_key] = _mon + return monitors + + def __del__(self): + if hasattr(self, 'mon'): + for key in tuple(self.mon.keys()): + del self.mon[key] + for key in tuple(self.__dict__.keys()): + del self.__dict__[key] + # gc.collect() diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py index 6f92bda7e..829c4658f 100644 --- a/brainpy/_src/running/tests/test_pathos_multiprocessing.py +++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py @@ -8,34 +8,34 @@ import brainpy.math as bm if sys.platform == 'win32' and sys.version_info.minor >= 11: - pytest.skip('python 3.11 does not support.', allow_module_level=True) + pytest.skip('python 3.11 does not support.', allow_module_level=True) else: - pytest.skip('Cannot pass tests.', allow_module_level=True) + pytest.skip('Cannot pass tests.', allow_module_level=True) class TestParallel(parameterized.TestCase): - @parameterized.product( - duration=[1e2, 1e3, 1e4, 1e5] - ) - def test_cpu_unordered_parallel_v1(self, duration): - @jax.jit - def body(inp): - return bm.for_loop(lambda x: x + 1e-9, inp) - - input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 - - r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) - assert bm.allclose(r[0], r[1]) - - @parameterized.product( - duration=[1e2, 1e3, 1e4, 1e5] - ) - def test_cpu_unordered_parallel_v2(self, duration): - @jax.jit - def body(inp): - return bm.for_loop(lambda x: x + 1e-9, inp) - - input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 - - r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) - assert bm.allclose(r[0], r[1]) + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v1(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) + + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v2(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) diff --git a/brainpy/_src/tests/test_access_methods.py b/brainpy/_src/tests/test_access_methods.py index 6d2109cbd..71eeaf83f 100644 --- a/brainpy/_src/tests/test_access_methods.py +++ b/brainpy/_src/tests/test_access_methods.py @@ -1,123 +1,124 @@ # -*- coding: utf-8 -*- import jax.numpy as jnp + import brainpy as bp bp.ode.set_default_odeint('rk4') class GABAa(bp.synapses.TwoEndConn): - def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., - alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): - super(GABAa, self).__init__(pre=pre, post=post, conn=conn, **kwargs) + def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., + alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): + super(GABAa, self).__init__(pre=pre, post=post, conn=conn, **kwargs) - # parameters - self.g_max = g_max - self.E = E - self.alpha = alpha - self.beta = beta - self.T = T - self.T_duration = T_duration - self.delay = delay + # parameters + self.g_max = g_max + self.E = E + self.alpha = alpha + self.beta = beta + self.T = T + self.T_duration = T_duration + self.delay = delay - # connections - self.conn_mat = self.conn.requires('conn_mat') - self.size = jnp.shape(self.conn_mat) + # connections + self.conn_mat = self.conn.requires('conn_mat') + self.size = jnp.shape(self.conn_mat) - # variables - self.t_last_pre_spike = jnp.ones(self.size) * -1e7 - self.s = jnp.zeros(self.size) + # variables + self.t_last_pre_spike = jnp.ones(self.size) * -1e7 + self.s = jnp.zeros(self.size) - self.int_s = bp.odeint(self.dev) + self.int_s = bp.odeint(self.dev) - def dev(self, s, t, TT, alpha, beta): - return alpha * TT * (1 - s) - beta * s + def dev(self, s, t, TT, alpha, beta): + return alpha * TT * (1 - s) - beta * s - def update(self, t, dt, **kwargs): - spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat - self.t_last_pre_spike = jnp.where(spike, t, self.t_last_pre_spike) - TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T - self.s = self.int_s(self.s, t, TT, self.alpha, self.beta) - self.post.inputs -= jnp.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) + def update(self, t, dt, **kwargs): + spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat + self.t_last_pre_spike = jnp.where(spike, t, self.t_last_pre_spike) + TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T + self.s = self.int_s(self.s, t, TT, self.alpha, self.beta) + self.post.inputs -= jnp.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) class HH(bp.dyn.NeuGroup): - def __init__(self, size, ENa=55., EK=-90., EL=-65, - C=1.0, gNa=35., gK=9., gL=0.1, V_th=20., - phi=5.0, **kwargs): - super(HH, self).__init__(size=size, **kwargs) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - self.phi = phi - - # variables - self.V = jnp.ones(self.num) * -65. - self.h = jnp.ones(self.num) * 0.6 - self.n = jnp.ones(self.num) * 0.32 - self.spikes = jnp.zeros(self.num) - self.inputs = jnp.zeros(self.num) - - self.integral = bp.odeint(self.dev) - - def dev(self, V, h, n, t, Iext): - alpha = 0.07 * jnp.exp(-(V + 58) / 20) - beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) - dhdt = alpha * (1 - h) - beta * h - - alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * jnp.exp(-(V + 44) / 80) - dndt = alpha * (1 - n) - beta * n - - m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * jnp.exp(-(V + 60) / 18) - m = m_alpha / (m_alpha + m_beta) - INa = self.gNa * m ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + Iext) / self.C - - return dVdt, self.phi * dhdt, self.phi * dndt - - def update(self, t, _i, **kwargs): - V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) - self.spikes[:] = (self.V < self.V_th) * (V >= self.V_th) - self.V[:] = V - self.h[:] = h - self.n[:] = n - self.inputs[:] = 0 + def __init__(self, size, ENa=55., EK=-90., EL=-65, + C=1.0, gNa=35., gK=9., gL=0.1, V_th=20., + phi=5.0, **kwargs): + super(HH, self).__init__(size=size, **kwargs) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = jnp.ones(self.num) * -65. + self.h = jnp.ones(self.num) * 0.6 + self.n = jnp.ones(self.num) * 0.32 + self.spikes = jnp.zeros(self.num) + self.inputs = jnp.zeros(self.num) + + self.integral = bp.odeint(self.dev) + + def dev(self, V, h, n, t, Iext): + alpha = 0.07 * jnp.exp(-(V + 58) / 20) + beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) + dhdt = alpha * (1 - h) - beta * h + + alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * jnp.exp(-(V + 44) / 80) + dndt = alpha * (1 - n) - beta * n + + m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * jnp.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + return dVdt, self.phi * dhdt, self.phi * dndt + + def update(self, t, _i, **kwargs): + V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) + self.spikes[:] = (self.V < self.V_th) * (V >= self.V_th) + self.V[:] = V + self.h[:] = h + self.n[:] = n + self.inputs[:] = 0 def test1(): - bp.math.random.seed(123) - num = 10 - neu = HH(num) - neu.V = -70. + bp.math.random.normal(size=num) * 20 - - syn = GABAa(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) - syn.g_max = 0.1 / num - - net = bp.Network(neu=neu, syn=syn) - - for method in ['relative', 'absolute']: - print(f'Method: {method}\n') - print('vars:') - print('-----') - print('neu.vars()', list(neu.vars(method).keys())) - print('syn.vars()', list(syn.vars(method).keys())) - print('net.vars()', list(net.vars(method).keys())) - print() - - print('nodes:') - print('------') - print('neu.nodes()', list(neu.nodes(method).keys())) - print('syn.nodes()', list(syn.nodes(method).keys())) - print('net.nodes()', list(net.nodes(method).keys())) - print() + bp.math.random.seed(123) + num = 10 + neu = HH(num) + neu.V = -70. + bp.math.random.normal(size=num) * 20 + + syn = GABAa(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) + syn.g_max = 0.1 / num + + net = bp.Network(neu=neu, syn=syn) + + for method in ['relative', 'absolute']: + print(f'Method: {method}\n') + print('vars:') + print('-----') + print('neu.vars()', list(neu.vars(method).keys())) + print('syn.vars()', list(syn.vars(method).keys())) + print('net.vars()', list(net.vars(method).keys())) + print() + + print('nodes:') + print('------') + print('neu.nodes()', list(neu.nodes(method).keys())) + print('syn.nodes()', list(syn.nodes(method).keys())) + print('net.nodes()', list(net.nodes(method).keys())) + print() diff --git a/brainpy/_src/tests/test_base_classes.py b/brainpy/_src/tests/test_base_classes.py index 83602cc18..6ac301fb7 100644 --- a/brainpy/_src/tests/test_base_classes.py +++ b/brainpy/_src/tests/test_base_classes.py @@ -3,68 +3,52 @@ import unittest import brainpy as bp -import brainpy.math as bm class TestDynamicalSystem(unittest.TestCase): - def test_delay(self): - A = bp.neurons.LIF(1) - B = bp.neurons.LIF(1) - C = bp.neurons.LIF(1) - A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1) - A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None) - net = bp.Network(A, B, C, A2B, A2C) - - runner = bp.DSRunner(net,) - runner.run(10.) - - - - def test_receive_update_output(self): - def aft_update(inp): - assert inp is not None - - hh = bp.dyn.HH(1) - hh.add_aft_update('aft_update', aft_update) - bp.share.save(i=0, t=0.) - hh(1.) - - - - def test_do_not_receive_update_output(self): - def aft_update(): - pass - - hh = bp.dyn.HH(1) - hh.add_aft_update('aft_update', bp.not_receive_update_output(aft_update)) - bp.share.save(i=0, t=0.) - hh(1.) - - - - def test_not_receive_update_input(self): - def bef_update(): - pass - - hh = bp.dyn.HH(1) - hh.add_bef_update('bef_update', bef_update) - bp.share.save(i=0, t=0.) - hh(1.) - - - - def test_receive_update_input(self): - def bef_update(inp): - assert inp is not None - - hh = bp.dyn.HH(1) - hh.add_bef_update('bef_update', bp.receive_update_input(bef_update)) - bp.share.save(i=0, t=0.) - hh(1.) - - - - - - - + def test_delay(self): + A = bp.neurons.LIF(1) + B = bp.neurons.LIF(1) + C = bp.neurons.LIF(1) + A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1) + A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None) + net = bp.Network(A, B, C, A2B, A2C) + + runner = bp.DSRunner(net, ) + runner.run(10.) + + def test_receive_update_output(self): + def aft_update(inp): + assert inp is not None + + hh = bp.dyn.HH(1) + hh.add_aft_update('aft_update', aft_update) + bp.share.save(i=0, t=0.) + hh(1.) + + def test_do_not_receive_update_output(self): + def aft_update(): + pass + + hh = bp.dyn.HH(1) + hh.add_aft_update('aft_update', bp.not_receive_update_output(aft_update)) + bp.share.save(i=0, t=0.) + hh(1.) + + def test_not_receive_update_input(self): + def bef_update(): + pass + + hh = bp.dyn.HH(1) + hh.add_bef_update('bef_update', bef_update) + bp.share.save(i=0, t=0.) + hh(1.) + + def test_receive_update_input(self): + def bef_update(inp): + assert inp is not None + + hh = bp.dyn.HH(1) + hh.add_bef_update('bef_update', bp.receive_update_input(bef_update)) + bp.share.save(i=0, t=0.) + hh(1.) diff --git a/brainpy/_src/tests/test_brainpy_deprecations.py b/brainpy/_src/tests/test_brainpy_deprecations.py index 9c38d485e..8e3c838fc 100644 --- a/brainpy/_src/tests/test_brainpy_deprecations.py +++ b/brainpy/_src/tests/test_brainpy_deprecations.py @@ -1,4 +1,5 @@ from absl.testing import parameterized + import brainpy bp_deprecated_names = list(brainpy.__deprecations.keys()) @@ -9,44 +10,44 @@ class Test(parameterized.TestCase): - @parameterized.product( - name=bp_deprecated_names - ) - def test_brainpy(self, name): - with self.assertWarns(DeprecationWarning): - getattr(brainpy, name) - - @parameterized.product( - name=mode_deprecated_names - ) - def test_brainpy_modes(self, name): - with self.assertWarns(DeprecationWarning): - getattr(brainpy.modes, name) - - @parameterized.product( - name=tools_deprecated_names - ) - def test_brainpy_tools(self, name): - with self.assertWarns(DeprecationWarning): - getattr(brainpy.tools, name) - - @parameterized.product( - name=train_deprecated_names - ) - def test_brainpy_train(self, name): - with self.assertWarns(DeprecationWarning): - getattr(brainpy.train, name) - - # @parameterized.product( - # name=dyn_deprecated_names - # ) - # def test_brainpy_dyn(self, name): - # with self.assertWarns(DeprecationWarning): - # getattr(brainpy.dyn, name) - # - @parameterized.product( - name=intg_deprecated_names - ) - def test_brainpy_intg(self, name): - with self.assertWarns(DeprecationWarning): - getattr(brainpy.integrators, name) + @parameterized.product( + name=bp_deprecated_names + ) + def test_brainpy(self, name): + with self.assertWarns(DeprecationWarning): + getattr(brainpy, name) + + @parameterized.product( + name=mode_deprecated_names + ) + def test_brainpy_modes(self, name): + with self.assertWarns(DeprecationWarning): + getattr(brainpy.modes, name) + + @parameterized.product( + name=tools_deprecated_names + ) + def test_brainpy_tools(self, name): + with self.assertWarns(DeprecationWarning): + getattr(brainpy.tools, name) + + @parameterized.product( + name=train_deprecated_names + ) + def test_brainpy_train(self, name): + with self.assertWarns(DeprecationWarning): + getattr(brainpy.train, name) + + # @parameterized.product( + # name=dyn_deprecated_names + # ) + # def test_brainpy_dyn(self, name): + # with self.assertWarns(DeprecationWarning): + # getattr(brainpy.dyn, name) + # + @parameterized.product( + name=intg_deprecated_names + ) + def test_brainpy_intg(self, name): + with self.assertWarns(DeprecationWarning): + getattr(brainpy.integrators, name) diff --git a/brainpy/_src/tests/test_check.py b/brainpy/_src/tests/test_check.py index a04105486..9d24f2f31 100644 --- a/brainpy/_src/tests/test_check.py +++ b/brainpy/_src/tests/test_check.py @@ -7,45 +7,45 @@ class TestUtils(unittest.TestCase): - def test_check_shape(self): - all_shapes = [ - (1, 2, 3), - (1, 4), - (10, 2, 4) - ] - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=-1) - self.assertEqual(free_shape, [3, 4, 4]) - self.assertEqual(fixed_shapes, [10, 2]) - - def test_check_shape2(self): - all_shapes = [ - (1, 2, 3, 8,), - (10, 1, 4, 10), - (10, 2, 4, 100) - ] - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[2, -1]) - print(free_shape) - print(fixed_shapes) - self.assertEqual(free_shape, [[3, 8], [4, 10], [4, 100]]) - self.assertEqual(fixed_shapes, [10, 2]) - - def test_check_shape3(self): - all_shapes = [ - (1, 2, 3, 8,), - (10, 1, 4, 10), - (10, 2, 4, 100) - ] - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, 2, -1]) - print(free_shape) - print(fixed_shapes) - self.assertEqual(free_shape, [[1, 3, 8], [10, 4, 10], [10, 4, 100]]) - self.assertEqual(fixed_shapes, [2]) - - def test_check_shape4(self): - all_shapes = [ - (1, 2, 3, 8,), - (10, 1, 4, 10), - (10, 2, 4, 100) - ] - with self.assertRaises(ValueError): - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, -1]) + def test_check_shape(self): + all_shapes = [ + (1, 2, 3), + (1, 4), + (10, 2, 4) + ] + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=-1) + self.assertEqual(free_shape, [3, 4, 4]) + self.assertEqual(fixed_shapes, [10, 2]) + + def test_check_shape2(self): + all_shapes = [ + (1, 2, 3, 8,), + (10, 1, 4, 10), + (10, 2, 4, 100) + ] + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[2, -1]) + print(free_shape) + print(fixed_shapes) + self.assertEqual(free_shape, [[3, 8], [4, 10], [4, 100]]) + self.assertEqual(fixed_shapes, [10, 2]) + + def test_check_shape3(self): + all_shapes = [ + (1, 2, 3, 8,), + (10, 1, 4, 10), + (10, 2, 4, 100) + ] + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, 2, -1]) + print(free_shape) + print(fixed_shapes) + self.assertEqual(free_shape, [[1, 3, 8], [10, 4, 10], [10, 4, 100]]) + self.assertEqual(fixed_shapes, [2]) + + def test_check_shape4(self): + all_shapes = [ + (1, 2, 3, 8,), + (10, 1, 4, 10), + (10, 2, 4, 100) + ] + with self.assertRaises(ValueError): + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, -1]) diff --git a/brainpy/_src/tests/test_delay.py b/brainpy/_src/tests/test_delay.py index b2eb97690..7d23e4d23 100644 --- a/brainpy/_src/tests/test_delay.py +++ b/brainpy/_src/tests/test_delay.py @@ -6,53 +6,52 @@ class TestVarDelay(unittest.TestCase): - def test_delay1(self): - bp.math.random.seed() - a = bp.math.Variable((10, 20)) - delay = bp.VarDelay(a, ) - delay.register_entry('a', 1.) - delay.register_entry('b', 2.) - delay.register_entry('c', None) - with self.assertRaises(KeyError): - delay.register_entry('c', 10.) - - def test_rotation_delay(self): - a = bp.math.Variable((1,)) - rotation_delay = bp.VarDelay(a) - t0 = 0. - t1, n1 = 1., 10 - t2, n2 = 2., 20 - - rotation_delay.register_entry('a', t0) - rotation_delay.register_entry('b', t1) - rotation_delay.register_entry('c', t2) - - print() - for i in range(100): - bp.share.save(i=i) - a.value = jnp.ones((1,)) * i - rotation_delay() - self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) - self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) - self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) - - def test_concat_delay(self): - a = bp.math.Variable((1,)) - rotation_delay = bp.VarDelay(a, method='concat') - t0 = 0. - t1, n1 = 1., 10 - t2, n2 = 2., 20 - - rotation_delay.register_entry('a', t0) - rotation_delay.register_entry('b', t1) - rotation_delay.register_entry('c', t2) - - print() - for i in range(100): - bp.share.save(i=i) - a.value = jnp.ones((1,)) * i - rotation_delay() - self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) - self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) - self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) - \ No newline at end of file + def test_delay1(self): + bp.math.random.seed() + a = bp.math.Variable((10, 20)) + delay = bp.VarDelay(a, ) + delay.register_entry('a', 1.) + delay.register_entry('b', 2.) + delay.register_entry('c', None) + with self.assertRaises(KeyError): + delay.register_entry('c', 10.) + + def test_rotation_delay(self): + a = bp.math.Variable((1,)) + rotation_delay = bp.VarDelay(a) + t0 = 0. + t1, n1 = 1., 10 + t2, n2 = 2., 20 + + rotation_delay.register_entry('a', t0) + rotation_delay.register_entry('b', t1) + rotation_delay.register_entry('c', t2) + + print() + for i in range(100): + bp.share.save(i=i) + a.value = jnp.ones((1,)) * i + rotation_delay() + self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) + self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) + + def test_concat_delay(self): + a = bp.math.Variable((1,)) + rotation_delay = bp.VarDelay(a, method='concat') + t0 = 0. + t1, n1 = 1., 10 + t2, n2 = 2., 20 + + rotation_delay.register_entry('a', t0) + rotation_delay.register_entry('b', t1) + rotation_delay.register_entry('c', t2) + + print() + for i in range(100): + bp.share.save(i=i) + a.value = jnp.ones((1,)) * i + rotation_delay() + self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) + self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py index c2f422ee4..97586a0ce 100644 --- a/brainpy/_src/tests/test_dyn_runner.py +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -1,81 +1,80 @@ # -*- coding: utf-8 -*- -import pytest import unittest + import brainpy as bp import brainpy.math as bm class TestDSRunner(unittest.TestCase): - def test1(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 - - ds = ExampleDS() - runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_t_and_dt(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 * bp.share['dt'] - - runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_DSView(self): - class EINet(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet, self).__init__() - - # network size - num_exc = int(800 * scale) - num_inh = int(200 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - - bm.random.seed() - - net = EINet(scale=1., method='exp_auto') - # with JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) - - # without JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) + def test1(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 + + ds = ExampleDS() + runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_t_and_dt(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 * bp.share['dt'] + + runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_DSView(self): + class EINet(bp.Network): + def __init__(self, scale=1.0, method='exp_auto'): + super(EINet, self).__init__() + + # network size + num_exc = int(800 * scale) + num_inh = int(200 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) + self.E = bp.neurons.LIF(num_exc, **pars, method=method) + self.I = bp.neurons.LIF(num_inh, **pars, method=method) + self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + + bm.random.seed() + + net = EINet(scale=1., method='exp_auto') + # with JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) + + # without JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) class TestMemoryEfficient(unittest.TestCase): - pass - + pass diff --git a/brainpy/_src/tests/test_dynsys.py b/brainpy/_src/tests/test_dynsys.py index f8605380e..8c087b25e 100644 --- a/brainpy/_src/tests/test_dynsys.py +++ b/brainpy/_src/tests/test_dynsys.py @@ -4,52 +4,52 @@ def test1(): - class A(bp.DynamicalSystem): - def update(self, x=None): - # print(tdi) - print(x) + class A(bp.DynamicalSystem): + def update(self, x=None): + # print(tdi) + print(x) - A()({}, 10.) + A()({}, 10.) def test2(): - class B(bp.DynamicalSystem): - def update(self, tdi, x=None): - print(tdi) - print(x) + class B(bp.DynamicalSystem): + def update(self, tdi, x=None): + print(tdi) + print(x) - B()({}, 10.) - B()(10.) + B()({}, 10.) + B()(10.) def test3(): - class A(bp.DynamicalSystem): - def update(self, x=None): - # print(tdi) - print('A:', x) + class A(bp.DynamicalSystem): + def update(self, x=None): + # print(tdi) + print('A:', x) - class B(A): - def update(self, tdi, x=None): - print('B:', tdi, x) - super().update(x) + class B(A): + def update(self, tdi, x=None): + print('B:', tdi, x) + super().update(x) - B()(dict(), 1.) - B()(1.) + B()(dict(), 1.) + B()(1.) class TestResetLevelDecorator(unittest.TestCase): - _max_level = 10 # Define the maximum level for testing purposes + _max_level = 10 # Define the maximum level for testing purposes - @bp.reset_level(5) - def test_function_with_reset_level_5(self): - self.assertEqual(self.test_function_with_reset_level_5.reset_level, 5) + @bp.reset_level(5) + def test_function_with_reset_level_5(self): + self.assertEqual(self.test_function_with_reset_level_5.reset_level, 5) - def test1(self): - with self.assertRaises(ValueError): - @bp.reset_level(12) # This should raise a ValueError - def test_function_with_invalid_reset_level(self): - pass # Call the function here to trigger the ValueError + def test1(self): + with self.assertRaises(ValueError): + @bp.reset_level(12) # This should raise a ValueError + def test_function_with_invalid_reset_level(self): + pass # Call the function here to trigger the ValueError - @bp.reset_level(-3) - def test_function_with_negative_reset_level(self): - self.assertEqual(self.test_function_with_negative_reset_level.reset_level, self._max_level - 3) + @bp.reset_level(-3) + def test_function_with_negative_reset_level(self): + self.assertEqual(self.test_function_with_negative_reset_level.reset_level, self._max_level - 3) diff --git a/brainpy/_src/tests/test_helper.py b/brainpy/_src/tests/test_helper.py index d8c85010b..9fe154356 100644 --- a/brainpy/_src/tests/test_helper.py +++ b/brainpy/_src/tests/test_helper.py @@ -1,30 +1,28 @@ -import brainpy as bp - import unittest +import brainpy as bp -class TestResetLevel(unittest.TestCase): - - def test1(self): - class Level0(bp.DynamicalSystem): - @bp.reset_level(0) - def reset_state(self, *args, **kwargs): - print('Level 0') - - class Level1(bp.DynamicalSystem): - @bp.reset_level(1) - def reset_state(self, *args, **kwargs): - print('Level 1') - - class Net(bp.DynamicalSystem): - def __init__(self): - super().__init__() - self.l0 = Level0() - self.l1 = Level1() - self.l0_2 = Level0() - self.l1_2 = Level1() - - net = Net() - net.reset() +class TestResetLevel(unittest.TestCase): + def test1(self): + class Level0(bp.DynamicalSystem): + @bp.reset_level(0) + def reset_state(self, *args, **kwargs): + print('Level 0') + + class Level1(bp.DynamicalSystem): + @bp.reset_level(1) + def reset_state(self, *args, **kwargs): + print('Level 1') + + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() + self.l0 = Level0() + self.l1 = Level1() + self.l0_2 = Level0() + self.l1_2 = Level1() + + net = Net() + net.reset() diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index e864fd647..b6023b7a1 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -1,53 +1,50 @@ +import unittest + import brainpy as bp import brainpy.math as bm -import unittest - class TestParamDesc(unittest.TestCase): - def test1(self): - a = bp.dyn.Expon(1) - self.assertTrue(not isinstance(a, bp.mixin.ParamDescriber[bp.dyn.Expon])) - self.assertTrue(not isinstance(a, bp.mixin.ParamDescriber[bp.DynamicalSystem])) + def test1(self): + a = bp.dyn.Expon(1) + self.assertTrue(not isinstance(a, bp.mixin.ParamDescriber[bp.dyn.Expon])) + self.assertTrue(not isinstance(a, bp.mixin.ParamDescriber[bp.DynamicalSystem])) - def test2(self): - a = bp.dyn.Expon.desc(1) - self.assertTrue(isinstance(a, bp.mixin.ParamDescriber[bp.dyn.Expon])) - self.assertTrue(isinstance(a, bp.mixin.ParamDescriber[bp.DynamicalSystem])) + def test2(self): + a = bp.dyn.Expon.desc(1) + self.assertTrue(isinstance(a, bp.mixin.ParamDescriber[bp.dyn.Expon])) + self.assertTrue(isinstance(a, bp.mixin.ParamDescriber[bp.DynamicalSystem])) class TestJointType(unittest.TestCase): - def test1(self): - T = bp.mixin.JointType[bp.DynamicalSystem] - self.assertTrue(isinstance(bp.dnn.Layer(), T)) + def test1(self): + T = bp.mixin.JointType[bp.DynamicalSystem] + self.assertTrue(isinstance(bp.dnn.Layer(), T)) - T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc] - self.assertTrue(isinstance(bp.dyn.Expon(1), T)) + T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc] + self.assertTrue(isinstance(bp.dyn.Expon(1), T)) - def test2(self): - T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc] - self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDescriber[T])) - self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDescriber[T])) + def test2(self): + T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc] + self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDescriber[T])) + self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDescriber[T])) class TestDelayRegister(unittest.TestCase): - # def test11(self): - # lif = bp.dyn.Lif(10) - # with self.assertWarns(UserWarning): - # lif.register_delay('pre.spike', 10, lif.spike) - # - # with self.assertWarns(UserWarning): - # lif.get_delay_data('pre.spike', 10) - - def test2(self): - bp.share.save(i=0) - lif = bp.dyn.Lif(10) - lif.register_local_delay('spike', 'a', delay_time=10.) - data = lif.get_local_delay('spike', 'a') - self.assertTrue(bm.allclose(data, bm.zeros(10))) - - with self.assertRaises(AttributeError): - lif.register_local_delay('a', 'a', 10.) - - - + # def test11(self): + # lif = bp.dyn.Lif(10) + # with self.assertWarns(UserWarning): + # lif.register_delay('pre.spike', 10, lif.spike) + # + # with self.assertWarns(UserWarning): + # lif.get_delay_data('pre.spike', 10) + + def test2(self): + bp.share.save(i=0) + lif = bp.dyn.Lif(10) + lif.register_local_delay('spike', 'a', delay_time=10.) + data = lif.get_local_delay('spike', 'a') + self.assertTrue(bm.allclose(data, bm.zeros(10))) + + with self.assertRaises(AttributeError): + lif.register_local_delay('a', 'a', 10.) diff --git a/brainpy/_src/tests/test_network.py b/brainpy/_src/tests/test_network.py index 5ef5c2026..553dfdabb 100644 --- a/brainpy/_src/tests/test_network.py +++ b/brainpy/_src/tests/test_network.py @@ -1,51 +1,48 @@ -import brainpy as bp import unittest +import brainpy as bp -class TestNetDefinition(unittest.TestCase): - def test_define_net1(self): - E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., - tau=20., tau_ref=5., method='exp_auto', - V_initializer=bp.init.Normal(-60., 2.)) - - I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60., - tau=20., tau_ref=5., method='exp_auto', - V_initializer=bp.init.Normal(-60., 2.)) - - E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6, - tau=5., output=bp.synouts.COBA(E=0.), - method='exp_auto') - E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6, - tau=5., output=bp.synouts.COBA(E=0.), - method='exp_auto') +class TestNetDefinition(unittest.TestCase): + def test_define_net1(self): + E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., + tau=20., tau_ref=5., method='exp_auto', + V_initializer=bp.init.Normal(-60., 2.)) - I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7, - tau=10., output=bp.synouts.COBA(E=-80.), - method='exp_auto') + I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60., + tau=20., tau_ref=5., method='exp_auto', + V_initializer=bp.init.Normal(-60., 2.)) - I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7, - tau=10., output=bp.synouts.COBA(E=-80.), - method='exp_auto') + E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6, + tau=5., output=bp.synouts.COBA(E=0.), + method='exp_auto') - net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I) + E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6, + tau=5., output=bp.synouts.COBA(E=0.), + method='exp_auto') - runner1 = bp.DSRunner(net, - monitors=['E.spike', 'I.spike'], - inputs=[('E.input', 20.), ('I.input', 20.)]) + I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7, + tau=10., output=bp.synouts.COBA(E=-80.), + method='exp_auto') - runner2 = bp.DSRunner(net, - monitors=[('E.spike', E.spike), ('I.spike', I.spike)], - inputs=[(E.input, 20.), (I.input, 20.)]) + I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7, + tau=10., output=bp.synouts.COBA(E=-80.), + method='exp_auto') - runner3 = bp.DSRunner(net, - monitors=[('E.spike', E.spike), 'I.spike'], - inputs=[(E.input, 20.), (I.input, 20.)]) + net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I) - runner4 = bp.DSRunner(net, - monitors={'E.spike': E.spike, 'I.spike': I.spike}, - inputs=[(E.input, 20.), (I.input, 20.)]) + runner1 = bp.DSRunner(net, + monitors=['E.spike', 'I.spike'], + inputs=[('E.input', 20.), ('I.input', 20.)]) - + runner2 = bp.DSRunner(net, + monitors=[('E.spike', E.spike), ('I.spike', I.spike)], + inputs=[(E.input, 20.), (I.input, 20.)]) + runner3 = bp.DSRunner(net, + monitors=[('E.spike', E.spike), 'I.spike'], + inputs=[(E.input, 20.), (I.input, 20.)]) + runner4 = bp.DSRunner(net, + monitors={'E.spike': E.spike, 'I.spike': I.spike}, + inputs=[(E.input, 20.), (I.input, 20.)]) diff --git a/brainpy/_src/tests/test_pickle.py b/brainpy/_src/tests/test_pickle.py index bc2c77f1c..26cc92f9a 100644 --- a/brainpy/_src/tests/test_pickle.py +++ b/brainpy/_src/tests/test_pickle.py @@ -1,22 +1,21 @@ # -*- coding: utf-8 -*- -import brainpy as bp - +import pickle import unittest -import pickle +import brainpy as bp class TestPickle(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestPickle, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs): + super(TestPickle, self).__init__(*args, **kwargs) - self.pre = bp.neurons.LIF(10) - self.post = bp.neurons.LIF(20) - self.syn = bp.synapses.TwoEndConn(self.pre, self.post, bp.conn.FixedProb(0.2)) - self.net = bp.DynSysGroup(self.pre, self.post, self.syn) + self.pre = bp.neurons.LIF(10) + self.post = bp.neurons.LIF(20) + self.syn = bp.synapses.TwoEndConn(self.pre, self.post, bp.conn.FixedProb(0.2)) + self.net = bp.DynSysGroup(self.pre, self.post, self.syn) - def test_net(self): - self.skipTest('Currently do not support') - with open('data/net.pickle', 'wb') as f: - pickle.dump(self.net, f) + def test_net(self): + self.skipTest('Currently do not support') + with open('data/net.pickle', 'wb') as f: + pickle.dump(self.net, f) diff --git a/brainpy/_src/tests/test_slice_view.py b/brainpy/_src/tests/test_slice_view.py index 1383c1a6c..5a12763d5 100644 --- a/brainpy/_src/tests/test_slice_view.py +++ b/brainpy/_src/tests/test_slice_view.py @@ -1,47 +1,47 @@ # -*- coding: utf-8 -*- +import unittest + import brainpy as bp import brainpy.math as bm -import unittest class TestSliceView(unittest.TestCase): - def test_lif(self): - lif = bp.neurons.LIF(10) - lif_tile = lif[5:] - print(lif_tile.V.shape) - print(lif_tile.varshape) - - print('Before modification: ') - print(lif.V) - lif_tile.V += 10. - - self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 10.]))) - print('After modification 1: ') - print(lif.V) - - lif_tile.V.value = bm.ones(5) * 40. - self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 40.]))) - print('After modification 2: ') - print(lif.V) - - def test_lif_train_mode(self): - lif = bp.neurons.LIF(10, mode=bm.training_mode) - lif_tile = lif[5:] - print(lif_tile.V.shape) - print(lif_tile.varshape) - - print('Before modification: ') - print(lif.V) - lif_tile.V += 10. - - self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 10.]))) - print('After modification 1: ') - print(lif.V) - - lif_tile.V.value = bm.ones((1, 5)) * 40. - self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 40.]))) - print('After modification 2: ') - print(lif.V) - + def test_lif(self): + lif = bp.neurons.LIF(10) + lif_tile = lif[5:] + print(lif_tile.V.shape) + print(lif_tile.varshape) + + print('Before modification: ') + print(lif.V) + lif_tile.V += 10. + + self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 10.]))) + print('After modification 1: ') + print(lif.V) + + lif_tile.V.value = bm.ones(5) * 40. + self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 40.]))) + print('After modification 2: ') + print(lif.V) + + def test_lif_train_mode(self): + lif = bp.neurons.LIF(10, mode=bm.training_mode) + lif_tile = lif[5:] + print(lif_tile.V.shape) + print(lif_tile.varshape) + + print('Before modification: ') + print(lif.V) + lif_tile.V += 10. + + self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 10.]))) + print('After modification 1: ') + print(lif.V) + + lif_tile.V.value = bm.ones((1, 5)) * 40. + self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 40.]))) + print('After modification 2: ') + print(lif.V) diff --git a/brainpy/_src/tools/__init__.py b/brainpy/_src/tools/__init__.py index 99d278f24..020766807 100644 --- a/brainpy/_src/tools/__init__.py +++ b/brainpy/_src/tools/__init__.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- from .codes import * -from .others import * from .dicts import * +from .install import * +from .math_util import * +from .others import * from .others import * from .package import * -from .math_util import * -from .install import * diff --git a/brainpy/_src/tools/codes.py b/brainpy/_src/tools/codes.py index 3a0a1a4ea..2c82cf679 100644 --- a/brainpy/_src/tools/codes.py +++ b/brainpy/_src/tools/codes.py @@ -6,93 +6,92 @@ BrainPyObject = None - __all__ = [ - 'repr_dict', - 'repr_object', - 'repr_context', - 'copy_doc', - 'code_lines_to_func', - - # tools for code string - 'get_identifiers', - 'indent', - 'deindent', - 'word_replace', - - # other tools - 'is_lambda_function', - 'get_main_code', - 'get_func_source', - 'change_func_name', + 'repr_dict', + 'repr_object', + 'repr_context', + 'copy_doc', + 'code_lines_to_func', + + # tools for code string + 'get_identifiers', + 'indent', + 'deindent', + 'word_replace', + + # other tools + 'is_lambda_function', + 'get_main_code', + 'get_func_source', + 'change_func_name', ] def repr_dict(dict_obj: dict): - ret = [f'{k}={v}' for k, v in dict_obj.items()] - return ', '.join(ret) + ret = [f'{k}={v}' for k, v in dict_obj.items()] + return ', '.join(ret) def repr_object(x): - global BrainPyObject - if BrainPyObject is None: - from brainpy.math import BrainPyObject - if isinstance(x, BrainPyObject): - return repr(x) - elif callable(x): - signature = inspect.signature(x) - args = [f'{k}={v.default}' for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty] - args = ', '.join(args) - while not hasattr(x, '__name__'): - if not hasattr(x, 'func'): - break - x = x.func # Handle functools.partial - if not hasattr(x, '__name__') and hasattr(x, '__class__'): - return x.__class__.__name__ - if args: - return f'{x.__name__}(*, {args})' - return x.__name__ - else: - x = repr(x).split('\n') - x = [x[0]] + [' ' + y for y in x[1:]] - return '\n'.join(x) + global BrainPyObject + if BrainPyObject is None: + from brainpy.math import BrainPyObject + if isinstance(x, BrainPyObject): + return repr(x) + elif callable(x): + signature = inspect.signature(x) + args = [f'{k}={v.default}' for k, v in signature.parameters.items() + if v.default is not inspect.Parameter.empty] + args = ', '.join(args) + while not hasattr(x, '__name__'): + if not hasattr(x, 'func'): + break + x = x.func # Handle functools.partial + if not hasattr(x, '__name__') and hasattr(x, '__class__'): + return x.__class__.__name__ + if args: + return f'{x.__name__}(*, {args})' + return x.__name__ + else: + x = repr(x).split('\n') + x = [x[0]] + [' ' + y for y in x[1:]] + return '\n'.join(x) def repr_context(repr_str, indent): - splits = repr_str.split('\n') - splits = [(s if i == 0 else (indent + s)) for i, s in enumerate(splits)] - return '\n'.join(splits) + splits = repr_str.split('\n') + splits = [(s if i == 0 else (indent + s)) for i, s in enumerate(splits)] + return '\n'.join(splits) def copy_doc(source_f): - def copy(target_f): - target_f.__doc__ = source_f.__doc__ - return target_f + def copy(target_f): + target_f.__doc__ = source_f.__doc__ + return target_f - return copy + return copy def code_lines_to_func(lines, func_name, func_args, scope, remind=''): - lines_for_compile = [f' {line}' for line in lines] - code_for_compile = '\n'.join(lines_for_compile) - code = f'def {func_name}({", ".join(func_args)}):\n' + \ - f' try:\n' + \ - f'{code_for_compile}\n' + \ - f' except Exception as e:\n' - code += ' exc_type, exc_obj, exc_tb = sys.exc_info()\n' - code += ' line_no = exc_tb.tb_lineno\n' - code += ' raise ValueError(f"Error occurred in line {line_no}: {code_for_debug} {str(e)} {remind}")' - lines_for_debug = [f'[{i + 1:3d}] {line}' for i, line in enumerate(code.split('\n'))] - code_for_debug = '\n'.join(lines_for_debug) - scope['code_for_debug'] = '\n\n' + code_for_debug + '\n\n' - scope['remind'] = '\n' + remind + '\n' - try: - exec(compile(code, '', 'exec'), scope) - except Exception as e: - raise ValueError(f'Compilation function error: \n\n{code}') from e - func = scope[func_name] - return code, func + lines_for_compile = [f' {line}' for line in lines] + code_for_compile = '\n'.join(lines_for_compile) + code = f'def {func_name}({", ".join(func_args)}):\n' + \ + f' try:\n' + \ + f'{code_for_compile}\n' + \ + f' except Exception as e:\n' + code += ' exc_type, exc_obj, exc_tb = sys.exc_info()\n' + code += ' line_no = exc_tb.tb_lineno\n' + code += ' raise ValueError(f"Error occurred in line {line_no}: {code_for_debug} {str(e)} {remind}")' + lines_for_debug = [f'[{i + 1:3d}] {line}' for i, line in enumerate(code.split('\n'))] + code_for_debug = '\n'.join(lines_for_debug) + scope['code_for_debug'] = '\n\n' + code_for_debug + '\n\n' + scope['remind'] = '\n' + remind + '\n' + try: + exec(compile(code, '', 'exec'), scope) + except Exception as e: + raise ValueError(f'Compilation function error: \n\n{code}') from e + func = scope[func_name] + return code, func ###################################### @@ -101,98 +100,98 @@ def code_lines_to_func(lines, func_name, func_args, scope, remind=''): def get_identifiers(expr, include_numbers=False): - """ - Return all the identifiers in a given string ``expr``, that is everything - that matches a programming language variable like expression, which is - here implemented as the regexp ``\\b[A-Za-z_][A-Za-z0-9_]*\\b``. - - Parameters:: - - expr : str - The string to analyze - include_numbers : bool, optional - Whether to include number literals in the output. Defaults to ``False``. - - Returns:: - - identifiers : set - A set of all the identifiers (and, optionally, numbers) in `expr`. - - Examples:: - - >>> expr = '3-a*_b+c5+8+f(A - .3e-10, tau_2)*17' - >>> ids = get_identifiers(expr) - >>> print(sorted(list(ids))) - ['A', '_b', 'a', 'c5', 'f', 'tau_2'] - >>> ids = get_identifiers(expr, include_numbers=True) - >>> print(sorted(list(ids))) - ['.3e-10', '17', '3', '8', 'A', '_b', 'a', 'c5', 'f', 'tau_2'] - """ - - _ID_KEYWORDS = {'and', 'or', 'not', 'True', 'False'} - identifiers = set(re.findall(r'\b[A-Za-z_][A-Za-z0-9_.]*\b', expr)) - # identifiers = set(re.findall(r'\b[A-Za-z_][.?[A-Za-z0-9_]*]*\b', expr)) - if include_numbers: - # only the number, not a + or - - pattern = r'(?<=[^A-Za-z_])[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?|^[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?' - numbers = set(re.findall(pattern, expr)) - else: - numbers = set() - return (identifiers - _ID_KEYWORDS) | numbers + """ + Return all the identifiers in a given string ``expr``, that is everything + that matches a programming language variable like expression, which is + here implemented as the regexp ``\\b[A-Za-z_][A-Za-z0-9_]*\\b``. + + Parameters:: + + expr : str + The string to analyze + include_numbers : bool, optional + Whether to include number literals in the output. Defaults to ``False``. + + Returns:: + + identifiers : set + A set of all the identifiers (and, optionally, numbers) in `expr`. + + Examples:: + + >>> expr = '3-a*_b+c5+8+f(A - .3e-10, tau_2)*17' + >>> ids = get_identifiers(expr) + >>> print(sorted(list(ids))) + ['A', '_b', 'a', 'c5', 'f', 'tau_2'] + >>> ids = get_identifiers(expr, include_numbers=True) + >>> print(sorted(list(ids))) + ['.3e-10', '17', '3', '8', 'A', '_b', 'a', 'c5', 'f', 'tau_2'] + """ + + _ID_KEYWORDS = {'and', 'or', 'not', 'True', 'False'} + identifiers = set(re.findall(r'\b[A-Za-z_][A-Za-z0-9_.]*\b', expr)) + # identifiers = set(re.findall(r'\b[A-Za-z_][.?[A-Za-z0-9_]*]*\b', expr)) + if include_numbers: + # only the number, not a + or - + pattern = r'(?<=[^A-Za-z_])[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?|^[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?' + numbers = set(re.findall(pattern, expr)) + else: + numbers = set() + return (identifiers - _ID_KEYWORDS) | numbers def indent(text, num_tabs=1, spaces_per_tab=4, tab=None): - if tab is None: - tab = ' ' * spaces_per_tab - indent_ = tab * num_tabs - indented_string = indent_ + text.replace('\n', '\n' + indent_) - return indented_string + if tab is None: + tab = ' ' * spaces_per_tab + indent_ = tab * num_tabs + indented_string = indent_ + text.replace('\n', '\n' + indent_) + return indented_string def deindent(text, num_tabs=None, spaces_per_tab=4, docstring=False): - text = text.replace('\t', ' ' * spaces_per_tab) - lines = text.split('\n') - # if it's a docstring, we search for the common tabulation starting from - # line 1, otherwise we use all lines - if docstring: - start = 1 - else: - start = 0 - if docstring and len(lines) < 2: # nothing to do - return text - # Find the minimum indentation level - if num_tabs is not None: - indent_level = num_tabs * spaces_per_tab - else: - line_seq = [len(line) - len(line.lstrip()) for line in lines[start:] if len(line.strip())] - if len(line_seq) == 0: - indent_level = 0 + text = text.replace('\t', ' ' * spaces_per_tab) + lines = text.split('\n') + # if it's a docstring, we search for the common tabulation starting from + # line 1, otherwise we use all lines + if docstring: + start = 1 else: - indent_level = min(line_seq) - # remove the common indentation - lines[start:] = [line[indent_level:] for line in lines[start:]] - return '\n'.join(lines) + start = 0 + if docstring and len(lines) < 2: # nothing to do + return text + # Find the minimum indentation level + if num_tabs is not None: + indent_level = num_tabs * spaces_per_tab + else: + line_seq = [len(line) - len(line.lstrip()) for line in lines[start:] if len(line.strip())] + if len(line_seq) == 0: + indent_level = 0 + else: + indent_level = min(line_seq) + # remove the common indentation + lines[start:] = [line[indent_level:] for line in lines[start:]] + return '\n'.join(lines) def word_replace(expr, substitutions, exclude_dot=True): - """Applies a dict of word substitutions. - - The dict ``substitutions`` consists of pairs ``(word, rep)`` where each - word ``word`` appearing in ``expr`` is replaced by ``rep``. Here a 'word' - means anything matching the regexp ``\\bword\\b``. - - Examples:: - - >>> expr = 'a*_b+c5+8+f(A)' - >>> print(word_replace(expr, {'a':'banana', 'f':'func'})) - banana*_b+c5+8+func(A) - """ - for var, replace_var in substitutions.items(): - if exclude_dot: - expr = re.sub(r'\b(?>> expr = 'a*_b+c5+8+f(A)' + >>> print(word_replace(expr, {'a':'banana', 'f':'func'})) + banana*_b+c5+8+func(A) + """ + for var, replace_var in substitutions.items(): + if exclude_dot: + expr = re.sub(r'\b(?" + Returns:: + + bool + True of False. + """ + return isinstance(func, LambdaType) and func.__name__ == "" def get_func_source(func): - code = inspect.getsource(func) - # remove @ - try: - start = code.index('def ') - code = code[start:] - except ValueError: - pass - return code + code = inspect.getsource(func) + # remove @ + try: + start = code.index('def ') + code = code[start:] + except ValueError: + pass + return code def get_main_code(func, codes=None): - """Get the main function _code string. - - For lambda function, return the - - Parameters:: - - func : callable, Optional, int, float - - Returns:: - - """ - if func is None: - return '' - elif callable(func): - if is_lambda_function(func): - codes = (codes or get_func_source(func)) - splits = codes.split(':') - if len(splits) != 2: - raise ValueError(f'Can not parse function: \n{codes}') - return f'return {splits[1]}' - + """Get the main function _code string. + + For lambda function, return the + + Parameters:: + + func : callable, Optional, int, float + + Returns:: + + """ + if func is None: + return '' + elif callable(func): + if is_lambda_function(func): + codes = (codes or get_func_source(func)) + splits = codes.split(':') + if len(splits) != 2: + raise ValueError(f'Can not parse function: \n{codes}') + return f'return {splits[1]}' + + else: + codes = (codes.split('\n') or inspect.getsourcelines(func)[0]) + idx = 0 + for line in codes: + idx += 1 + line = line.replace(' ', '') + if '):' in line: + break + else: + code = "\n".join(codes) + raise ValueError(f'Can not parse function: \n{code}') + return ''.join(codes[idx:]) else: - codes = (codes.split('\n') or inspect.getsourcelines(func)[0]) - idx = 0 - for line in codes: - idx += 1 - line = line.replace(' ', '') - if '):' in line: - break - else: - code = "\n".join(codes) - raise ValueError(f'Can not parse function: \n{code}') - return ''.join(codes[idx:]) - else: - raise ValueError(f'Unknown function type: {type(func)}.') + raise ValueError(f'Unknown function type: {type(func)}.') diff --git a/brainpy/_src/tools/dicts.py b/brainpy/_src/tools/dicts.py index 23835909e..e9c3732b5 100644 --- a/brainpy/_src/tools/dicts.py +++ b/brainpy/_src/tools/dicts.py @@ -4,181 +4,181 @@ from typing import Union, Dict, Sequence import numpy as np -from jax.tree_util import register_pytree_node from brainstate._compatible_import import safe_zip +from jax.tree_util import register_pytree_node __all__ = [ - 'DotDict', + 'DotDict', ] class DotDict(dict): - """Python dictionaries with advanced dot notation access. - - For example: - - >>> d = DotDict({'a': 10, 'b': 20}) - >>> d.a - 10 - >>> d['a'] - 10 - >>> d.c # this will raise a KeyError - KeyError: 'c' - >>> d.c = 30 # but you can assign a value to a non-existing item - >>> d.c - 30 - - In general, all attributes will be included ad keys in the dict. - For example, if you add an attribute to specify what variable names - you have: - - >>> d.names = ('a', 'b') - - This attribute `names` will cause error when you treat the object as - a PyTree. - - >>> from jax import jit - >>> f = jit(lambda x: x) - >>> f(d) - TypeError: Argument 'a' of type is not a valid JAX type. + """Python dictionaries with advanced dot notation access. - """ + For example: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.__dict__ = self + >>> d = DotDict({'a': 10, 'b': 20}) + >>> d.a + 10 + >>> d['a'] + 10 + >>> d.c # this will raise a KeyError + KeyError: 'c' + >>> d.c = 30 # but you can assign a value to a non-existing item + >>> d.c + 30 - def copy(self) -> 'DotDict': - return type(self)(super().copy()) + In general, all attributes will be included ad keys in the dict. + For example, if you add an attribute to specify what variable names + you have: - def to_numpy(self): - """Change all values to numpy arrays.""" - for key in tuple(self.keys()): - self[key] = np.asarray(self[key]) + >>> d.names = ('a', 'b') - def update(self, *args, **kwargs): - super().update(*args, **kwargs) - return self + This attribute `names` will cause error when you treat the object as + a PyTree. - def __add__(self, other): - """Merging two dicts. + >>> from jax import jit + >>> f = jit(lambda x: x) + >>> f(d) + TypeError: Argument 'a' of type is not a valid JAX type. - Parameters:: - - other: dict - The other dict instance. - - Returns:: - - gather: Collector - The new collector. """ - gather = type(self)(self) - gather.update(other) - return gather - - def __sub__(self, other: Union[Dict, Sequence]): - """Remove other item in the collector. - Parameters:: - - other: dict, sequence - The items to remove. - - Returns:: - - gather: Collector - The new collector. - """ - if not isinstance(other, (dict, tuple, list)): - raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.') - gather = type(self)(self) - if isinstance(other, dict): - for key, val in other.items(): - if key in gather: - if id(val) != id(gather[key]): - raise ValueError(f'Cannot remove {key}, because we got two different values: ' - f'{val} != {gather[key]}') - gather.pop(key) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + def copy(self) -> 'DotDict': + return type(self)(super().copy()) + + def to_numpy(self): + """Change all values to numpy arrays.""" + for key in tuple(self.keys()): + self[key] = np.asarray(self[key]) + + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + return self + + def __add__(self, other): + """Merging two dicts. + + Parameters:: + + other: dict + The other dict instance. + + Returns:: + + gather: Collector + The new collector. + """ + gather = type(self)(self) + gather.update(other) + return gather + + def __sub__(self, other: Union[Dict, Sequence]): + """Remove other item in the collector. + + Parameters:: + + other: dict, sequence + The items to remove. + + Returns:: + + gather: Collector + The new collector. + """ + if not isinstance(other, (dict, tuple, list)): + raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.') + gather = type(self)(self) + if isinstance(other, dict): + for key, val in other.items(): + if key in gather: + if id(val) != id(gather[key]): + raise ValueError(f'Cannot remove {key}, because we got two different values: ' + f'{val} != {gather[key]}') + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') + elif isinstance(other, (list, tuple)): + id_to_keys = {} + for k, v in self.items(): + id_ = id(v) + if id_ not in id_to_keys: + id_to_keys[id_] = [] + id_to_keys[id_].append(k) + + keys_to_remove = [] + for key in other: + if isinstance(key, str): + keys_to_remove.append(key) + else: + keys_to_remove.extend(id_to_keys[id(key)]) + + for key in set(keys_to_remove): + if key in gather: + gather.pop(key) + else: + raise ValueError(f'Cannot remove {key}, because we do not find it ' + f'in {self.keys()}.') else: - raise ValueError(f'Cannot remove {key}, because we do not find it ' - f'in {self.keys()}.') - elif isinstance(other, (list, tuple)): - id_to_keys = {} - for k, v in self.items(): - id_ = id(v) - if id_ not in id_to_keys: - id_to_keys[id_] = [] - id_to_keys[id_].append(k) - - keys_to_remove = [] - for key in other: - if isinstance(key, str): - keys_to_remove.append(key) - else: - keys_to_remove.extend(id_to_keys[id(key)]) - - for key in set(keys_to_remove): - if key in gather: - gather.pop(key) - else: - raise ValueError(f'Cannot remove {key}, because we do not find it ' - f'in {self.keys()}.') - else: - raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}') - return gather - - def subset(self, var_type): - """Get the subset of the (key, value) pair. - - ``subset()`` can be used to get a subset of some class: - - >>> import brainpy as bp - >>> - >>> some_collector = DotDict() - >>> - >>> # get all trainable variables - >>> some_collector.subset(bp.math.TrainVar) - >>> - >>> # get all Variable - >>> some_collector.subset(bp.math.Variable) - - or, it can be used to get a subset of integrators: - - >>> # get all ODE integrators - >>> some_collector.subset(bp.ode.ODEIntegrator) - - Parameters:: - - var_type : type - The type/class to match. - """ - gather = type(self)() - for key, value in self.items(): - if isinstance(value, var_type): - gather[key] = value - return gather - - def unique(self): - """Get a new type of collector with unique values. - - If one value is assigned to two or more keys, - then only one pair of (key, value) will be returned. - """ - gather = type(self)() - seen = set() - for k, v in self.items(): - if id(v) not in seen: - seen.add(id(v)) - gather[k] = v - return gather - - def __hash__(self): - return hash(tuple(sorted(self.items()))) + raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}') + return gather + + def subset(self, var_type): + """Get the subset of the (key, value) pair. + + ``subset()`` can be used to get a subset of some class: + + >>> import brainpy as bp + >>> + >>> some_collector = DotDict() + >>> + >>> # get all trainable variables + >>> some_collector.subset(bp.math.TrainVar) + >>> + >>> # get all Variable + >>> some_collector.subset(bp.math.Variable) + + or, it can be used to get a subset of integrators: + + >>> # get all ODE integrators + >>> some_collector.subset(bp.ode.ODEIntegrator) + + Parameters:: + + var_type : type + The type/class to match. + """ + gather = type(self)() + for key, value in self.items(): + if isinstance(value, var_type): + gather[key] = value + return gather + + def unique(self): + """Get a new type of collector with unique values. + + If one value is assigned to two or more keys, + then only one pair of (key, value) will be returned. + """ + gather = type(self)() + seen = set() + for k, v in self.items(): + if id(v) not in seen: + seen.add(id(v)) + gather[k] = v + return gather + + def __hash__(self): + return hash(tuple(sorted(self.items()))) register_pytree_node( - DotDict, - lambda x: (x.values(), x.keys()), - lambda keys, values: DotDict(safe_zip(keys, values)) + DotDict, + lambda x: (x.values(), x.keys()), + lambda keys, values: DotDict(safe_zip(keys, values)) ) diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py index cbc710dba..378f2515e 100644 --- a/brainpy/_src/tools/functions.py +++ b/brainpy/_src/tools/functions.py @@ -4,189 +4,189 @@ from types import MethodType __all__ = [ - 'compose', 'pipe' + 'compose', 'pipe' ] def identity(x): - """ Identity function. Return x + """ Identity function. Return x - >>> identity(3) - 3 - """ - return x + >>> identity(3) + 3 + """ + return x def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): - """ Like @property, but returns ``classval`` when used as a class attribute - - >>> class MyClass(object): - ... '''The class docstring''' - ... @instanceproperty(classval=__doc__) - ... def __doc__(self): - ... return 'An object docstring' - ... @instanceproperty - ... def val(self): - ... return 42 - ... - >>> MyClass.__doc__ - 'The class docstring' - >>> MyClass.val is None - True - >>> obj = MyClass() - >>> obj.__doc__ - 'An object docstring' - >>> obj.val - 42 - """ - if fget is None: - return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, - classval=classval) - return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, - classval=classval) + """ Like @property, but returns ``classval`` when used as a class attribute + + >>> class MyClass(object): + ... '''The class docstring''' + ... @instanceproperty(classval=__doc__) + ... def __doc__(self): + ... return 'An object docstring' + ... @instanceproperty + ... def val(self): + ... return 42 + ... + >>> MyClass.__doc__ + 'The class docstring' + >>> MyClass.val is None + True + >>> obj = MyClass() + >>> obj.__doc__ + 'An object docstring' + >>> obj.val + 42 + """ + if fget is None: + return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, + classval=classval) + return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, + classval=classval) class InstanceProperty(property): - """ Like @property, but returns ``classval`` when used as a class attribute + """ Like @property, but returns ``classval`` when used as a class attribute - Should not be used directly. Use ``instanceproperty`` instead. - """ + Should not be used directly. Use ``instanceproperty`` instead. + """ - def __init__(self, fget=None, fset=None, fdel=None, doc=None, - classval=None): - self.classval = classval - property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) + def __init__(self, fget=None, fset=None, fdel=None, doc=None, + classval=None): + self.classval = classval + property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) - def __get__(self, obj, type=None): - if obj is None: - return self.classval - return property.__get__(self, obj, type) + def __get__(self, obj, type=None): + if obj is None: + return self.classval + return property.__get__(self, obj, type) - def __reduce__(self): - state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) - return InstanceProperty, state + def __reduce__(self): + state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) + return InstanceProperty, state class Compose(object): - """ A composition of functions - - See Also: - compose - """ - __slots__ = 'first', 'funcs' - - def __init__(self, funcs): - funcs = tuple(reversed(funcs)) - self.first = funcs[0] - self.funcs = funcs[1:] - - def __call__(self, *args, **kwargs): - ret = self.first(*args, **kwargs) - for f in self.funcs: - ret = f(ret) - return ret - - def __getstate__(self): - return self.first, self.funcs - - def __setstate__(self, state): - self.first, self.funcs = state - - @instanceproperty(classval=__doc__) - def __doc__(self): - def composed_doc(*fs): - """Generate a docstring for the composition of fs. - """ - if not fs: - # Argument name for the docstring. - return '*args, **kwargs' - - return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) - - try: - return ( - 'lambda *args, **kwargs: ' + - composed_doc(*reversed((self.first,) + self.funcs)) - ) - except AttributeError: - # One of our callables does not have a `__name__`, whatever. - return 'A composition of functions' - - @property - def __name__(self): - try: - return '_of_'.join( - (f.__name__ for f in reversed((self.first,) + self.funcs)) - ) - except AttributeError: - return type(self).__name__ - - def __repr__(self): - return '{.__class__.__name__}{!r}'.format( - self, tuple(reversed((self.first,) + self.funcs))) - - def __eq__(self, other): - if isinstance(other, Compose): - return other.first == self.first and other.funcs == self.funcs - return NotImplemented - - def __ne__(self, other): - equality = self.__eq__(other) - return NotImplemented if equality is NotImplemented else not equality - - def __hash__(self): - return hash(self.first) ^ hash(self.funcs) - - # Mimic the descriptor behavior of python functions. - # i.e. let Compose be called as a method when bound to a class. - # adapted from - # docs.python.org/3/howto/descriptor.html#functions-and-methods - def __get__(self, obj, objtype=None): - return self if obj is None else MethodType(self, obj) - - # introspection with Signature is only possible from py3.3+ - @instanceproperty - def __signature__(self): - base = inspect.signature(self.first) - last = inspect.signature(self.funcs[-1]) - return base.replace(return_annotation=last.return_annotation) - - __wrapped__ = instanceproperty(attrgetter('first')) + """ A composition of functions + + See Also: + compose + """ + __slots__ = 'first', 'funcs' + + def __init__(self, funcs): + funcs = tuple(reversed(funcs)) + self.first = funcs[0] + self.funcs = funcs[1:] + + def __call__(self, *args, **kwargs): + ret = self.first(*args, **kwargs) + for f in self.funcs: + ret = f(ret) + return ret + + def __getstate__(self): + return self.first, self.funcs + + def __setstate__(self, state): + self.first, self.funcs = state + + @instanceproperty(classval=__doc__) + def __doc__(self): + def composed_doc(*fs): + """Generate a docstring for the composition of fs. + """ + if not fs: + # Argument name for the docstring. + return '*args, **kwargs' + + return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) + + try: + return ( + 'lambda *args, **kwargs: ' + + composed_doc(*reversed((self.first,) + self.funcs)) + ) + except AttributeError: + # One of our callables does not have a `__name__`, whatever. + return 'A composition of functions' + + @property + def __name__(self): + try: + return '_of_'.join( + (f.__name__ for f in reversed((self.first,) + self.funcs)) + ) + except AttributeError: + return type(self).__name__ + + def __repr__(self): + return '{.__class__.__name__}{!r}'.format( + self, tuple(reversed((self.first,) + self.funcs))) + + def __eq__(self, other): + if isinstance(other, Compose): + return other.first == self.first and other.funcs == self.funcs + return NotImplemented + + def __ne__(self, other): + equality = self.__eq__(other) + return NotImplemented if equality is NotImplemented else not equality + + def __hash__(self): + return hash(self.first) ^ hash(self.funcs) + + # Mimic the descriptor behavior of python functions. + # i.e. let Compose be called as a method when bound to a class. + # adapted from + # docs.python.org/3/howto/descriptor.html#functions-and-methods + def __get__(self, obj, objtype=None): + return self if obj is None else MethodType(self, obj) + + # introspection with Signature is only possible from py3.3+ + @instanceproperty + def __signature__(self): + base = inspect.signature(self.first) + last = inspect.signature(self.funcs[-1]) + return base.replace(return_annotation=last.return_annotation) + + __wrapped__ = instanceproperty(attrgetter('first')) def compose(*funcs): - """ Compose functions to operate in series. + """ Compose functions to operate in series. - Returns a function that applies other functions in sequence. + Returns a function that applies other functions in sequence. - Functions are applied from right to left so that - ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. + Functions are applied from right to left so that + ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. - If no arguments are provided, the identity function (f(x) = x) is returned. + If no arguments are provided, the identity function (f(x) = x) is returned. - >>> inc = lambda i: i + 1 - >>> compose(str, inc)(3) - '4' - """ - if not funcs: - return identity - if len(funcs) == 1: - return funcs[0] - else: - return Compose(funcs) + >>> inc = lambda i: i + 1 + >>> compose(str, inc)(3) + '4' + """ + if not funcs: + return identity + if len(funcs) == 1: + return funcs[0] + else: + return Compose(funcs) def pipe(*funcs): - """ Pipe a value through a sequence of functions + """ Pipe a value through a sequence of functions - I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` + I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` - We think of the value as progressing through a pipe of several - transformations, much like pipes in UNIX + We think of the value as progressing through a pipe of several + transformations, much like pipes in UNIX - >>> double = lambda i: 2 * i - >>> pipe(double, str)(3) - '6' - """ - return compose(*reversed(funcs)) + >>> double = lambda i: 2 * i + >>> pipe(double, str)(3) + '6' + """ + return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/install.py b/brainpy/_src/tools/install.py index 68981a5ec..27c4285ad 100644 --- a/brainpy/_src/tools/install.py +++ b/brainpy/_src/tools/install.py @@ -1,9 +1,7 @@ - __all__ = [ - 'jaxlib_install_info', + 'jaxlib_install_info', ] - jaxlib_install_info = ''' BrainPy needs jaxlib, please install it. diff --git a/brainpy/_src/tools/math_util.py b/brainpy/_src/tools/math_util.py index 3b1cdce1c..7aab09350 100644 --- a/brainpy/_src/tools/math_util.py +++ b/brainpy/_src/tools/math_util.py @@ -2,17 +2,15 @@ import numpy as np - __all__ = [ - 'format_seed' + 'format_seed' ] def format_seed(seed=None): - """Get the random sed. - """ - if seed is None: - return np.random.randint(0, int(1e7)) - else: - return seed - + """Get the random sed. + """ + if seed is None: + return np.random.randint(0, int(1e7)) + else: + return seed diff --git a/brainpy/_src/tools/others.py b/brainpy/_src/tools/others.py index 79b205340..49e10dbc4 100644 --- a/brainpy/_src/tools/others.py +++ b/brainpy/_src/tools/others.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import collections.abc import _thread as thread +import collections.abc import threading from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar, Any @@ -11,28 +11,28 @@ from tqdm.auto import tqdm __all__ = [ - 'one_of', - 'replicate', - 'not_customized', - 'to_size', - 'size2num', - 'timeout', - 'init_progress_bar', + 'one_of', + 'replicate', + 'not_customized', + 'to_size', + 'size2num', + 'timeout', + 'init_progress_bar', ] def one_of(default: Any, *choices, names: Sequence[str] = None): - names = [f'arg{i}' for i in range(len(choices))] if names is None else names - res = default - has_chosen = False - for c in choices: - if c is not None: - if has_chosen: - raise ValueError(f'Provide one of {names}, but we got {list(zip(choices, names))}') - else: - has_chosen = True - res = c - return res + names = [f'arg{i}' for i in range(len(choices))] if names is None else names + res = default + has_chosen = False + for c in choices: + if c is not None: + if has_chosen: + raise ValueError(f'Provide one of {names}, but we got {list(zip(choices, names))}') + else: + has_chosen = True + res = c + return res T = TypeVar('T') @@ -43,144 +43,143 @@ def replicate( num_replicate: int, name: str, ) -> Tuple[T, ...]: - """Replicates entry in `element` `num_replicate` if needed.""" - if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence): - return (element,) * num_replicate - elif len(element) == 1: - return tuple(element * num_replicate) - elif len(element) == num_replicate: - return tuple(element) - else: - raise TypeError(f"{name} must be a scalar or sequence of length 1 or " - f"sequence of length {num_replicate}.") + """Replicates entry in `element` `num_replicate` if needed.""" + if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence): + return (element,) * num_replicate + elif len(element) == 1: + return tuple(element * num_replicate) + elif len(element) == num_replicate: + return tuple(element) + else: + raise TypeError(f"{name} must be a scalar or sequence of length 1 or " + f"sequence of length {num_replicate}.") def not_customized(fun: Callable) -> Callable: - """Marks the given module method is not implemented. + """Marks the given module method is not implemented. - Methods wrapped in @not_customized can define submodules directly within the method. + Methods wrapped in @not_customized can define submodules directly within the method. - For instance:: + For instance:: - @not_customized - init_fb(self): - ... + @not_customized + init_fb(self): + ... - @not_customized - def feedback(self): - ... - """ - fun.not_customized = True - return fun + @not_customized + def feedback(self): + ... + """ + fun.not_customized = True + return fun def size2num(size): - if isinstance(size, (int, np.integer)): - return size - elif isinstance(size, (tuple, list)): - a = 1 - for b in size: - a *= b - return a - else: - raise ValueError(f'Do not support type {type(size)}: {size}') + if isinstance(size, (int, np.integer)): + return size + elif isinstance(size, (tuple, list)): + a = 1 + for b in size: + a *= b + return a + else: + raise ValueError(f'Do not support type {type(size)}: {size}') def to_size(x) -> Optional[Tuple[int]]: - if isinstance(x, (tuple, list)): - return tuple(x) - if isinstance(x, (int, np.integer)): - return (x,) - if x is None: - return x - raise ValueError(f'Cannot make a size for {x}') + if isinstance(x, (tuple, list)): + return tuple(x) + if isinstance(x, (int, np.integer)): + return (x,) + if x is None: + return x + raise ValueError(f'Cannot make a size for {x}') def timeout(s): - """Add a timeout parameter to a function and return it. + """Add a timeout parameter to a function and return it. - Parameters:: + Parameters:: - s : float - Time limit in seconds. + s : float + Time limit in seconds. - Returns:: + Returns:: - func : callable - Functional results. Or, raise an error of KeyboardInterrupt. - """ + func : callable + Functional results. Or, raise an error of KeyboardInterrupt. + """ - def outer(fn): - def inner(*args, **kwargs): - timer = threading.Timer(s, thread.interrupt_main) - timer.start() - try: - result = fn(*args, **kwargs) - finally: - timer.cancel() - return result + def outer(fn): + def inner(*args, **kwargs): + timer = threading.Timer(s, thread.interrupt_main) + timer.start() + try: + result = fn(*args, **kwargs) + finally: + timer.cancel() + return result - return inner + return inner - return outer + return outer def init_progress_bar(duration, dt, report=0.01, message=None): - """Setup a progress bar.""" - if message is None: - message = f"Running a duration of {duration}" - - num_samples = int(duration / dt) - print_rate = int(duration * report / dt) - remainder = num_samples % print_rate - - tqdm_bars = {} - - def _define_tqdm(arg, transform): - tqdm_bars[0] = tqdm(np.arange(0, duration, dt)) - tqdm_bars[0].set_description(message, refresh=False) - - def _update_tqdm(num_processed, transform): - tqdm_bars[0].update(num_processed * dt) - - def _update_progress_bar(num_iter): - _ = lax.cond( - num_iter == 0, - lambda _: host_callback.id_tap(_define_tqdm, None, result=num_iter), - lambda _: num_iter, - operand=None, - ) - - _ = lax.cond( - # update tqdm every multiple of `print_rate` except at the end - (num_iter % print_rate == 0) & (num_iter != num_samples - remainder), - lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=num_iter), - lambda _: num_iter, - operand=None, - ) - - _ = lax.cond( - # update tqdm by `remainder` - num_iter == num_samples - remainder, - lambda _: host_callback.id_tap(_update_tqdm, remainder, result=num_iter), - lambda _: num_iter, - operand=None, - ) - - def _close_tqdm(arg, transform): - tqdm_bars[0].close() - - def close_tqdm(iter_num): - return lax.cond( - iter_num == num_samples - 1, - lambda _: host_callback.id_tap(_close_tqdm, None, result=None), - lambda _: None, - operand=None, - ) - - def _progress_bar(iter_num): - _update_progress_bar(iter_num) - close_tqdm(iter_num) - - return _progress_bar - + """Setup a progress bar.""" + if message is None: + message = f"Running a duration of {duration}" + + num_samples = int(duration / dt) + print_rate = int(duration * report / dt) + remainder = num_samples % print_rate + + tqdm_bars = {} + + def _define_tqdm(arg, transform): + tqdm_bars[0] = tqdm(np.arange(0, duration, dt)) + tqdm_bars[0].set_description(message, refresh=False) + + def _update_tqdm(num_processed, transform): + tqdm_bars[0].update(num_processed * dt) + + def _update_progress_bar(num_iter): + _ = lax.cond( + num_iter == 0, + lambda _: host_callback.id_tap(_define_tqdm, None, result=num_iter), + lambda _: num_iter, + operand=None, + ) + + _ = lax.cond( + # update tqdm every multiple of `print_rate` except at the end + (num_iter % print_rate == 0) & (num_iter != num_samples - remainder), + lambda _: host_callback.id_tap(_update_tqdm, print_rate, result=num_iter), + lambda _: num_iter, + operand=None, + ) + + _ = lax.cond( + # update tqdm by `remainder` + num_iter == num_samples - remainder, + lambda _: host_callback.id_tap(_update_tqdm, remainder, result=num_iter), + lambda _: num_iter, + operand=None, + ) + + def _close_tqdm(arg, transform): + tqdm_bars[0].close() + + def close_tqdm(iter_num): + return lax.cond( + iter_num == num_samples - 1, + lambda _: host_callback.id_tap(_close_tqdm, None, result=None), + lambda _: None, + operand=None, + ) + + def _progress_bar(iter_num): + _update_progress_bar(iter_num) + close_tqdm(iter_num) + + return _progress_bar diff --git a/brainpy/_src/tools/package.py b/brainpy/_src/tools/package.py index c459ecfac..e793b0646 100644 --- a/brainpy/_src/tools/package.py +++ b/brainpy/_src/tools/package.py @@ -3,41 +3,38 @@ import numpy as np try: - import numba + import numba except (ImportError, ModuleNotFoundError): - numba = None - + numba = None __all__ = [ - 'numba_jit', - 'numba_seed', - 'numba_range', - 'SUPPORT_NUMBA', + 'numba_jit', + 'numba_seed', + 'numba_range', + 'SUPPORT_NUMBA', ] - - SUPPORT_NUMBA = numba is not None def numba_jit(f=None, **kwargs): - if f is None: - return lambda f: (f if (numba is None) else numba.njit(f, **kwargs)) - else: - if numba is None: - return f + if f is None: + return lambda f: (f if (numba is None) else numba.njit(f, **kwargs)) else: - return numba.njit(f) + if numba is None: + return f + else: + return numba.njit(f) @numba_jit def _seed(seed): - np.random.seed(seed) + np.random.seed(seed) def numba_seed(seed): - if numba is not None and seed is not None: - _seed(seed) + if numba is not None and seed is not None: + _seed(seed) numba_range = numba.prange if SUPPORT_NUMBA else range diff --git a/brainpy/_src/tools/progress.py b/brainpy/_src/tools/progress.py index 13b6a1574..c216c5515 100644 --- a/brainpy/_src/tools/progress.py +++ b/brainpy/_src/tools/progress.py @@ -17,503 +17,503 @@ def func_dump(func): - """Serializes a user defined function. - - Args: - func: the function to serialize. - - Returns: - A tuple `(code, defaults, closure)`. - """ - if os.name == "nt": - raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") - code = codecs.encode(raw_code, "base64").decode("ascii") - else: - raw_code = marshal.dumps(func.__code__) - code = codecs.encode(raw_code, "base64").decode("ascii") - defaults = func.__defaults__ - if func.__closure__: - closure = tuple(c.cell_contents for c in func.__closure__) - else: - closure = None - return code, defaults, closure + """Serializes a user defined function. + Args: + func: the function to serialize. -def func_load(code, defaults=None, closure=None, globs=None): - """Deserializes a user defined function. - - Args: - code: bytecode of the function. - defaults: defaults of the function. - closure: closure of the function. - globs: dictionary of global objects. + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == "nt": + raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") + code = codecs.encode(raw_code, "base64").decode("ascii") + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, "base64").decode("ascii") + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure - Returns: - A function object. - """ - if isinstance(code, (tuple, list)): # unpack previous dump - code, defaults, closure = code - if isinstance(defaults, list): - defaults = tuple(defaults) - def ensure_value_to_cell(value): - """Ensures that a value is converted to a python cell object. +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. Args: - value: Any value that needs to be casted to the cell type + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. Returns: - A value wrapped as a cell object (see function "func_load") + A function object. """ - - def dummy_fn(): - value # just access it so it gets captured in .__closure__ - - cell_value = dummy_fn.__closure__[0] - if not isinstance(value, type(cell_value)): - return cell_value - return value - - if closure is not None: - closure = tuple(ensure_value_to_cell(_) for _ in closure) - try: - raw_code = codecs.decode(code.encode("ascii"), "base64") - except (UnicodeEncodeError, binascii.Error): - raw_code = code.encode("raw_unicode_escape") - code = marshal.loads(raw_code) - if globs is None: - globs = globals() - return python_types.FunctionType( - code, globs, name=code.co_name, argdefs=defaults, closure=closure - ) + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Args: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + + def dummy_fn(): + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode("ascii"), "base64") + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode("raw_unicode_escape") + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure + ) class Progbar: - """Displays a progress bar. - - Args: - target: Total number of steps expected, None if unknown. - width: Progress bar width on screen. - verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) - stateful_metrics: Iterable of string names of metrics that should *not* - be averaged over time. Metrics in this list will be displayed as-is. - All others will be averaged by the progbar before display. - interval: Minimum visual progress update interval (in seconds). - unit_name: Display name for step counts (usually "step" or "sample"). - """ - - def __init__( - self, - target, - width=30, - verbose=1, - interval=0.05, - stateful_metrics=None, - unit_name="step", - ): - self.target = target - self.width = width - self.verbose = verbose - self.interval = interval - self.unit_name = unit_name - if stateful_metrics: - self.stateful_metrics = set(stateful_metrics) - else: - self.stateful_metrics = set() - - self._dynamic_display = ( - (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) - or "ipykernel" in sys.modules - or "posix" in sys.modules - or "PYCHARM_HOSTED" in os.environ - ) - self._total_width = 0 - self._seen_so_far = 0 - # We use a dict + list to avoid garbage collection - # issues found in OrderedDict - self._values = {} - self._values_order = [] - self._start = time.time() - self._last_update = 0 - self._time_at_epoch_start = self._start - self._time_at_epoch_end = None - self._time_after_first_step = None - - def update(self, current, values=None, finalize=None): - """Updates the progress bar. + """Displays a progress bar. Args: - current: Index of current step. - values: List of tuples: `(name, value_for_last_step)`. If `name` is - in `stateful_metrics`, `value_for_last_step` will be displayed - as-is. Else, an average of the metric over time will be - displayed. - finalize: Whether this is the last update for the progress bar. If - `None`, uses `current >= self.target`. Defaults to `None`. + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). """ - if finalize is None: - if self.target is None: - finalize = False - else: - finalize = current >= self.target - - values = values or [] - for k, v in values: - if k not in self._values_order: - self._values_order.append(k) - if k not in self.stateful_metrics: - # In the case that progress bar doesn't have a target value in - # the first epoch, both on_batch_end and on_epoch_end will be - # called, which will cause 'current' and 'self._seen_so_far' to - # have the same value. Force the minimal value to 1 here, - # otherwise stateful_metric will be 0s. - value_base = max(current - self._seen_so_far, 1) - if k not in self._values: - self._values[k] = [v * value_base, value_base] + + def __init__( + self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name="step", + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) else: - self._values[k][0] += v * value_base - self._values[k][1] += value_base - else: - # Stateful metrics output a numeric value. This representation - # means "take an average from a single value" but keeps the - # numeric formatting. - self._values[k] = [v, 1] - self._seen_so_far = current - - message = "" - now = time.time() - info = f" - {now - self._start:.0f}s" - if current == self.target: - self._time_at_epoch_end = now - if self.verbose == 1: - if now - self._last_update < self.interval and not finalize: - return - - prev_total_width = self._total_width - if self._dynamic_display: - message += "\b" * prev_total_width - message += "\r" - else: - message += "\n" - - if self.target is not None: - numdigits = int(np.log10(self.target)) + 1 - bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) - prog = float(current) / self.target - prog_width = int(self.width * prog) - if prog_width > 0: - bar += "=" * (prog_width - 1) - if current < self.target: - bar += ">" - else: - bar += "=" - bar += "." * (self.width - prog_width) - bar += "]" - else: - bar = "%7d/Unknown" % current - - self._total_width = len(bar) - message += bar - - time_per_unit = self._estimate_step_duration(current, now) - - if self.target is None or finalize: - info += self._format_time(time_per_unit, self.unit_name) - else: - eta = time_per_unit * (self.target - current) - if eta > 3600: - eta_format = "%d:%02d:%02d" % ( - eta // 3600, - (eta % 3600) // 60, - eta % 60, - ) - elif eta > 60: - eta_format = "%d:%02d" % (eta // 60, eta % 60) + self.stateful_metrics = set() + + self._dynamic_display = ( + (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) + or "ipykernel" in sys.modules + or "posix" in sys.modules + or "PYCHARM_HOSTED" in os.environ + ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + self._time_at_epoch_start = self._start + self._time_at_epoch_end = None + self._time_after_first_step = None + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + + Args: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, uses `current >= self.target`. Defaults to `None`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in + # the first epoch, both on_batch_end and on_epoch_end will be + # called, which will cause 'current' and 'self._seen_so_far' to + # have the same value. Force the minimal value to 1 here, + # otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + message = "" + now = time.time() + info = f" - {now - self._start:.0f}s" + if current == self.target: + self._time_at_epoch_end = now + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + message += "\b" * prev_total_width + message += "\r" + else: + message += "\n" + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += "=" * (prog_width - 1) + if current < self.target: + bar += ">" + else: + bar += "=" + bar += "." * (self.width - prog_width) + bar += "]" + else: + bar = "%7d/Unknown" % current + + self._total_width = len(bar) + message += bar + + time_per_unit = self._estimate_step_duration(current, now) + + if self.target is None or finalize: + info += self._format_time(time_per_unit, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + + info = f" - ETA: {eta_format}" + + for k in self._values_order: + info += f" - {k}:" + if isinstance(self._values[k], list): + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if abs(avg) > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + else: + info += f" {self._values[k]}" + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += " " * (prev_total_width - self._total_width) + + if finalize: + info += "\n" + + message += info + print_msg(message, line_break=False) + message = "" + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ("%" + str(numdigits) + "d/%d") % (current, self.target) + info = count + info + for k in self._values_order: + info += f" - {k}:" + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if avg > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + if self._time_at_epoch_end: + time_per_epoch = ( + self._time_at_epoch_end - self._time_at_epoch_start + ) + avg_time_per_step = time_per_epoch / self.target + self._time_at_epoch_start = now + self._time_at_epoch_end = None + info += " -" + self._format_time(time_per_epoch, "epoch") + info += " -" + self._format_time( + avg_time_per_step, self.unit_name + ) + info += "\n" + message += info + print_msg(message, line_break=False) + message = "" + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _format_time(self, time_per_unit, unit_name): + """format a given duration to display to the user. + + Given the duration, this function formats it in either milliseconds + or seconds and displays the unit (i.e. ms/step or s/epoch) + Args: + time_per_unit: the duration to display + unit_name: the name of the unit to display + Returns: + a string with the correctly formatted duration and units + """ + formatted = "" + if time_per_unit >= 1 or time_per_unit == 0: + formatted += f" {time_per_unit:.0f}s/{unit_name}" + elif time_per_unit >= 1e-3: + formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" else: - eta_format = "%ds" % eta - - info = f" - ETA: {eta_format}" - - for k in self._values_order: - info += f" - {k}:" - if isinstance(self._values[k], list): - avg = np.mean( - self._values[k][0] / max(1, self._values[k][1]) - ) - if abs(avg) > 1e-3: - info += f" {avg:.4f}" - else: - info += f" {avg:.4e}" + formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" + return formatted + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + + Given the step number `current` and the corresponding time `now` this + function returns an estimate for how long a single step takes. If this + is called before one step has been completed (i.e. `current == 0`) then + zero is given as an estimate. The duration estimate ignores the duration + of the (assumed to be non-representative) first step for estimates when + more steps are available (i.e. `current>1`). + + Args: + current: Index of current step. + now: The current time. + + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying + # step 1 + # 2) somebody is calling the progress bar and supplies step one + # multiple times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / ( + current - 1 + ) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit else: - info += f" {self._values[k]}" - - self._total_width += len(info) - if prev_total_width > self._total_width: - info += " " * (prev_total_width - self._total_width) - - if finalize: - info += "\n" - - message += info - print_msg(message, line_break=False) - message = "" - - elif self.verbose == 2: - if finalize: - numdigits = int(np.log10(self.target)) + 1 - count = ("%" + str(numdigits) + "d/%d") % (current, self.target) - info = count + info - for k in self._values_order: - info += f" - {k}:" - avg = np.mean( - self._values[k][0] / max(1, self._values[k][1]) - ) - if avg > 1e-3: - info += f" {avg:.4f}" - else: - info += f" {avg:.4e}" - if self._time_at_epoch_end: - time_per_epoch = ( - self._time_at_epoch_end - self._time_at_epoch_start - ) - avg_time_per_step = time_per_epoch / self.target - self._time_at_epoch_start = now - self._time_at_epoch_end = None - info += " -" + self._format_time(time_per_epoch, "epoch") - info += " -" + self._format_time( - avg_time_per_step, self.unit_name - ) - info += "\n" - message += info - print_msg(message, line_break=False) - message = "" + return 0 - self._last_update = now + def _update_stateful_metrics(self, stateful_metrics): + self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) - def add(self, n, values=None): - self.update(self._seen_so_far + n, values) - def _format_time(self, time_per_unit, unit_name): - """format a given duration to display to the user. +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). - Given the duration, this function formats it in either milliseconds - or seconds and displays the unit (i.e. ms/step or s/epoch) Args: - time_per_unit: the duration to display - unit_name: the name of the unit to display + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + Returns: - a string with the correctly formatted duration and units + A list of tuples of array indices. """ - formatted = "" - if time_per_unit >= 1 or time_per_unit == 0: - formatted += f" {time_per_unit:.0f}s/{unit_name}" - elif time_per_unit >= 1e-3: - formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" - else: - formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" - return formatted + num_batches = int(np.ceil(size / float(batch_size))) + return [ + (i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches) + ] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. - def _estimate_step_duration(self, current, now): - """Estimate the duration of a single step. + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list - Given the step number `current` and the corresponding time `now` this - function returns an estimate for how long a single step takes. If this - is called before one step has been completed (i.e. `current == 0`) then - zero is given as an estimate. The duration estimate ignores the duration - of the (assumed to be non-representative) first step for estimates when - more steps are available (i.e. `current>1`). + Can also work on list/array of indices: `slice_arrays(x, indices)` Args: - current: Index of current step. - now: The current time. + arrays: Single array or list of arrays. + start: can be an integer index (start index) or a list/array of indices + stop: integer (stop index); should be None if `start` was a list. - Returns: Estimate of the duration of a single step. + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. """ - if current: - # there are a few special scenarios here: - # 1) somebody is calling the progress bar without ever supplying - # step 1 - # 2) somebody is calling the progress bar and supplies step one - # multiple times, e.g. as part of a finalizing call - # in these cases, we just fall back to the simple calculation - if self._time_after_first_step is not None and current > 1: - time_per_unit = (now - self._time_after_first_step) / ( - current - 1 + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError( + "The stop argument has to be None if the value of start " + f"is a list. Received start={start}, stop={stop}" ) - else: - time_per_unit = (now - self._start) / current - - if current == 1: - self._time_after_first_step = now - return time_per_unit + elif isinstance(arrays, list): + if hasattr(start, "__len__"): + # hdf5 datasets only support list objects as indices + if hasattr(start, "shape"): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + return [ + None + if x is None + else None + if not hasattr(x, "__getitem__") + else x[start:stop] + for x in arrays + ] else: - return 0 - - def _update_stateful_metrics(self, stateful_metrics): - self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) - - -def make_batches(size, batch_size): - """Returns a list of batch indices (tuples of indices). - - Args: - size: Integer, total size of the data to slice into batches. - batch_size: Integer, batch size. - - Returns: - A list of tuples of array indices. - """ - num_batches = int(np.ceil(size / float(batch_size))) - return [ - (i * batch_size, min(size, (i + 1) * batch_size)) - for i in range(0, num_batches) - ] - - -def slice_arrays(arrays, start=None, stop=None): - """Slice an array or list of arrays. - - This takes an array-like, or a list of - array-likes, and outputs: - - arrays[start:stop] if `arrays` is an array-like - - [x[start:stop] for x in arrays] if `arrays` is a list - - Can also work on list/array of indices: `slice_arrays(x, indices)` - - Args: - arrays: Single array or list of arrays. - start: can be an integer index (start index) or a list/array of indices - stop: integer (stop index); should be None if `start` was a list. - - Returns: - A slice of the array(s). - - Raises: - ValueError: If the value of start is a list and stop is not None. - """ - if arrays is None: - return [None] - if isinstance(start, list) and stop is not None: - raise ValueError( - "The stop argument has to be None if the value of start " - f"is a list. Received start={start}, stop={stop}" - ) - elif isinstance(arrays, list): - if hasattr(start, "__len__"): - # hdf5 datasets only support list objects as indices - if hasattr(start, "shape"): - start = start.tolist() - return [None if x is None else x[start] for x in arrays] - return [ - None - if x is None - else None - if not hasattr(x, "__getitem__") - else x[start:stop] - for x in arrays - ] - else: - if hasattr(start, "__len__"): - if hasattr(start, "shape"): - start = start.tolist() - return arrays[start] - if hasattr(start, "__getitem__"): - return arrays[start:stop] - return [None] + if hasattr(start, "__len__"): + if hasattr(start, "shape"): + start = start.tolist() + return arrays[start] + if hasattr(start, "__getitem__"): + return arrays[start:stop] + return [None] def to_list(x): - """Normalizes a list/tensor into a list. + """Normalizes a list/tensor into a list. - If a tensor is passed, we return - a list of size 1 containing the tensor. + If a tensor is passed, we return + a list of size 1 containing the tensor. - Args: - x: target object to be normalized. + Args: + x: target object to be normalized. - Returns: - A list. - """ - if isinstance(x, list): - return x - return [x] + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] def to_snake_case(name): - intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) - insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() - # If the class is private the name starts with "_" which is not secure - # for creating scopes. We prefix the name with "private" in this case. - if insecure[0] != "_": - return insecure - return "private" + insecure + intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != "_": + return insecure + return "private" + insecure def check_for_unexpected_keys(name, input_dict, expected_values): - unknown = set(input_dict.keys()).difference(expected_values) - if unknown: - raise ValueError( - f"Unknown entries in {name} dictionary: {list(unknown)}. " - f"Only expected following keys: {expected_values}" - ) + unknown = set(input_dict.keys()).difference(expected_values) + if unknown: + raise ValueError( + f"Unknown entries in {name} dictionary: {list(unknown)}. " + f"Only expected following keys: {expected_values}" + ) def validate_kwargs( kwargs, allowed_kwargs, error_message="Keyword argument not understood:" ): - """Checks that all keyword arguments are in the set of allowed keys.""" - for kwarg in kwargs: - if kwarg not in allowed_kwargs: - raise TypeError(error_message, kwarg) + """Checks that all keyword arguments are in the set of allowed keys.""" + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError(error_message, kwarg) def default(method): - """Decorates a method to detect overrides in subclasses.""" - method._is_default = True - return method + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method def is_default(method): - """Check if a method is decorated with the `default` wrapper.""" - return getattr(method, "_is_default", False) + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) def populate_dict_with_module_objects(target_dict, modules, obj_filter): - for module in modules: - for name in dir(module): - obj = getattr(module, name) - if obj_filter(obj): - target_dict[name] = obj + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if obj_filter(obj): + target_dict[name] = obj class LazyLoader(python_types.ModuleType): - """Lazily import a module, mainly to avoid pulling in large dependencies.""" + """Lazily import a module, mainly to avoid pulling in large dependencies.""" - def __init__(self, local_name, parent_module_globals, name): - self._local_name = local_name - self._parent_module_globals = parent_module_globals - super().__init__(name) + def __init__(self, local_name, parent_module_globals, name): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + super().__init__(name) - def _load(self): - """Load the module and insert it into the parent's globals.""" - # Import the target module and insert it into the parent's namespace - module = importlib.import_module(self.__name__) - self._parent_module_globals[self._local_name] = module - # Update this object's dict so that if someone keeps a reference to the - # LazyLoader, lookups are efficient (__getattr__ is only called on - # lookups that fail). - self.__dict__.update(module.__dict__) - return module + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on + # lookups that fail). + self.__dict__.update(module.__dict__) + return module - def __getattr__(self, item): - module = self._load() - return getattr(module, item) + def __getattr__(self, item): + module = self._load() + return getattr(module, item) def print_msg(message, line_break=True): - """Print the message to absl logging or stdout.""" - if line_break: - sys.stdout.write(message + "\n") - else: - sys.stdout.write(message) - sys.stdout.flush() + """Print the message to absl logging or stdout.""" + if line_break: + sys.stdout.write(message + "\n") + else: + sys.stdout.write(message) + sys.stdout.flush() diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py index 885a3d88d..be655c3b5 100644 --- a/brainpy/_src/tools/tests/test_functions.py +++ b/brainpy/_src/tools/tests/test_functions.py @@ -1,4 +1,3 @@ - import unittest import brainpy as bp @@ -6,19 +5,15 @@ class TestFunction(unittest.TestCase): - def test_compose(self): - f = lambda a: a + 1 - g = lambda a: a * 10 - fun1 = bp.tools.compose(f, g) - fun2 = bp.tools.pipe(g, f) - - arr = bm.random.randn(10) - r1 = fun1(arr) - r2 = fun2(arr) - groundtruth = f(g(arr)) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, groundtruth)) - - - - + def test_compose(self): + f = lambda a: a + 1 + g = lambda a: a * 10 + fun1 = bp.tools.compose(f, g) + fun2 = bp.tools.pipe(g, f) + + arr = bm.random.randn(10) + r1 = fun1(arr) + r2 = fun2(arr) + groundtruth = f(g(arr)) + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, groundtruth)) diff --git a/brainpy/_src/train/_utils.py b/brainpy/_src/train/_utils.py index c6052ba28..78ceb4870 100644 --- a/brainpy/_src/train/_utils.py +++ b/brainpy/_src/train/_utils.py @@ -3,11 +3,11 @@ import jax.numpy as jnp import brainpy.math as bm -from brainpy.check import is_dict_data from brainpy._src.dynsys import DynamicalSystem +from brainpy.check import is_dict_data __all__ = [ - 'format_ys' + 'format_ys' ] msg = ''' @@ -30,35 +30,35 @@ def data(batch_size): def format_ys(cls, ys): - if isinstance(ys, (bm.BaseArray, jnp.ndarray)): - if len(cls.train_nodes) == 1: - ys = {cls.train_nodes[0].name: ys} - else: - raise ValueError(f'The network\n {cls.target} \nhas {len(cls.train_nodes)} ' - f'training nodes, while we only got one target data.') - is_dict_data(ys, key_type=str, val_type=(bm.BaseArray, jnp.ndarray)) + if isinstance(ys, (bm.BaseArray, jnp.ndarray)): + if len(cls.train_nodes) == 1: + ys = {cls.train_nodes[0].name: ys} + else: + raise ValueError(f'The network\n {cls.target} \nhas {len(cls.train_nodes)} ' + f'training nodes, while we only got one target data.') + is_dict_data(ys, key_type=str, val_type=(bm.BaseArray, jnp.ndarray)) - # check data path - abs_node_names = [node.name for node in cls.train_nodes] - formatted_ys = {} - ys_not_included = {} - for k, v in ys.items(): - if k in abs_node_names: - formatted_ys[k] = v - else: - ys_not_included[k] = v - if len(ys_not_included): - rel_nodes = cls.target.nodes('relative', level=-1, include_self=True).subset(DynamicalSystem).unique() - for k, v in ys_not_included.items(): - if k in rel_nodes: - formatted_ys[rel_nodes[k].name] = v - else: - raise ValueError(f'Unknown target "{k}" for fitting.') + # check data path + abs_node_names = [node.name for node in cls.train_nodes] + formatted_ys = {} + ys_not_included = {} + for k, v in ys.items(): + if k in abs_node_names: + formatted_ys[k] = v + else: + ys_not_included[k] = v + if len(ys_not_included): + rel_nodes = cls.target.nodes('relative', level=-1, include_self=True).subset(DynamicalSystem).unique() + for k, v in ys_not_included.items(): + if k in rel_nodes: + formatted_ys[rel_nodes[k].name] = v + else: + raise ValueError(f'Unknown target "{k}" for fitting.') - # check data shape - for key, val in formatted_ys.items(): - if val.ndim < 3: - raise ValueError("Targets must be a tensor with shape of " - "(batch, time, feature, ...) or (time, batch, feature, ...)" - f"but we got {val.shape}") - return formatted_ys + # check data shape + for key, val in formatted_ys.items(): + if val.ndim < 3: + raise ValueError("Targets must be a tensor with shape of " + "(batch, time, feature, ...) or (time, batch, feature, ...)" + f"but we got {val.shape}") + return formatted_ys diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py index c583bcf40..92c9aea5a 100644 --- a/brainpy/_src/train/back_propagation.py +++ b/brainpy/_src/train/back_propagation.py @@ -14,611 +14,611 @@ from brainpy import optim from brainpy import tools from brainpy._src.context import share -from brainpy._src.helpers import clear_input from brainpy._src.dynsys import DynamicalSystem +from brainpy._src.helpers import clear_input from brainpy._src.running import constants as c from brainpy.errors import UnsupportedError, NoLongerSupportError from brainpy.types import ArrayType, Output from ._utils import msg from .base import DSTrainer - __all__ = [ - 'BPTT', - 'BPFF', + 'BPTT', + 'BPFF', ] def _is_brainpy_array(s): - return isinstance(s, bm.BaseArray) + return isinstance(s, bm.BaseArray) class BPTrainer(DSTrainer): - """Trainer implementing back-propagation algorithm for supervised trasks. - - For more parameters, users should refer to :py:class:`~.DSRunner`. - - Parameters:: - - target: DynamicalSystem - The target model to train. - loss_fun: str, callable - The loss function. If it is a string, it should be the - function chosen from ``brainpy.losses`` module. Otherwise, - a callable function which receives argument of `(predicts, targets)` - should be provided. - loss_has_aux: bool - To indicate whether the `loss_fun` returns auxiliary data. - loss_auto_run: bool - pass - optimizer: optim.Optimizer - The optimizer used for training. - numpy_mon_after_run: bool - Make the monitored results as NumPy arrays. - logger: Any - A file-like object (stream); defaults to the current `sys.stdout`. - shuffle_data: bool - .. deprecated:: 2.2.4.1 - Control the data shuffling by user self. - seed: int - .. deprecated:: 2.2.4.1 - Control the data shuffling by user self. - - kwargs: Any - Other general parameters please see :py:class:`~.DSRunner`. - """ - - def __init__( - self, - target: DynamicalSystem, - loss_fun: Union[str, Callable], # loss function - optimizer: optim.Optimizer = None, # optimizer - loss_has_aux: bool = False, # loss auxiliary - loss_auto_run: bool = True, # loss auxiliary - - # ------------- - # API deprecated - seed: int = None, # deprecated - shuffle_data: bool = None, # deprecated - - **kwargs, - ): - super().__init__(target=target, **kwargs) - - if shuffle_data is not None: - raise NoLongerSupportError( - f''' - "shuffle_data" is no longer supported. ' - To be general, users should shuffle their data by themself. - - See https://github.com/brainpy/BrainPy/releases/tag/V2.3.1 - for the solution of how to fix this. - ''' - ) - if seed is not None: - NoLongerSupportError('"seed" is no longer supported. ' - 'Please shuffle your data by yourself.') - - # jit settings - if isinstance(self._origin_jit, bool): - self.jit[c.PREDICT_PHASE] = self.jit.get(c.PREDICT_PHASE, self._origin_jit) - self.jit[c.LOSS_PHASE] = self.jit.get(c.LOSS_PHASE, self._origin_jit) - self.jit[c.FIT_PHASE] = self.jit.get(c.FIT_PHASE, self._origin_jit) - else: - self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True) - self.jit[c.LOSS_PHASE] = self._origin_jit.get(c.LOSS_PHASE, True) - self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True) - - # optimizer - if optimizer is None: - lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) - optimizer = optim.Adam(lr=lr) - self.optimizer: optim.Optimizer = optimizer - if len(self.optimizer.vars_to_train) == 0: - self.optimizer.register_train_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique()) - - # loss function - self.loss_has_aux = loss_has_aux - if isinstance(loss_fun, str): - loss_fun = getattr(losses, loss_fun) - elif callable(loss_fun): - loss_fun = loss_fun - else: - raise UnsupportedError(f'Do not support {type(loss_fun)} to specify the loss function. ' - f'We only support str and callable function.') - self._loss_func = loss_fun - self.loss_auto_run = loss_auto_run - - # loss data - self._report_train_metrics = dict() - self._report_test_metrics = dict() - self._detailed_train_metrics = dict() - self._detailed_test_metrics = dict() - - # functions - self._jit_step_func_grad = bm.jit(self._step_func_grad, static_argnums=(0,)) - self._jit_step_func_loss = bm.jit(self._step_func_loss, static_argnums=(0,)) - self._jit_step_func_fit = bm.jit(self._step_func_fit, static_argnums=(0,)) - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}jit={self.jit}, \n\t' - f'{prefix}loss={self._loss_func}, \n\t' - f'{prefix}optimizer={self.optimizer})') - - def get_hist_metric(self, phase='fit', metric='loss', which='report'): - """Get history losses.""" - assert phase in [c.FIT_PHASE, c.TEST_PHASE, c.TRAIN_PHASE, c.PREDICT_PHASE] - assert which in ['report', 'detailed'] - if phase in [c.FIT_PHASE, c.TRAIN_PHASE]: - if which == 'report': - return self._report_train_metrics.get(metric, None) - elif which == 'detailed': - return self._detailed_train_metrics.get(metric, None) - elif phase in [c.TEST_PHASE, c.PREDICT_PHASE]: - if which == 'report': - return self._report_test_metrics.get(metric, None) - elif which == 'detailed': - return self._detailed_test_metrics.get(metric, None) - - @property - def train_losses(self): - return self.get_hist_metric(phase='fit') - - @property - def test_losses(self): - return self.get_hist_metric(phase='test') - - def fit( - self, - train_data: Union[Callable, Iterable], - test_data: Optional[Union[Callable, Iterable]] = None, - num_epoch: int = 100, - num_report: int = -1, - reset_state: bool = True, - shared_args: Optional[Dict] = None, - fun_after_report: Optional[Callable] = None, - - # ------ - # API deprecated - batch_size: int = None, - ): - """Fit the target model according to the given training data. + """Trainer implementing back-propagation algorithm for supervised trasks. - Parameters:: + For more parameters, users should refer to :py:class:`~.DSRunner`. - train_data: callable, iterable - It can be a callable function, or a tuple/list representing `(X, Y)` data. - - Callable. This function should return a pair of `(X, Y)` data. - - Iterable. It should be a pair of `(X, Y)` train set. - - ``X``: should be a tensor or a dict of tensors with the shape of - `(num_sample, num_time, ...)`, where `num_sample` is - the number of samples, `num_time` is the number of the time step, - and `num_feature` is the number of features. - - - ``Y``: Target values. A tensor or a dict of tensors. - - If the shape of each tensor is `(num_sample, num_feature)`, - then we will only fit the model with the only last output. - - If the shape of each tensor is `(num_sample, num_time, num_feature)`, - then the fitting happens on the whole data series. - test_data: callable, iterable, optional - Same as ``train_data``. - num_epoch: int - The number of training epoch. Default 100. - num_report: int - The number of step to report the progress. - If `num_report=-1`, it will report the training progress each epoch. - reset_state: bool - Whether reset the initial states of the target model. - shared_args: dict - The shared keyword arguments for the target models. - fun_after_report: optional, Callable - The function to call after each report of `fit` phase or `test` phase. - The function should receive three arguments: - - ``idx`` for the indicator the current the running index. (If ``report=-1``, - The running index is the epoch. Otherwise, is the 'fit_idx' for 'fit' phase - and 'test_idx' for 'test' phase). - - ``metrics``: the metrics defined in the loss function - - ``phase``: to indicate the phase of 'fit' or 'test'. - - .. versionadded:: 2.3.1 - batch_size: int + Parameters:: + target: DynamicalSystem + The target model to train. + loss_fun: str, callable + The loss function. If it is a string, it should be the + function chosen from ``brainpy.losses`` module. Otherwise, + a callable function which receives argument of `(predicts, targets)` + should be provided. + loss_has_aux: bool + To indicate whether the `loss_fun` returns auxiliary data. + loss_auto_run: bool + pass + optimizer: optim.Optimizer + The optimizer used for training. + numpy_mon_after_run: bool + Make the monitored results as NumPy arrays. + logger: Any + A file-like object (stream); defaults to the current `sys.stdout`. + shuffle_data: bool .. deprecated:: 2.2.4.1 - Please set batch size in your dataset. + Control the data shuffling by user self. + seed: int + .. deprecated:: 2.2.4.1 + Control the data shuffling by user self. + kwargs: Any + Other general parameters please see :py:class:`~.DSRunner`. """ - if shared_args is None: - shared_args = dict() - shared_args['fit'] = shared_args.get('fit', True) - shared_args = tools.DotDict(shared_args) - - if batch_size is not None: - raise NoLongerSupportError('Please set batch size in your data. ' - 'Specifically, make an iterable dataset ' - 'which return a batch of (X, Y) data.') - if isinstance(train_data, (tuple, list)): - if len(train_data) == 2: - raise UnsupportedError(msg) - - if fun_after_report is not None: - assert callable(fun_after_report), ('\n' - 'Unknown "fun_after_report", ' - 'it should be a callable function receiving ' - 'three arguments: idx, metrics, phase') - - if shared_args is None: - shared_args = dict() - shared_args['fit'] = shared_args.get('fit', True) - - true_progress_bar = self.progress_bar - self.progress_bar = False - - # training the model - detailed_train_metric = dict() - report_train_metric = dict() - detailed_test_metric = dict() - report_test_metric = dict() - - fit_i, fit_t = 0, 0 - test_i, test_t = 0, 0 - for epoch_idx in range(num_epoch): - - # training set - fit_t0 = time.time() - fit_epoch_metric = dict(loss=[]) - _training_data = train_data() if callable(train_data) else train_data - if hasattr(_training_data, '__len__'): - bar = tqdm(total=len(_training_data)) - else: - bar = None - - for x, y in _training_data: - # reset state - if reset_state: - self.target.reset(self._get_input_batch_size(x)) - self.reset_state() - - # training - res = self.f_train(shared_args, x, y) - - # loss - fit_epoch_metric['loss'].append(res[0]) - if self.loss_has_aux: - if not isinstance(res[1], dict): - raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}') - for k, v in res[1].items(): - if k not in fit_epoch_metric: - fit_epoch_metric[k] = [] - fit_epoch_metric[k].append(v) - if bar is not None: - bar.update(1) - - # report - fit_i += 1 - if num_report > 0 and fit_i % num_report == 0: - fit_t1 = time.time() - aux = {} - for k, v in fit_epoch_metric.items(): - aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) - if k not in report_train_metric: - report_train_metric[k] = [] - detailed_train_metric[k] = [] - report_train_metric[k].append(aux[k]) - detailed_train_metric[k].extend(v) - v.clear() - _report = (f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - if bar is not None: - bar.set_description(_report, refresh=True) - else: - print(_report) - if fun_after_report is not None: - fun_after_report(fit_i, aux, 'fit') - fit_t0 = time.time() - fit_t = 0 - - if num_report <= 0: - fit_t1 = time.time() - aux = {} - for k, v in fit_epoch_metric.items(): - aux[k] = np.mean(np.asarray(v)) - if k not in report_train_metric: - report_train_metric[k] = [] - detailed_train_metric[k] = [] - report_train_metric[k].append(aux[k]) - detailed_train_metric[k].extend(v) - v.clear() - _report = (f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - if bar is not None: - bar.set_description(_report, refresh=True) + + def __init__( + self, + target: DynamicalSystem, + loss_fun: Union[str, Callable], # loss function + optimizer: optim.Optimizer = None, # optimizer + loss_has_aux: bool = False, # loss auxiliary + loss_auto_run: bool = True, # loss auxiliary + + # ------------- + # API deprecated + seed: int = None, # deprecated + shuffle_data: bool = None, # deprecated + + **kwargs, + ): + super().__init__(target=target, **kwargs) + + if shuffle_data is not None: + raise NoLongerSupportError( + f''' + "shuffle_data" is no longer supported. ' + To be general, users should shuffle their data by themself. + + See https://github.com/brainpy/BrainPy/releases/tag/V2.3.1 + for the solution of how to fix this. + ''' + ) + if seed is not None: + NoLongerSupportError('"seed" is no longer supported. ' + 'Please shuffle your data by yourself.') + + # jit settings + if isinstance(self._origin_jit, bool): + self.jit[c.PREDICT_PHASE] = self.jit.get(c.PREDICT_PHASE, self._origin_jit) + self.jit[c.LOSS_PHASE] = self.jit.get(c.LOSS_PHASE, self._origin_jit) + self.jit[c.FIT_PHASE] = self.jit.get(c.FIT_PHASE, self._origin_jit) else: - print(_report) - if fun_after_report is not None: - fun_after_report(epoch_idx, aux, 'fit') - else: - fit_t = time.time() - fit_t0 - self.optimizer.lr.step_epoch() - if bar is not None: bar.close() - - # testing set - if test_data is not None: - test_t0 = time.time() - test_epoch_metric = dict(loss=[]) - _testing_data = test_data() if callable(test_data) else test_data - if hasattr(_testing_data, '__len__'): - bar = tqdm(total=len(_testing_data)) + self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True) + self.jit[c.LOSS_PHASE] = self._origin_jit.get(c.LOSS_PHASE, True) + self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True) + + # optimizer + if optimizer is None: + lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) + optimizer = optim.Adam(lr=lr) + self.optimizer: optim.Optimizer = optimizer + if len(self.optimizer.vars_to_train) == 0: + self.optimizer.register_train_vars( + self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique()) + + # loss function + self.loss_has_aux = loss_has_aux + if isinstance(loss_fun, str): + loss_fun = getattr(losses, loss_fun) + elif callable(loss_fun): + loss_fun = loss_fun else: - bar = None - for x, y in _testing_data: - # reset state - if reset_state: - self.target.reset(self._get_input_batch_size(x)) - self.reset_state() + raise UnsupportedError(f'Do not support {type(loss_fun)} to specify the loss function. ' + f'We only support str and callable function.') + self._loss_func = loss_fun + self.loss_auto_run = loss_auto_run + + # loss data + self._report_train_metrics = dict() + self._report_test_metrics = dict() + self._detailed_train_metrics = dict() + self._detailed_test_metrics = dict() + + # functions + self._jit_step_func_grad = bm.jit(self._step_func_grad, static_argnums=(0,)) + self._jit_step_func_loss = bm.jit(self._step_func_loss, static_argnums=(0,)) + self._jit_step_func_fit = bm.jit(self._step_func_fit, static_argnums=(0,)) + + def __repr__(self): + name = self.__class__.__name__ + prefix = ' ' * len(name) + return (f'{name}(target={self.target}, \n\t' + f'{prefix}jit={self.jit}, \n\t' + f'{prefix}loss={self._loss_func}, \n\t' + f'{prefix}optimizer={self.optimizer})') + + def get_hist_metric(self, phase='fit', metric='loss', which='report'): + """Get history losses.""" + assert phase in [c.FIT_PHASE, c.TEST_PHASE, c.TRAIN_PHASE, c.PREDICT_PHASE] + assert which in ['report', 'detailed'] + if phase in [c.FIT_PHASE, c.TRAIN_PHASE]: + if which == 'report': + return self._report_train_metrics.get(metric, None) + elif which == 'detailed': + return self._detailed_train_metrics.get(metric, None) + elif phase in [c.TEST_PHASE, c.PREDICT_PHASE]: + if which == 'report': + return self._report_test_metrics.get(metric, None) + elif which == 'detailed': + return self._detailed_test_metrics.get(metric, None) + + @property + def train_losses(self): + return self.get_hist_metric(phase='fit') + + @property + def test_losses(self): + return self.get_hist_metric(phase='test') + + def fit( + self, + train_data: Union[Callable, Iterable], + test_data: Optional[Union[Callable, Iterable]] = None, + num_epoch: int = 100, + num_report: int = -1, + reset_state: bool = True, + shared_args: Optional[Dict] = None, + fun_after_report: Optional[Callable] = None, + + # ------ + # API deprecated + batch_size: int = None, + ): + """Fit the target model according to the given training data. + + Parameters:: + + train_data: callable, iterable + It can be a callable function, or a tuple/list representing `(X, Y)` data. + - Callable. This function should return a pair of `(X, Y)` data. + - Iterable. It should be a pair of `(X, Y)` train set. + - ``X``: should be a tensor or a dict of tensors with the shape of + `(num_sample, num_time, ...)`, where `num_sample` is + the number of samples, `num_time` is the number of the time step, + and `num_feature` is the number of features. + + - ``Y``: Target values. A tensor or a dict of tensors. + - If the shape of each tensor is `(num_sample, num_feature)`, + then we will only fit the model with the only last output. + - If the shape of each tensor is `(num_sample, num_time, num_feature)`, + then the fitting happens on the whole data series. + test_data: callable, iterable, optional + Same as ``train_data``. + num_epoch: int + The number of training epoch. Default 100. + num_report: int + The number of step to report the progress. + If `num_report=-1`, it will report the training progress each epoch. + reset_state: bool + Whether reset the initial states of the target model. + shared_args: dict + The shared keyword arguments for the target models. + fun_after_report: optional, Callable + The function to call after each report of `fit` phase or `test` phase. + The function should receive three arguments: + - ``idx`` for the indicator the current the running index. (If ``report=-1``, + The running index is the epoch. Otherwise, is the 'fit_idx' for 'fit' phase + and 'test_idx' for 'test' phase). + - ``metrics``: the metrics defined in the loss function + - ``phase``: to indicate the phase of 'fit' or 'test'. + + .. versionadded:: 2.3.1 + batch_size: int + + .. deprecated:: 2.2.4.1 + Please set batch size in your dataset. + + """ + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', True) + shared_args = tools.DotDict(shared_args) + + if batch_size is not None: + raise NoLongerSupportError('Please set batch size in your data. ' + 'Specifically, make an iterable dataset ' + 'which return a batch of (X, Y) data.') + if isinstance(train_data, (tuple, list)): + if len(train_data) == 2: + raise UnsupportedError(msg) - # testing - res = self.f_loss(shared_args, x, y) - - # loss - if self.loss_has_aux: - test_epoch_metric['loss'].append(res[0]) - if not isinstance(res[1], dict): - raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}') - for k, v in res[1].items(): - if k not in test_epoch_metric: - test_epoch_metric[k] = [] - test_epoch_metric[k].append(v) - else: - test_epoch_metric['loss'].append(res) - - if bar is not None: bar.update(1) - - # report - test_i += 1 - if num_report > 0 and test_i % num_report == 0: - test_t1 = time.time() - aux = {} - for k, v in test_epoch_metric.items(): - aux[k] = np.mean(np.asarray(v)) - if k not in report_test_metric: - report_test_metric[k] = [] - detailed_test_metric[k] = [] - report_test_metric[k].append(aux[k]) - detailed_test_metric[k].extend(v) - v.clear() - _report = (f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - if bar is not None: - bar.set_description(_report, refresh=True) + if fun_after_report is not None: + assert callable(fun_after_report), ('\n' + 'Unknown "fun_after_report", ' + 'it should be a callable function receiving ' + 'three arguments: idx, metrics, phase') + + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', True) + + true_progress_bar = self.progress_bar + self.progress_bar = False + + # training the model + detailed_train_metric = dict() + report_train_metric = dict() + detailed_test_metric = dict() + report_test_metric = dict() + + fit_i, fit_t = 0, 0 + test_i, test_t = 0, 0 + for epoch_idx in range(num_epoch): + + # training set + fit_t0 = time.time() + fit_epoch_metric = dict(loss=[]) + _training_data = train_data() if callable(train_data) else train_data + if hasattr(_training_data, '__len__'): + bar = tqdm(total=len(_training_data)) else: - print(_report) - if fun_after_report is not None: - fun_after_report(test_i, aux, 'test') - test_t0 = time.time() - test_t = 0 - - if num_report <= 0: - test_t1 = time.time() - aux = {} - for k, v in test_epoch_metric.items(): - aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) - if k not in report_test_metric: - report_test_metric[k] = [] - detailed_test_metric[k] = [] - report_test_metric[k].append(aux[k]) - detailed_test_metric[k].extend(v) - v.clear() - _report = (f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' + - ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - if bar is not None: - bar.set_description(_report, refresh=True) - else: - print(_report) - if fun_after_report is not None: - fun_after_report(epoch_idx, aux, 'test') - else: - test_t = time.time() - test_t0 - - if bar is not None: bar.close() - - # finally - self._report_train_metrics = {k: np.asarray(v) for k, v in report_train_metric.items()} - self._detailed_train_metrics = {k: np.asarray(v) for k, v in detailed_train_metric.items()} - self._report_test_metrics = {k: np.asarray(v) for k, v in report_test_metric.items()} - self._detailed_test_metrics = {k: np.asarray(v) for k, v in detailed_test_metric.items()} - self.progress_bar = true_progress_bar - - def _step_func_grad(self, shared_args, inputs, targets): - tran_vars = self.target.train_vars().unique() - grad_f = bm.grad(self._step_func_loss, - grad_vars=tran_vars, - return_value=True, - has_aux=self.loss_has_aux) - return grad_f(shared_args, inputs, targets) - - def _step_func_loss(self, shared_args, inputs, targets): - raise NotImplementedError - - @property - def f_loss(self): - return self._jit_step_func_loss if self.jit[c.LOSS_PHASE] else self._step_func_loss - - def _step_func_fit(self, shared_args, inputs, targets): - raise NotImplementedError - - @property - def f_train(self): - return self._jit_step_func_fit if self.jit[c.FIT_PHASE] else self._step_func_fit - - @property - def f_grad(self): - return self._jit_step_func_grad if self.jit[c.FIT_PHASE] else self._step_func_grad + bar = None + + for x, y in _training_data: + # reset state + if reset_state: + self.target.reset(self._get_input_batch_size(x)) + self.reset_state() + + # training + res = self.f_train(shared_args, x, y) + + # loss + fit_epoch_metric['loss'].append(res[0]) + if self.loss_has_aux: + if not isinstance(res[1], dict): + raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}') + for k, v in res[1].items(): + if k not in fit_epoch_metric: + fit_epoch_metric[k] = [] + fit_epoch_metric[k].append(v) + if bar is not None: + bar.update(1) + + # report + fit_i += 1 + if num_report > 0 and fit_i % num_report == 0: + fit_t1 = time.time() + aux = {} + for k, v in fit_epoch_metric.items(): + aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) + if k not in report_train_metric: + report_train_metric[k] = [] + detailed_train_metric[k] = [] + report_train_metric[k].append(aux[k]) + detailed_train_metric[k].extend(v) + v.clear() + _report = (f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) + if fun_after_report is not None: + fun_after_report(fit_i, aux, 'fit') + fit_t0 = time.time() + fit_t = 0 + + if num_report <= 0: + fit_t1 = time.time() + aux = {} + for k, v in fit_epoch_metric.items(): + aux[k] = np.mean(np.asarray(v)) + if k not in report_train_metric: + report_train_metric[k] = [] + detailed_train_metric[k] = [] + report_train_metric[k].append(aux[k]) + detailed_train_metric[k].extend(v) + v.clear() + _report = (f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) + if fun_after_report is not None: + fun_after_report(epoch_idx, aux, 'fit') + else: + fit_t = time.time() - fit_t0 + self.optimizer.lr.step_epoch() + if bar is not None: bar.close() + + # testing set + if test_data is not None: + test_t0 = time.time() + test_epoch_metric = dict(loss=[]) + _testing_data = test_data() if callable(test_data) else test_data + if hasattr(_testing_data, '__len__'): + bar = tqdm(total=len(_testing_data)) + else: + bar = None + for x, y in _testing_data: + # reset state + if reset_state: + self.target.reset(self._get_input_batch_size(x)) + self.reset_state() + + # testing + res = self.f_loss(shared_args, x, y) + + # loss + if self.loss_has_aux: + test_epoch_metric['loss'].append(res[0]) + if not isinstance(res[1], dict): + raise TypeError(f'Auxiliary data in loss function should be a dict. But we got {type(res)}') + for k, v in res[1].items(): + if k not in test_epoch_metric: + test_epoch_metric[k] = [] + test_epoch_metric[k].append(v) + else: + test_epoch_metric['loss'].append(res) + + if bar is not None: bar.update(1) + + # report + test_i += 1 + if num_report > 0 and test_i % num_report == 0: + test_t1 = time.time() + aux = {} + for k, v in test_epoch_metric.items(): + aux[k] = np.mean(np.asarray(v)) + if k not in report_test_metric: + report_test_metric[k] = [] + detailed_test_metric[k] = [] + report_test_metric[k].append(aux[k]) + detailed_test_metric[k].extend(v) + v.clear() + _report = (f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) + if fun_after_report is not None: + fun_after_report(test_i, aux, 'test') + test_t0 = time.time() + test_t = 0 + + if num_report <= 0: + test_t1 = time.time() + aux = {} + for k, v in test_epoch_metric.items(): + aux[k] = jnp.mean(bm.as_jax(bm.asarray(v))) + if k not in report_test_metric: + report_test_metric[k] = [] + detailed_test_metric[k] = [] + report_test_metric[k].append(aux[k]) + detailed_test_metric[k].extend(v) + v.clear() + _report = (f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' + + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) + if fun_after_report is not None: + fun_after_report(epoch_idx, aux, 'test') + else: + test_t = time.time() - test_t0 + + if bar is not None: bar.close() + + # finally + self._report_train_metrics = {k: np.asarray(v) for k, v in report_train_metric.items()} + self._detailed_train_metrics = {k: np.asarray(v) for k, v in detailed_train_metric.items()} + self._report_test_metrics = {k: np.asarray(v) for k, v in report_test_metric.items()} + self._detailed_test_metrics = {k: np.asarray(v) for k, v in detailed_test_metric.items()} + self.progress_bar = true_progress_bar + + def _step_func_grad(self, shared_args, inputs, targets): + tran_vars = self.target.train_vars().unique() + grad_f = bm.grad(self._step_func_loss, + grad_vars=tran_vars, + return_value=True, + has_aux=self.loss_has_aux) + return grad_f(shared_args, inputs, targets) + + def _step_func_loss(self, shared_args, inputs, targets): + raise NotImplementedError + + @property + def f_loss(self): + return self._jit_step_func_loss if self.jit[c.LOSS_PHASE] else self._step_func_loss + + def _step_func_fit(self, shared_args, inputs, targets): + raise NotImplementedError + + @property + def f_train(self): + return self._jit_step_func_fit if self.jit[c.FIT_PHASE] else self._step_func_fit + + @property + def f_grad(self): + return self._jit_step_func_grad if self.jit[c.FIT_PHASE] else self._step_func_grad class BPTT(BPTrainer): - """The trainer implementing the back-propagation through time (BPTT) - algorithm for training dyamical systems. - - For more parameters, users should refer to :py:class:`~.DSRunner`. + """The trainer implementing the back-propagation through time (BPTT) + algorithm for training dyamical systems. - Parameters:: + For more parameters, users should refer to :py:class:`~.DSRunner`. - target: DynamicalSystem - The target model to train. + Parameters:: - loss_fun: str, callable - The loss function. + target: DynamicalSystem + The target model to train. + + loss_fun: str, callable + The loss function. + + - If it is a string, it should be the function chosen from ``brainpy.losses`` module. + - Otherwise, a callable function which receives argument of ``(predicts, targets)`` + should be provided. + + .. note:: + If ``monitors`` has been set in the trainer, the ``predicts`` contains two + parts: the network history prediction outputs, and the monitored values. + + see BrainPy examples for more information. + loss_has_aux: bool + To indicate whether the loss function returns auxiliary data expect the loss. + Moreover, all auxiliary data should be a dict, whose key is used for logging + item name and its data is used for the corresponding value. + For example, + + .. code-block:: python + + def loss_fun(predicts, targets): + return loss, {'acc': acc, 'spike_num': spike_num} + optimizer: Optimizer + The optimizer used for training. Should be an instance of :py:class:`~.Optimizer`. + numpy_mon_after_run: bool + Make the monitored results as NumPy arrays. + logger: Any + A file-like object (stream). Used to output the running results. Default is the current `sys.stdout`. + data_first_axis: str + To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the + time length (``data_first_axis='T'``). + """ - - If it is a string, it should be the function chosen from ``brainpy.losses`` module. - - Otherwise, a callable function which receives argument of ``(predicts, targets)`` - should be provided. + def _step_func_loss(self, shared_args, inputs, targets): + num_step = self._get_input_time_step(xs=inputs) + indices = np.arange(self.i0, self.i0 + num_step, dtype=np.int_) + if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B': + inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), inputs, is_leaf=lambda x: isinstance(x, bm.BaseArray)) + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + outs, mons = self._predict(indices, *inputs, shared_args=shared_args) + predicts = (outs, mons) if len(mons) > 0 else outs + return self._loss_func(predicts, targets) - .. note:: - If ``monitors`` has been set in the trainer, the ``predicts`` contains two - parts: the network history prediction outputs, and the monitored values. - - see BrainPy examples for more information. - loss_has_aux: bool - To indicate whether the loss function returns auxiliary data expect the loss. - Moreover, all auxiliary data should be a dict, whose key is used for logging - item name and its data is used for the corresponding value. - For example, - - .. code-block:: python - - def loss_fun(predicts, targets): - return loss, {'acc': acc, 'spike_num': spike_num} - optimizer: Optimizer - The optimizer used for training. Should be an instance of :py:class:`~.Optimizer`. - numpy_mon_after_run: bool - Make the monitored results as NumPy arrays. - logger: Any - A file-like object (stream). Used to output the running results. Default is the current `sys.stdout`. - data_first_axis: str - To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the - time length (``data_first_axis='T'``). - """ - - def _step_func_loss(self, shared_args, inputs, targets): - num_step = self._get_input_time_step(xs=inputs) - indices = np.arange(self.i0, self.i0 + num_step, dtype=np.int_) - if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B': - inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), inputs, is_leaf=lambda x: isinstance(x, bm.BaseArray)) - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - outs, mons = self._predict(indices, *inputs, shared_args=shared_args) - predicts = (outs, mons) if len(mons) > 0 else outs - return self._loss_func(predicts, targets) - - def _step_func_fit(self, shared_args, inputs, targets): - res = self.f_grad(shared_args, inputs, targets) - self.optimizer.update(res[0]) - return res[1:] + def _step_func_fit(self, shared_args, inputs, targets): + res = self.f_grad(shared_args, inputs, targets) + self.optimizer.update(res[0]) + return res[1:] class BPFF(BPTrainer): - """ - The trainer implementing back propagation algorithm - for feedforward neural networks. - - For more parameters, users should refer to :py:class:`~.DSRunner`. - - """ - - def _step_func_loss(self, shared_args, inputs, targets): - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - outputs, mon = self._step_func_predict(*inputs, shared_args=shared_args) - outs = (outputs, mon) if len(mon) > 0 else outputs - loss = self._loss_func(outs, targets) - return loss - - def _step_func_fit(self, shared_args, inputs, targets): - res = self.f_grad(shared_args, inputs, targets) - self.optimizer.update(res[0]) - return res[1:] - - def _step_func_predict(self, *x, shared_args=None): - assert self.data_first_axis == 'B', (f'There is no time dimension when ' - f'using the trainer {self.__class__.__name__}.') - if shared_args is not None: - assert isinstance(shared_args, dict) - share.save(**shared_args) - share.save(dt=self.dt) - - # input step - clear_input(self.target) - self._step_func_input() - - # dynamics update step - out = self.target(*x) - - # monitor step - mon = self._step_func_monitor() - # share.clear_shargs() - return out, mon - - def _fun_predict(self, *inputs, shared_args=None): - if self.jit['predict']: - return self._jit_step_func_predict(*inputs, shared_args=shared_args) - else: - return self._step_func_predict(*inputs, shared_args=shared_args) - - def predict( - self, - inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], - reset_state: bool = True, - shared_args: Dict = None, - eval_time: bool = False - ) -> Output: - """Predict a series of input data with the given target model. - - This function use the JIT compilation to accelerate the model simulation. - Moreover, it can automatically monitor the node variables, states, inputs, - feedbacks and its output. - - Parameters:: - - inputs: ArrayType, dict - The feedforward input data. It must be a 3-dimensional data - which has the shape of `(num_sample, num_time, num_feature)`. - reset_state: bool - Whether reset the model states. - shared_args: optional, dict - The shared arguments across different layers. - eval_time: bool - Evaluate the time used for running. + """ + The trainer implementing back propagation algorithm + for feedforward neural networks. - Returns:: + For more parameters, users should refer to :py:class:`~.DSRunner`. - output: ArrayType, dict - The model output. """ - if shared_args is None: - shared_args = dict() - shared_args['fit'] = shared_args.get('fit', False) - shared_args = tools.DotDict(shared_args) - - # reset the model states - if reset_state: - self.target.reset(self._get_input_batch_size(xs=inputs)) - self.reset_state() - # init monitor - for key in self._monitors.keys(): - self.mon[key] = [] # reshape the monitor items - # prediction - if not isinstance(inputs, (tuple, list)): - inputs = (inputs,) - if eval_time: t0 = time.time() - outs, hists = self._fun_predict(*inputs, shared_args=shared_args) - if eval_time: t1 = time.time() - # post-running for monitors - for key in hists.keys(): - self.mon[key] = bm.asarray(hists[key]) - if self.numpy_mon_after_run: - for key in hists.keys(): - self.mon[key] = np.asarray(self.mon[key]) - return (t1 - t0, outs) if eval_time else outs + + def _step_func_loss(self, shared_args, inputs, targets): + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + outputs, mon = self._step_func_predict(*inputs, shared_args=shared_args) + outs = (outputs, mon) if len(mon) > 0 else outputs + loss = self._loss_func(outs, targets) + return loss + + def _step_func_fit(self, shared_args, inputs, targets): + res = self.f_grad(shared_args, inputs, targets) + self.optimizer.update(res[0]) + return res[1:] + + def _step_func_predict(self, *x, shared_args=None): + assert self.data_first_axis == 'B', (f'There is no time dimension when ' + f'using the trainer {self.__class__.__name__}.') + if shared_args is not None: + assert isinstance(shared_args, dict) + share.save(**shared_args) + share.save(dt=self.dt) + + # input step + clear_input(self.target) + self._step_func_input() + + # dynamics update step + out = self.target(*x) + + # monitor step + mon = self._step_func_monitor() + # share.clear_shargs() + return out, mon + + def _fun_predict(self, *inputs, shared_args=None): + if self.jit['predict']: + return self._jit_step_func_predict(*inputs, shared_args=shared_args) + else: + return self._step_func_predict(*inputs, shared_args=shared_args) + + def predict( + self, + inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], + reset_state: bool = True, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Predict a series of input data with the given target model. + + This function use the JIT compilation to accelerate the model simulation. + Moreover, it can automatically monitor the node variables, states, inputs, + feedbacks and its output. + + Parameters:: + + inputs: ArrayType, dict + The feedforward input data. It must be a 3-dimensional data + which has the shape of `(num_sample, num_time, num_feature)`. + reset_state: bool + Whether reset the model states. + shared_args: optional, dict + The shared arguments across different layers. + eval_time: bool + Evaluate the time used for running. + + Returns:: + + output: ArrayType, dict + The model output. + """ + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', False) + shared_args = tools.DotDict(shared_args) + + # reset the model states + if reset_state: + self.target.reset(self._get_input_batch_size(xs=inputs)) + self.reset_state() + # init monitor + for key in self._monitors.keys(): + self.mon[key] = [] # reshape the monitor items + # prediction + if not isinstance(inputs, (tuple, list)): + inputs = (inputs,) + if eval_time: t0 = time.time() + outs, hists = self._fun_predict(*inputs, shared_args=shared_args) + if eval_time: t1 = time.time() + # post-running for monitors + for key in hists.keys(): + self.mon[key] = bm.asarray(hists[key]) + if self.numpy_mon_after_run: + for key in hists.keys(): + self.mon[key] = np.asarray(self.mon[key]) + return (t1 - t0, outs) if eval_time else outs diff --git a/brainpy/_src/train/base.py b/brainpy/_src/train/base.py index a660b9792..e6e79a886 100644 --- a/brainpy/_src/train/base.py +++ b/brainpy/_src/train/base.py @@ -10,40 +10,40 @@ from brainpy.types import Output __all__ = [ - 'DSTrainer', + 'DSTrainer', ] class DSTrainer(DSRunner): - """Structural Trainer for Dynamical Systems. + """Structural Trainer for Dynamical Systems. - For more parameters, users should refer to :py:class:`~.DSRunner`. + For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters:: - target: DynamicalSystem - The training target. + target: DynamicalSystem + The training target. - kwargs: Any - Other general parameters in :py:class:`~.DSRunner`. + kwargs: Any + Other general parameters in :py:class:`~.DSRunner`. - """ + """ - target: DynamicalSystem - '''The training target.''' + target: DynamicalSystem + '''The training target.''' - train_nodes: Sequence[DynamicalSystem] # need to be initialized by subclass - '''All children nodes in this training target.''' + train_nodes: Sequence[DynamicalSystem] # need to be initialized by subclass + '''All children nodes in this training target.''' - def __init__( - self, - target: DynamicalSystem, - **kwargs - ): - super().__init__(target=target, **kwargs) + def __init__( + self, + target: DynamicalSystem, + **kwargs + ): + super().__init__(target=target, **kwargs) - if not isinstance(self.target.mode, bm.BatchingMode): - raise NoLongerSupportError(f''' + if not isinstance(self.target.mode, bm.BatchingMode): + raise NoLongerSupportError(f''' From version 2.3.1, DSTrainer must receive a DynamicalSystem instance with the computing mode of {bm.batching_mode} or {bm.training_mode}. @@ -51,51 +51,51 @@ def __init__( for the solution of how to fix this. ''') - # jit - if isinstance(self._origin_jit, bool): - self.jit[c.PREDICT_PHASE] = self._origin_jit - self.jit[c.FIT_PHASE] = self._origin_jit - else: - self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True) - self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True) - - def predict( - self, - inputs: Any, - reset_state: bool = False, - shared_args: Optional[Dict] = None, - eval_time: bool = False - ) -> Output: - """Prediction function. - - Parameters:: - - inputs: ArrayType, sequence of ArrayType, dict of ArrayType - The input values. - reset_state: bool - Reset the target state before running. - eval_time: bool - Whether we evaluate the running time or not? - shared_args: dict - The shared arguments across nodes. - - Returns:: - - output: ArrayType, sequence of ArrayType, dict of ArrayType - The running output. - """ - if shared_args is None: - shared_args = dict() - shared_args['fit'] = shared_args.get('fit', False) - return super().predict(inputs=inputs, - reset_state=reset_state, - shared_args=shared_args, - eval_time=eval_time) - - def fit( - self, - train_data: Any, - reset_state: bool = False, - shared_args: Dict = None - ) -> Output: # need to be implemented by subclass - raise NotImplementedError('Must implement the fit function. ') + # jit + if isinstance(self._origin_jit, bool): + self.jit[c.PREDICT_PHASE] = self._origin_jit + self.jit[c.FIT_PHASE] = self._origin_jit + else: + self.jit[c.PREDICT_PHASE] = self._origin_jit.get(c.PREDICT_PHASE, True) + self.jit[c.FIT_PHASE] = self._origin_jit.get(c.FIT_PHASE, True) + + def predict( + self, + inputs: Any, + reset_state: bool = False, + shared_args: Optional[Dict] = None, + eval_time: bool = False + ) -> Output: + """Prediction function. + + Parameters:: + + inputs: ArrayType, sequence of ArrayType, dict of ArrayType + The input values. + reset_state: bool + Reset the target state before running. + eval_time: bool + Whether we evaluate the running time or not? + shared_args: dict + The shared arguments across nodes. + + Returns:: + + output: ArrayType, sequence of ArrayType, dict of ArrayType + The running output. + """ + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', False) + return super().predict(inputs=inputs, + reset_state=reset_state, + shared_args=shared_args, + eval_time=eval_time) + + def fit( + self, + train_data: Any, + reset_state: bool = False, + shared_args: Dict = None + ) -> Output: # need to be implemented by subclass + raise NotImplementedError('Must implement the fit function. ') diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index 1fa8b423e..dab014e4a 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -18,234 +18,234 @@ from .base import DSTrainer __all__ = [ - 'OfflineTrainer', - 'RidgeTrainer', + 'OfflineTrainer', + 'RidgeTrainer', ] class OfflineTrainer(DSTrainer): - """Offline trainer for models with recurrent dynamics. - - For more parameters, users should refer to :py:class:`~.DSRunner`. - - Parameters:: - - target: DynamicalSystem - The target model to train. - fit_method: OfflineAlgorithm, Callable, dict, str - The fitting method applied to the target model. - - It can be a string, which specify the shortcut name of the training algorithm. - Like, ``fit_method='ridge'`` means using the Ridge regression method. - All supported fitting methods can be obtained through - :py:func:`~get_supported_offline_methods`. - - It can be a dict, whose "name" item specifies the name of the training algorithm, - and the others parameters specify the initialization parameters of the algorithm. - For example, ``fit_method={'name': 'ridge', 'alpha': 0.1}``. - - It can be an instance of :py:class:`brainpy.algorithms.OfflineAlgorithm`. - For example, ``fit_meth=bp.algorithms.RidgeRegression(alpha=0.1)``. - - It can also be a callable function, which receives three arguments "targets", "x" and "y". - For example, ``fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]``. - - kwargs: Any - Other general parameters please see :py:class:`~.DSRunner`. - """ - - def __init__( - self, - target: DynamicalSystem, - fit_method: Union[OfflineAlgorithm, Callable, Dict, str] = None, - **kwargs - ): - self._true_numpy_mon_after_run = kwargs.get('numpy_mon_after_run', True) - kwargs['numpy_mon_after_run'] = False - super().__init__(target=target, **kwargs) - - # get all trainable nodes - nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() - self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)]) - if len(self.train_nodes) == 0: - raise ValueError('Found no trainable nodes.') - - # check the required interface in the trainable nodes - self._check_interface() - - # training method - if fit_method is None: - fit_method = RidgeRegression(alpha=1e-7) - elif isinstance(fit_method, str): - fit_method = get(fit_method)() - elif isinstance(fit_method, dict): - name = fit_method.pop('name') - fit_method = get(name)(**fit_method) - if not callable(fit_method): - raise ValueError(f'"train_method" must be an instance of callable function, ' - f'but we got {type(fit_method)}.') - self.fit_method = fit_method - - # set the training method - for node in self.train_nodes: - node.offline_fit_by = fit_method - # training function - self._jit_fun_train = bm.jit(self._fun_train) - - def __repr__(self): - name = self.__class__.__name__ - prefix = ' ' * len(name) - return (f'{name}(target={self.target}, \n\t' - f'{prefix}fit_method={self.fit_method})') - - def predict( - self, - inputs: Any, - reset_state: bool = False, - shared_args: Dict = None, - eval_time: bool = False - ) -> Output: - """Prediction function. - - What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that - the `inputs_are_batching` is default `True`. + """Offline trainer for models with recurrent dynamics. - Parameters:: - - inputs: ArrayType - The input values. - reset_state: bool - Reset the target state before running. - eval_time: bool - Whether we evaluate the running time or not? - shared_args: dict - The shared arguments across nodes. - - Returns:: - - output: ArrayType - The running output. - """ - outs = super().predict(inputs=inputs, reset_state=reset_state, - eval_time=eval_time, shared_args=shared_args) - for node in self.train_nodes: - node.fit_record.clear() - return outs - - def fit( - self, - train_data: Sequence, - reset_state: bool = False, - shared_args: Dict = None, - ) -> Output: - """Fit the target model according to the given training and testing data. + For more parameters, users should refer to :py:class:`~.DSRunner`. Parameters:: - train_data: sequence of data - It should be a pair of `(X, Y)` train set. - - ``X``: should be a tensor or a dict of tensors with the shape of - `(num_sample, num_time, num_feature)`, where `num_sample` is - the number of samples, `num_time` is the number of the time step, - and `num_feature` is the number of features. - - ``Y``: Target values. A tensor or a dict of tensors. - - If the shape of each tensor is `(num_sample, num_feature)`, - then we will only fit the model with the only last output. - - If the shape of each tensor is `(num_sample, num_time, num_feature)`, - then the fitting happens on the whole data series. - reset_state: bool - Whether reset the initial states of the target model. - shared_args: dict - The shared keyword arguments for the target models. + target: DynamicalSystem + The target model to train. + fit_method: OfflineAlgorithm, Callable, dict, str + The fitting method applied to the target model. + - It can be a string, which specify the shortcut name of the training algorithm. + Like, ``fit_method='ridge'`` means using the Ridge regression method. + All supported fitting methods can be obtained through + :py:func:`~get_supported_offline_methods`. + - It can be a dict, whose "name" item specifies the name of the training algorithm, + and the others parameters specify the initialization parameters of the algorithm. + For example, ``fit_method={'name': 'ridge', 'alpha': 0.1}``. + - It can be an instance of :py:class:`brainpy.algorithms.OfflineAlgorithm`. + For example, ``fit_meth=bp.algorithms.RidgeRegression(alpha=0.1)``. + - It can also be a callable function, which receives three arguments "targets", "x" and "y". + For example, ``fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]``. + + kwargs: Any + Other general parameters please see :py:class:`~.DSRunner`. """ - if shared_args is None: - shared_args = dict() - shared_args['fit'] = shared_args.get('fit', True) - shared_args = tools.DotDict(shared_args) - - # checking training and testing data - if not isinstance(train_data, (list, tuple)): - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a {type(train_data)}.") - if len(train_data) != 2: - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a sequence with length {len(train_data)}") - xs, ys = train_data - - # prediction, get all needed data - shared_args['fit'] = shared_args.get('fit', False) - outs = self.predict(inputs=xs, reset_state=reset_state, shared_args=shared_args) - - # check target data - ys = format_ys(self, ys) - - # init progress bar - if self.progress_bar: - self._pbar = tqdm.auto.tqdm(total=len(self.train_nodes)) - self._pbar.set_description(f"Train {len(self.train_nodes)} nodes: ", refresh=True) - - # training - monitor_data = dict() - for node in self.train_nodes: - key = f'{node.name}-fit_record' - monitor_data[key] = self.mon.get(key) - run_fun = self._jit_fun_train if self.jit['fit'] else self._fun_train - shared_args['fit'] = True - run_fun(monitor_data, ys) - del monitor_data - - # close the progress bar - if self.progress_bar: - self._pbar.close() - - # final things - for node in self.train_nodes: - # Only pop if the key exists - fit_record_key = f'{node.name}-fit_record' - if fit_record_key in self.mon: - self.mon.pop(fit_record_key) - node.fit_record.clear() # clear fit records - if self._true_numpy_mon_after_run: - for key in self.mon.keys(): - self.mon[key] = np.asarray(self.mon[key]) - - return outs - - def _fun_train(self, - monitor_data: Dict[str, ArrayType], - target_data: Dict[str, ArrayType], - shared_args: Dict = None): - if shared_args is None: - shared_args = dict() - share.save(**shared_args) - - for node in self.train_nodes: - fit_record_key = f'{node.name}-fit_record' - fit_record = monitor_data.get(fit_record_key, None) - targets = target_data[node.name] - node.offline_fit(targets, fit_record) - if self.progress_bar: - jax.debug.callback(lambda *args: self._pbar.update(), ()) - - def _step_func_monitor(self): - res = dict() - for key, val in self._monitors.items(): - if callable(val): - res[key] = _call_fun_with_share(val) - else: - (variable, idx) = val - if idx is None: - res[key] = variable.value - else: - res[key] = variable[bm.asarray(idx)] - if share.load('fit'): - for node in self.train_nodes: - res[f'{node.name}-fit_record'] = node.fit_record - return res - - def _check_interface(self): - for node in self.train_nodes: - if not isinstance(node, SupportOffline): - raise TypeError( - f''' + + def __init__( + self, + target: DynamicalSystem, + fit_method: Union[OfflineAlgorithm, Callable, Dict, str] = None, + **kwargs + ): + self._true_numpy_mon_after_run = kwargs.get('numpy_mon_after_run', True) + kwargs['numpy_mon_after_run'] = False + super().__init__(target=target, **kwargs) + + # get all trainable nodes + nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() + self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)]) + if len(self.train_nodes) == 0: + raise ValueError('Found no trainable nodes.') + + # check the required interface in the trainable nodes + self._check_interface() + + # training method + if fit_method is None: + fit_method = RidgeRegression(alpha=1e-7) + elif isinstance(fit_method, str): + fit_method = get(fit_method)() + elif isinstance(fit_method, dict): + name = fit_method.pop('name') + fit_method = get(name)(**fit_method) + if not callable(fit_method): + raise ValueError(f'"train_method" must be an instance of callable function, ' + f'but we got {type(fit_method)}.') + self.fit_method = fit_method + + # set the training method + for node in self.train_nodes: + node.offline_fit_by = fit_method + # training function + self._jit_fun_train = bm.jit(self._fun_train) + + def __repr__(self): + name = self.__class__.__name__ + prefix = ' ' * len(name) + return (f'{name}(target={self.target}, \n\t' + f'{prefix}fit_method={self.fit_method})') + + def predict( + self, + inputs: Any, + reset_state: bool = False, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Prediction function. + + What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that + the `inputs_are_batching` is default `True`. + + Parameters:: + + inputs: ArrayType + The input values. + reset_state: bool + Reset the target state before running. + eval_time: bool + Whether we evaluate the running time or not? + shared_args: dict + The shared arguments across nodes. + + Returns:: + + output: ArrayType + The running output. + """ + outs = super().predict(inputs=inputs, reset_state=reset_state, + eval_time=eval_time, shared_args=shared_args) + for node in self.train_nodes: + node.fit_record.clear() + return outs + + def fit( + self, + train_data: Sequence, + reset_state: bool = False, + shared_args: Dict = None, + ) -> Output: + """Fit the target model according to the given training and testing data. + + Parameters:: + + train_data: sequence of data + It should be a pair of `(X, Y)` train set. + - ``X``: should be a tensor or a dict of tensors with the shape of + `(num_sample, num_time, num_feature)`, where `num_sample` is + the number of samples, `num_time` is the number of the time step, + and `num_feature` is the number of features. + - ``Y``: Target values. A tensor or a dict of tensors. + - If the shape of each tensor is `(num_sample, num_feature)`, + then we will only fit the model with the only last output. + - If the shape of each tensor is `(num_sample, num_time, num_feature)`, + then the fitting happens on the whole data series. + reset_state: bool + Whether reset the initial states of the target model. + shared_args: dict + The shared keyword arguments for the target models. + """ + if shared_args is None: + shared_args = dict() + shared_args['fit'] = shared_args.get('fit', True) + shared_args = tools.DotDict(shared_args) + + # checking training and testing data + if not isinstance(train_data, (list, tuple)): + raise ValueError(f"{self.__class__.__name__} only support " + f"training data with the format of (X, Y) pair, " + f"but we got a {type(train_data)}.") + if len(train_data) != 2: + raise ValueError(f"{self.__class__.__name__} only support " + f"training data with the format of (X, Y) pair, " + f"but we got a sequence with length {len(train_data)}") + xs, ys = train_data + + # prediction, get all needed data + shared_args['fit'] = shared_args.get('fit', False) + outs = self.predict(inputs=xs, reset_state=reset_state, shared_args=shared_args) + + # check target data + ys = format_ys(self, ys) + + # init progress bar + if self.progress_bar: + self._pbar = tqdm.auto.tqdm(total=len(self.train_nodes)) + self._pbar.set_description(f"Train {len(self.train_nodes)} nodes: ", refresh=True) + + # training + monitor_data = dict() + for node in self.train_nodes: + key = f'{node.name}-fit_record' + monitor_data[key] = self.mon.get(key) + run_fun = self._jit_fun_train if self.jit['fit'] else self._fun_train + shared_args['fit'] = True + run_fun(monitor_data, ys) + del monitor_data + + # close the progress bar + if self.progress_bar: + self._pbar.close() + + # final things + for node in self.train_nodes: + # Only pop if the key exists + fit_record_key = f'{node.name}-fit_record' + if fit_record_key in self.mon: + self.mon.pop(fit_record_key) + node.fit_record.clear() # clear fit records + if self._true_numpy_mon_after_run: + for key in self.mon.keys(): + self.mon[key] = np.asarray(self.mon[key]) + + return outs + + def _fun_train(self, + monitor_data: Dict[str, ArrayType], + target_data: Dict[str, ArrayType], + shared_args: Dict = None): + if shared_args is None: + shared_args = dict() + share.save(**shared_args) + + for node in self.train_nodes: + fit_record_key = f'{node.name}-fit_record' + fit_record = monitor_data.get(fit_record_key, None) + targets = target_data[node.name] + node.offline_fit(targets, fit_record) + if self.progress_bar: + jax.debug.callback(lambda *args: self._pbar.update(), ()) + + def _step_func_monitor(self): + res = dict() + for key, val in self._monitors.items(): + if callable(val): + res[key] = _call_fun_with_share(val) + else: + (variable, idx) = val + if idx is None: + res[key] = variable.value + else: + res[key] = variable[bm.asarray(idx)] + if share.load('fit'): + for node in self.train_nodes: + res[f'{node.name}-fit_record'] = node.fit_record + return res + + def _check_interface(self): + for node in self.train_nodes: + if not isinstance(node, SupportOffline): + raise TypeError( + f''' The node {node} @@ -253,21 +253,21 @@ def _check_interface(self): is set to be computing mode of {bm.training_mode} with {self.__class__.__name__}. However, {self.__class__.__name__} only support training nodes that are instances of {SupportOffline}. ''' - ) + ) class RidgeTrainer(OfflineTrainer): - """Trainer of ridge regression, also known as regression with Tikhonov regularization. + """Trainer of ridge regression, also known as regression with Tikhonov regularization. - For more parameters, users should refer to :py:class:`~.DSRunner`. + For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters:: - target: TrainingSystem, DynamicalSystem - The target model. - alpha: float - The regularization coefficient. - """ + target: TrainingSystem, DynamicalSystem + The target model. + alpha: float + The regularization coefficient. + """ - def __init__(self, target, alpha=1e-7, **kwargs): - super().__init__(target=target, fit_method=dict(name='ridge', alpha=alpha), **kwargs) + def __init__(self, target, alpha=1e-7, **kwargs): + super().__init__(target=target, fit_method=dict(name='ridge', alpha=alpha), **kwargs) diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 016e96ddf..c20ffb02c 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -10,8 +10,8 @@ from brainpy import math as bm, tools from brainpy._src.context import share from brainpy._src.dynsys import DynamicalSystem -from brainpy._src.mixin import SupportOnline from brainpy._src.helpers import clear_input +from brainpy._src.mixin import SupportOnline from brainpy._src.runners import _call_fun_with_share from brainpy.algorithms.online import get, OnlineAlgorithm, RLS from brainpy.types import ArrayType, Output @@ -19,273 +19,273 @@ from .base import DSTrainer __all__ = [ - 'OnlineTrainer', - 'ForceTrainer', + 'OnlineTrainer', + 'ForceTrainer', ] class OnlineTrainer(DSTrainer): - """Online trainer for models with recurrent dynamics. - - For more parameters, users should refer to :py:class:`~.DSRunner`. - - Parameters:: - - target: DynamicalSystem - The target model to train. - - fit_method: OnlineAlgorithm, Callable, dict, str - The fitting method applied to the target model. - - - It can be a string, which specify the shortcut name of the training algorithm. - Like, ``fit_method='rls'`` means using the RLS method. - All supported fitting methods can be obtained through - :py:func:`~.get_supported_online_methods`. - - It can be a dict, whose "name" item specifies the name of the training algorithm, - and the others parameters specify the initialization parameters of the algorithm. - For example, ``fit_method={'name': 'rls', 'alpha': 0.1}``. - - It can be an instance of :py:class:`brainpy.algorithms.OnlineAlgorithm`. - For example, ``fit_meth=bp.algorithms.RLS(alpha=1e-5)``. - - It can also be a callable function. - - kwargs: Any - Other general parameters please see :py:class:`~.DSRunner`. - """ - - def __init__( - self, - target: DynamicalSystem, - fit_method: Union[OnlineAlgorithm, Callable, Dict, str] = None, - **kwargs - ): - super().__init__(target=target, **kwargs) - - # get all trainable nodes - nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() - self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)]) - if len(self.train_nodes) == 0: - raise ValueError('Found no trainable nodes.') - - # check the required interface in the trainable nodes - self._check_interface() - - # training method - if fit_method is None: - fit_method = RLS(alpha=1e-7) - elif isinstance(fit_method, str): - fit_method = get(fit_method)() - elif isinstance(fit_method, dict): - name = fit_method.pop('name') - fit_method = get(name)(**fit_method) - self.fit_method = fit_method - if not callable(fit_method): - raise ValueError(f'"train_method" must be an instance of callable function, ' - f'but we got {type(fit_method)}.') - - # set the training method - for node in self.train_nodes: - node.online_fit_by = fit_method - - # initialize the fitting method - for node in self.train_nodes: - node.online_init() - - # training function - self._f_fit_compiled = dict() - - def __repr__(self): - name = self.__class__.__name__ - indent = ' ' * len(name) - indent2 = indent + " " * len("target") - return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n' - f'{indent}jit={self.jit}, \n' - f'{indent}fit_method={self.fit_method})') - - def predict( - self, - inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], - reset_state: bool = False, - shared_args: Dict = None, - eval_time: bool = False - ) -> Output: - """Prediction function. - - What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that - the `inputs_are_batching` is default `True`. - - Parameters:: - - inputs: ArrayType - The input values. - reset_state: bool - Reset the target state before running. - shared_args: dict - The shared arguments across nodes. - eval_time: bool - Whether we evaluate the running time or not? + """Online trainer for models with recurrent dynamics. - Returns:: - - output: ArrayType - The running output. - """ - outs = super().predict(inputs=inputs, - reset_state=reset_state, - shared_args=shared_args, - eval_time=eval_time) - for node in self.train_nodes: - node.fit_record.clear() - return outs - - def fit( - self, - train_data: Sequence, - reset_state: bool = False, - shared_args: Dict = None, - ) -> Output: - if shared_args is None: shared_args = dict() - shared_args['fit'] = shared_args.get('fit', True) - shared_args = tools.DotDict(shared_args) - - # checking training and testing data - if not isinstance(train_data, (list, tuple)): - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a {type(train_data)}.") - if len(train_data) != 2: - raise ValueError(f"{self.__class__.__name__} only support " - f"training data with the format of (X, Y) pair, " - f"but we got a sequence with length {len(train_data)}") - xs, ys = train_data - - # reset the model states - if reset_state: - num_batch = self._get_input_batch_size(xs) - self.target.reset(num_batch) - self.reset_state() - - # format input/target data - ys = format_ys(self, ys) - num_step = self._get_input_time_step(xs=xs) - - indices = np.arange(self.i0, num_step + self.i0, dtype=np.int_) - if self.data_first_axis == 'B': - xs = tree_map(lambda x: bm.moveaxis(x, 0, 1), - xs, - is_leaf=lambda x: isinstance(x, bm.BaseArray)) - ys = tree_map(lambda y: bm.moveaxis(y, 0, 1), - ys, - is_leaf=lambda y: isinstance(y, bm.BaseArray)) - - # init monitor - for key in self._monitors.keys(): - self.mon[key] = [] # reshape the monitor items - - # init progress bar - if self.progress_bar: - self._pbar = tqdm.auto.tqdm(total=num_step) - self._pbar.set_description(f"Train {num_step} steps: ", refresh=True) - - # prediction - xs = (xs, ) if not isinstance(xs, (tuple, list)) else xs - outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args) - - # close the progress bar - if self.progress_bar: - self._pbar.close() - - # post-running for monitors - if self.numpy_mon_after_run: - hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.BaseArray)) - for key in hists.keys(): - self.mon[key] = hists[key] - self.i0 += num_step - return outs - - def _fit(self, - indices: ArrayType, - xs: Sequence, - ys: Dict[str, ArrayType], - shared_args: Dict = None): - """Predict the output according to the inputs. + For more parameters, users should refer to :py:class:`~.DSRunner`. Parameters:: - indices: ArrayType - The running indices. - ys: dict - Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - shared_args: optional, dict - The shared keyword arguments. - - Returns:: - - outputs, hists - A tuple of pair of (outputs, hists). + target: DynamicalSystem + The target model to train. + + fit_method: OnlineAlgorithm, Callable, dict, str + The fitting method applied to the target model. + + - It can be a string, which specify the shortcut name of the training algorithm. + Like, ``fit_method='rls'`` means using the RLS method. + All supported fitting methods can be obtained through + :py:func:`~.get_supported_online_methods`. + - It can be a dict, whose "name" item specifies the name of the training algorithm, + and the others parameters specify the initialization parameters of the algorithm. + For example, ``fit_method={'name': 'rls', 'alpha': 0.1}``. + - It can be an instance of :py:class:`brainpy.algorithms.OnlineAlgorithm`. + For example, ``fit_meth=bp.algorithms.RLS(alpha=1e-5)``. + - It can also be a callable function. + + kwargs: Any + Other general parameters please see :py:class:`~.DSRunner`. """ - hists = bm.for_loop(functools.partial(self._step_func_fit, shared_args=shared_args), - (indices, xs, ys), - jit=self.jit['fit']) - hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), - hists, - is_leaf=lambda x: isinstance(x, bm.BaseArray)) - return hists - - def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None): - if shared_args is None: - shared_args = dict() - share.save(t=i * self.dt, dt=self.dt, i=i, **shared_args) - - # input step - clear_input(self.target) - self._step_func_input() - - # update step - out = self.target(*xs) - - # monitor step - monitors = self._step_func_monitor() - for node in self.train_nodes: - fit_record = monitors.pop(f'{node.name}-fit_record') - target = ys[node.name] - node.online_fit(target, fit_record) - - # finally - if self.progress_bar: - jax.debug.callback(lambda *args: self._pbar.update(), ()) - return out, monitors - - def _check_interface(self): - for node in self.train_nodes: - if not isinstance(node, SupportOnline): - raise TypeError( - f'The node \n\n{node}\n\n' - f'is set to be trainable with {self.__class__.__name__} method. \n' - f'{self.__class__.__name__} only support training nodes that are instances ' - f'of {SupportOnline}. ' - ) - - def _step_func_monitor(self): - res = dict() - for key, val in self._monitors.items(): - if callable(val): - res[key] = _call_fun_with_share(val) - else: - (variable, idx) = val - if idx is None: - res[key] = variable.value - else: - res[key] = variable[bm.asarray(idx)] - if share.load('fit'): - for node in self.train_nodes: - res[f'{node.name}-fit_record'] = node.fit_record - return res + + def __init__( + self, + target: DynamicalSystem, + fit_method: Union[OnlineAlgorithm, Callable, Dict, str] = None, + **kwargs + ): + super().__init__(target=target, **kwargs) + + # get all trainable nodes + nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique() + self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)]) + if len(self.train_nodes) == 0: + raise ValueError('Found no trainable nodes.') + + # check the required interface in the trainable nodes + self._check_interface() + + # training method + if fit_method is None: + fit_method = RLS(alpha=1e-7) + elif isinstance(fit_method, str): + fit_method = get(fit_method)() + elif isinstance(fit_method, dict): + name = fit_method.pop('name') + fit_method = get(name)(**fit_method) + self.fit_method = fit_method + if not callable(fit_method): + raise ValueError(f'"train_method" must be an instance of callable function, ' + f'but we got {type(fit_method)}.') + + # set the training method + for node in self.train_nodes: + node.online_fit_by = fit_method + + # initialize the fitting method + for node in self.train_nodes: + node.online_init() + + # training function + self._f_fit_compiled = dict() + + def __repr__(self): + name = self.__class__.__name__ + indent = ' ' * len(name) + indent2 = indent + " " * len("target") + return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n' + f'{indent}jit={self.jit}, \n' + f'{indent}fit_method={self.fit_method})') + + def predict( + self, + inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]], + reset_state: bool = False, + shared_args: Dict = None, + eval_time: bool = False + ) -> Output: + """Prediction function. + + What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that + the `inputs_are_batching` is default `True`. + + Parameters:: + + inputs: ArrayType + The input values. + reset_state: bool + Reset the target state before running. + shared_args: dict + The shared arguments across nodes. + eval_time: bool + Whether we evaluate the running time or not? + + Returns:: + + output: ArrayType + The running output. + """ + outs = super().predict(inputs=inputs, + reset_state=reset_state, + shared_args=shared_args, + eval_time=eval_time) + for node in self.train_nodes: + node.fit_record.clear() + return outs + + def fit( + self, + train_data: Sequence, + reset_state: bool = False, + shared_args: Dict = None, + ) -> Output: + if shared_args is None: shared_args = dict() + shared_args['fit'] = shared_args.get('fit', True) + shared_args = tools.DotDict(shared_args) + + # checking training and testing data + if not isinstance(train_data, (list, tuple)): + raise ValueError(f"{self.__class__.__name__} only support " + f"training data with the format of (X, Y) pair, " + f"but we got a {type(train_data)}.") + if len(train_data) != 2: + raise ValueError(f"{self.__class__.__name__} only support " + f"training data with the format of (X, Y) pair, " + f"but we got a sequence with length {len(train_data)}") + xs, ys = train_data + + # reset the model states + if reset_state: + num_batch = self._get_input_batch_size(xs) + self.target.reset(num_batch) + self.reset_state() + + # format input/target data + ys = format_ys(self, ys) + num_step = self._get_input_time_step(xs=xs) + + indices = np.arange(self.i0, num_step + self.i0, dtype=np.int_) + if self.data_first_axis == 'B': + xs = tree_map(lambda x: bm.moveaxis(x, 0, 1), + xs, + is_leaf=lambda x: isinstance(x, bm.BaseArray)) + ys = tree_map(lambda y: bm.moveaxis(y, 0, 1), + ys, + is_leaf=lambda y: isinstance(y, bm.BaseArray)) + + # init monitor + for key in self._monitors.keys(): + self.mon[key] = [] # reshape the monitor items + + # init progress bar + if self.progress_bar: + self._pbar = tqdm.auto.tqdm(total=num_step) + self._pbar.set_description(f"Train {num_step} steps: ", refresh=True) + + # prediction + xs = (xs,) if not isinstance(xs, (tuple, list)) else xs + outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args) + + # close the progress bar + if self.progress_bar: + self._pbar.close() + + # post-running for monitors + if self.numpy_mon_after_run: + hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.BaseArray)) + for key in hists.keys(): + self.mon[key] = hists[key] + self.i0 += num_step + return outs + + def _fit(self, + indices: ArrayType, + xs: Sequence, + ys: Dict[str, ArrayType], + shared_args: Dict = None): + """Predict the output according to the inputs. + + Parameters:: + + indices: ArrayType + The running indices. + ys: dict + Each tensor should have the shape of `(num_time, num_batch, num_feature)`. + shared_args: optional, dict + The shared keyword arguments. + + Returns:: + + outputs, hists + A tuple of pair of (outputs, hists). + """ + hists = bm.for_loop(functools.partial(self._step_func_fit, shared_args=shared_args), + (indices, xs, ys), + jit=self.jit['fit']) + hists = tree_map(lambda x: bm.moveaxis(x, 0, 1), + hists, + is_leaf=lambda x: isinstance(x, bm.BaseArray)) + return hists + + def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None): + if shared_args is None: + shared_args = dict() + share.save(t=i * self.dt, dt=self.dt, i=i, **shared_args) + + # input step + clear_input(self.target) + self._step_func_input() + + # update step + out = self.target(*xs) + + # monitor step + monitors = self._step_func_monitor() + for node in self.train_nodes: + fit_record = monitors.pop(f'{node.name}-fit_record') + target = ys[node.name] + node.online_fit(target, fit_record) + + # finally + if self.progress_bar: + jax.debug.callback(lambda *args: self._pbar.update(), ()) + return out, monitors + + def _check_interface(self): + for node in self.train_nodes: + if not isinstance(node, SupportOnline): + raise TypeError( + f'The node \n\n{node}\n\n' + f'is set to be trainable with {self.__class__.__name__} method. \n' + f'{self.__class__.__name__} only support training nodes that are instances ' + f'of {SupportOnline}. ' + ) + + def _step_func_monitor(self): + res = dict() + for key, val in self._monitors.items(): + if callable(val): + res[key] = _call_fun_with_share(val) + else: + (variable, idx) = val + if idx is None: + res[key] = variable.value + else: + res[key] = variable[bm.asarray(idx)] + if share.load('fit'): + for node in self.train_nodes: + res[f'{node.name}-fit_record'] = node.fit_record + return res class ForceTrainer(OnlineTrainer): - """FORCE learning.""" + """FORCE learning.""" - def __init__(self, target, alpha=1., **kwargs): - super(ForceTrainer, self).__init__(target=target, - fit_method=RLS(alpha=alpha), - **kwargs) + def __init__(self, target, alpha=1., **kwargs): + super(ForceTrainer, self).__init__(target=target, + fit_method=RLS(alpha=alpha), + **kwargs) diff --git a/brainpy/_src/transform.py b/brainpy/_src/transform.py index c5070ea19..a9e04ef8a 100644 --- a/brainpy/_src/transform.py +++ b/brainpy/_src/transform.py @@ -14,276 +14,275 @@ from brainpy.types import PyTree __all__ = [ - 'LoopOverTime', + 'LoopOverTime', ] class LoopOverTime(DynamicalSystem): - """Transform a single step :py:class:`~.DynamicalSystem` - into a multiple-step forward propagation :py:class:`~.BrainPyObject`. - - .. note:: - - This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`. - - If the `target` has a batching mode, before sending the data into the wrapped object, - reset the state (``.reset_state(batch_size)``) with the same batch size as in the given data. - - - For more flexible customization, we recommend users to use :py:func:`~.for_loop`, - or :py:class:`~.DSRunner`. - - Examples:: - - This model can be used for network training: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> n_time, n_batch, n_in = 30, 128, 100 - >>> model = bp.Sequential(l1=bp.layers.RNNCell(n_in, 20), - >>> l2=bm.relu, - >>> l3=bp.layers.RNNCell(20, 2)) - >>> over_time = bp.LoopOverTime(model, data_first_axis='T') - >>> over_time.reset_state(n_batch) - (30, 128, 2) - >>> - >>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in)) - >>> print(hist_l3.shape) - >>> - >>> # monitor the "l1" layer state - >>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state, data_first_axis='T') - >>> over_time.reset_state(n_batch) - >>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in)) - >>> print(hist_l3.shape) - (30, 128, 2) - >>> print(hist_l1.shape) - (30, 128, 20) - - It is also able to used in brain simulation models: - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> - >>> hh = bp.neurons.HH(1) - >>> over_time = bp.LoopOverTime(hh, out_vars=hh.V) - >>> - >>> # running with a given duration - >>> _, potentials = over_time(100.) - >>> plt.plot(bm.as_numpy(potentials), label='with given duration') - >>> - >>> # running with the given inputs - >>> _, potentials = over_time(bm.ones(1000) * 5) - >>> plt.plot(bm.as_numpy(potentials), label='with given inputs') - >>> plt.legend() - >>> plt.show() - - - Parameters:: - - target: DynamicalSystem - The target to transform. - no_state: bool - Denoting whether the `target` has the shared argument or not. - - - For ANN layers which are no_state, like :py:class:`~.Dense` or :py:class:`~.Conv2d`, - set `no_state=True` is high efficiently. This is because :math:`Y[t]` only relies on - :math:`X[t]`, and it is not necessary to calculate :math:`Y[t]` step-bt-step. - For this case, we reshape the input from `shape = [T, N, *]` to `shape = [TN, *]`, - send data to the object, and reshape output to `shape = [T, N, *]`. - In this way, the calculation over different time is parralelized. - - out_vars: PyTree - The variables to monitor over the time loop. - t0: float, optional - The start time to run the system. If None, ``t`` will be no longer generated in the loop. - i0: int, optional - The start index to run the system. If None, ``i`` will be no longer generated in the loop. - dt: float - The time step. - shared_arg: dict - The shared arguments across the nodes. - For instance, `shared_arg={'fit': False}` for the prediction phase. - data_first_axis: str - Denoting the type of the first axis of input data. - If ``'T'``, we treat the data as `(time, ...)`. - If ``'B'``, we treat the data as `(batch, time, ...)` when the `target` is in Batching mode. - Default is ``'T'``. - name: str - The transformed object name. - """ - - def __init__( - self, - target: DynamicalSystem, - out_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, - no_state: bool = False, - t0: Optional[float] = 0., - i0: Optional[int] = 0, - dt: Optional[float] = None, - shared_arg: Optional[Dict] = None, - data_first_axis: str = 'T', - name: str = None, - jit: bool = True, - remat: bool = False, - ): - super().__init__(name=name) - assert data_first_axis in ['B', 'T'] - is_integer(i0, 'i0', allow_none=True) - is_float(t0, 't0', allow_none=True) - is_float(dt, 'dt', allow_none=True) - dt = share.dt if dt is None else dt - if shared_arg is None: - shared_arg = dict(dt=dt) - else: - assert isinstance(shared_arg, dict) - shared_arg['dt'] = dt - self.dt = dt - self._t0 = t0 - self._i0 = i0 - self.t0 = None if t0 is None else bm.Variable(bm.as_jax(t0)) - self.i0 = None if i0 is None else bm.Variable(bm.as_jax(i0)) - - self.jit = jit - self.remat = remat - self.shared_arg = shared_arg - self.data_first_axis = data_first_axis - self.target = target - if not isinstance(target, DynamicalSystem): - raise TypeError(f'Must be instance of {DynamicalSystem.__name__}, ' - f'but we got {type(target)}') - self.no_state = no_state - self.out_vars = out_vars - if out_vars is not None: - out_vars, _ = tree_flatten(out_vars, is_leaf=lambda s: isinstance(s, bm.Variable)) - for v in out_vars: - if not isinstance(v, bm.Variable): - raise TypeError('out_vars must be a PyTree of Variable.') - - def __call__( - self, - duration_or_xs: Union[float, PyTree], - ): - """Forward propagation along the time or inputs. + """Transform a single step :py:class:`~.DynamicalSystem` + into a multiple-step forward propagation :py:class:`~.BrainPyObject`. + + .. note:: + + This object transforms a :py:class:`~.DynamicalSystem` into a :py:class:`~.BrainPyObject`. + + If the `target` has a batching mode, before sending the data into the wrapped object, + reset the state (``.reset_state(batch_size)``) with the same batch size as in the given data. + + + For more flexible customization, we recommend users to use :py:func:`~.for_loop`, + or :py:class:`~.DSRunner`. + + Examples:: + + This model can be used for network training: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> n_time, n_batch, n_in = 30, 128, 100 + >>> model = bp.Sequential(l1=bp.layers.RNNCell(n_in, 20), + >>> l2=bm.relu, + >>> l3=bp.layers.RNNCell(20, 2)) + >>> over_time = bp.LoopOverTime(model, data_first_axis='T') + >>> over_time.reset_state(n_batch) + (30, 128, 2) + >>> + >>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in)) + >>> print(hist_l3.shape) + >>> + >>> # monitor the "l1" layer state + >>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state, data_first_axis='T') + >>> over_time.reset_state(n_batch) + >>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in)) + >>> print(hist_l3.shape) + (30, 128, 2) + >>> print(hist_l1.shape) + (30, 128, 20) + + It is also able to used in brain simulation models: + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> + >>> hh = bp.neurons.HH(1) + >>> over_time = bp.LoopOverTime(hh, out_vars=hh.V) + >>> + >>> # running with a given duration + >>> _, potentials = over_time(100.) + >>> plt.plot(bm.as_numpy(potentials), label='with given duration') + >>> + >>> # running with the given inputs + >>> _, potentials = over_time(bm.ones(1000) * 5) + >>> plt.plot(bm.as_numpy(potentials), label='with given inputs') + >>> plt.legend() + >>> plt.show() + Parameters:: - - duration_or_xs: float, PyTree - If `float`, it indicates a running duration. - If a PyTree, it is the given inputs. - - Returns:: - - out: PyTree - The accumulated outputs over time. + + target: DynamicalSystem + The target to transform. + no_state: bool + Denoting whether the `target` has the shared argument or not. + + - For ANN layers which are no_state, like :py:class:`~.Dense` or :py:class:`~.Conv2d`, + set `no_state=True` is high efficiently. This is because :math:`Y[t]` only relies on + :math:`X[t]`, and it is not necessary to calculate :math:`Y[t]` step-bt-step. + For this case, we reshape the input from `shape = [T, N, *]` to `shape = [TN, *]`, + send data to the object, and reshape output to `shape = [T, N, *]`. + In this way, the calculation over different time is parralelized. + + out_vars: PyTree + The variables to monitor over the time loop. + t0: float, optional + The start time to run the system. If None, ``t`` will be no longer generated in the loop. + i0: int, optional + The start index to run the system. If None, ``i`` will be no longer generated in the loop. + dt: float + The time step. + shared_arg: dict + The shared arguments across the nodes. + For instance, `shared_arg={'fit': False}` for the prediction phase. + data_first_axis: str + Denoting the type of the first axis of input data. + If ``'T'``, we treat the data as `(time, ...)`. + If ``'B'``, we treat the data as `(batch, time, ...)` when the `target` is in Batching mode. + Default is ``'T'``. + name: str + The transformed object name. """ - # inputs - if isinstance(duration_or_xs, float): - shared = tools.DotDict() - if self.t0 is not None: - shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value - if self.i0 is not None: - shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value - xs = None - if self.no_state: - raise ValueError('Under the `no_state=True` setting, input cannot be a duration.') - length = shared['t'].shape - - else: - inp_err_msg = ('\n' - 'Input should be a Array PyTree with the shape ' - 'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, ' - 'where B the batch size and T the time length.') - xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.BaseArray)) - if self.target.mode.is_child_of(bm.BatchingMode): - b_idx, t_idx = (1, 0) if self.data_first_axis == 'T' else (0, 1) - - try: - batch = tuple(set([x.shape[b_idx] for x in xs])) - except (AttributeError, IndexError) as e: - raise ValueError(inp_err_msg) from e - if len(batch) != 1: - raise ValueError('\n' - 'Input should be a Array PyTree with the same batch dimension. ' - f'but we got {tree_unflatten(tree, batch)}.') - try: - length = tuple(set([x.shape[t_idx] for x in xs])) - except (AttributeError, IndexError) as e: - raise ValueError(inp_err_msg) from e - if len(batch) != 1: - raise ValueError('\n' - 'Input should be a Array PyTree with the same batch size. ' - f'but we got {tree_unflatten(tree, batch)}.') - if len(length) != 1: - raise ValueError('\n' - 'Input should be a Array PyTree with the same time length. ' - f'but we got {tree_unflatten(tree, length)}.') - - if self.no_state: - xs = [bm.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs] + + def __init__( + self, + target: DynamicalSystem, + out_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, + no_state: bool = False, + t0: Optional[float] = 0., + i0: Optional[int] = 0, + dt: Optional[float] = None, + shared_arg: Optional[Dict] = None, + data_first_axis: str = 'T', + name: str = None, + jit: bool = True, + remat: bool = False, + ): + super().__init__(name=name) + assert data_first_axis in ['B', 'T'] + is_integer(i0, 'i0', allow_none=True) + is_float(t0, 't0', allow_none=True) + is_float(dt, 'dt', allow_none=True) + dt = share.dt if dt is None else dt + if shared_arg is None: + shared_arg = dict(dt=dt) + else: + assert isinstance(shared_arg, dict) + shared_arg['dt'] = dt + self.dt = dt + self._t0 = t0 + self._i0 = i0 + self.t0 = None if t0 is None else bm.Variable(bm.as_jax(t0)) + self.i0 = None if i0 is None else bm.Variable(bm.as_jax(i0)) + + self.jit = jit + self.remat = remat + self.shared_arg = shared_arg + self.data_first_axis = data_first_axis + self.target = target + if not isinstance(target, DynamicalSystem): + raise TypeError(f'Must be instance of {DynamicalSystem.__name__}, ' + f'but we got {type(target)}') + self.no_state = no_state + self.out_vars = out_vars + if out_vars is not None: + out_vars, _ = tree_flatten(out_vars, is_leaf=lambda s: isinstance(s, bm.Variable)) + for v in out_vars: + if not isinstance(v, bm.Variable): + raise TypeError('out_vars must be a PyTree of Variable.') + + def __call__( + self, + duration_or_xs: Union[float, PyTree], + ): + """Forward propagation along the time or inputs. + + Parameters:: + + duration_or_xs: float, PyTree + If `float`, it indicates a running duration. + If a PyTree, it is the given inputs. + + Returns:: + + out: PyTree + The accumulated outputs over time. + """ + # inputs + if isinstance(duration_or_xs, float): + shared = tools.DotDict() + if self.t0 is not None: + shared['t'] = jnp.arange(0, duration_or_xs, self.dt) + self.t0.value + if self.i0 is not None: + shared['i'] = jnp.arange(0, shared['t'].shape[0]) + self.i0.value + xs = None + if self.no_state: + raise ValueError('Under the `no_state=True` setting, input cannot be a duration.') + length = shared['t'].shape + else: - if self.data_first_axis == 'B': - xs = [jnp.moveaxis(x, 0, 1) for x in xs] - xs = tree_unflatten(tree, xs) - origin_shape = (length[0], batch[0]) if self.data_first_axis == 'T' else (batch[0], length[0]) - - else: - - try: - length = tuple(set([x.shape[0] for x in xs])) - except (AttributeError, IndexError) as e: - raise ValueError(inp_err_msg) from e - if len(length) != 1: - raise ValueError('\n' - 'Input should be a Array PyTree with the same time length. ' - f'but we got {tree_unflatten(tree, length)}.') - xs = tree_unflatten(tree, xs) - origin_shape = (length[0],) - - # computation - if self.no_state: - share.save(**self.shared_arg) - outputs = self._run(self.shared_arg, dict(), xs) - results = tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs) + inp_err_msg = ('\n' + 'Input should be a Array PyTree with the shape ' + 'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, ' + 'where B the batch size and T the time length.') + xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.BaseArray)) + if self.target.mode.is_child_of(bm.BatchingMode): + b_idx, t_idx = (1, 0) if self.data_first_axis == 'T' else (0, 1) + + try: + batch = tuple(set([x.shape[b_idx] for x in xs])) + except (AttributeError, IndexError) as e: + raise ValueError(inp_err_msg) from e + if len(batch) != 1: + raise ValueError('\n' + 'Input should be a Array PyTree with the same batch dimension. ' + f'but we got {tree_unflatten(tree, batch)}.') + try: + length = tuple(set([x.shape[t_idx] for x in xs])) + except (AttributeError, IndexError) as e: + raise ValueError(inp_err_msg) from e + if len(batch) != 1: + raise ValueError('\n' + 'Input should be a Array PyTree with the same batch size. ' + f'but we got {tree_unflatten(tree, batch)}.') + if len(length) != 1: + raise ValueError('\n' + 'Input should be a Array PyTree with the same time length. ' + f'but we got {tree_unflatten(tree, length)}.') + + if self.no_state: + xs = [bm.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs] + else: + if self.data_first_axis == 'B': + xs = [jnp.moveaxis(x, 0, 1) for x in xs] + xs = tree_unflatten(tree, xs) + origin_shape = (length[0], batch[0]) if self.data_first_axis == 'T' else (batch[0], length[0]) + + else: + + try: + length = tuple(set([x.shape[0] for x in xs])) + except (AttributeError, IndexError) as e: + raise ValueError(inp_err_msg) from e + if len(length) != 1: + raise ValueError('\n' + 'Input should be a Array PyTree with the same time length. ' + f'but we got {tree_unflatten(tree, length)}.') + xs = tree_unflatten(tree, xs) + origin_shape = (length[0],) + + # computation + if self.no_state: + share.save(**self.shared_arg) + outputs = self._run(self.shared_arg, dict(), xs) + results = tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs) + if self.i0 is not None: + self.i0 += length[0] + if self.t0 is not None: + self.t0 += length[0] * self.dt + return results + + else: + shared = tools.DotDict() + if self.t0 is not None: + shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value + if self.i0 is not None: + shared['i'] = jnp.arange(0, length[0]) + self.i0.value + + assert not self.no_state + results = bm.for_loop(functools.partial(self._run, self.shared_arg), + (shared, xs), + jit=self.jit, + remat=self.remat) if self.i0 is not None: - self.i0 += length[0] + self.i0 += length[0] if self.t0 is not None: - self.t0 += length[0] * self.dt + self.t0 += length[0] * self.dt return results - else: - shared = tools.DotDict() - if self.t0 is not None: - shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value + def reset_state(self, batch_size=None): if self.i0 is not None: - shared['i'] = jnp.arange(0, length[0]) + self.i0.value - - assert not self.no_state - results = bm.for_loop(functools.partial(self._run, self.shared_arg), - (shared, xs), - jit=self.jit, - remat=self.remat) - if self.i0 is not None: - self.i0 += length[0] - if self.t0 is not None: - self.t0 += length[0] * self.dt - return results - - def reset_state(self, batch_size=None): - if self.i0 is not None: - self.i0.value = bm.as_jax(self._i0) - if self.t0 is not None: - self.t0.value = bm.as_jax(self._t0) - - def _run(self, static_sh, dyn_sh, x): - share.save(**static_sh, **dyn_sh) - outs = self.target(x) - if self.out_vars is not None: - outs = (outs, tree_map(bm.as_jax, self.out_vars)) - clear_input(self.target) - return outs - + self.i0.value = bm.as_jax(self._i0) + if self.t0 is not None: + self.t0.value = bm.as_jax(self._t0) + + def _run(self, static_sh, dyn_sh, x): + share.save(**static_sh, **dyn_sh) + outs = self.target(x) + if self.out_vars is not None: + outs = (outs, tree_map(bm.as_jax, self.out_vars)) + clear_input(self.target) + return outs diff --git a/brainpy/_src/types.py b/brainpy/_src/types.py index 77e255f95..59542d1df 100644 --- a/brainpy/_src/types.py +++ b/brainpy/_src/types.py @@ -12,26 +12,24 @@ from brainpy._src.math.object_transform import Variable, TrainVar __all__ = [ - 'ArrayType', 'Parameter', 'PyTree', - 'Shape', 'Initializer', - 'Output', 'Monitor', 'Sharding', + 'ArrayType', 'Parameter', 'PyTree', + 'Shape', 'Initializer', + 'Output', 'Monitor', 'Sharding', ] - # data -Parameter = TypeVar('Parameter', numbers.Number, jax.Array, 'Array', 'Variable') # noqa -ArrayType = TypeVar('ArrayType', Array, Variable, TrainVar, jax.Array, np.ndarray) # noqa -Array = ArrayType # noqa -PyTree = TypeVar('PyTree') # noqa +Parameter = TypeVar('Parameter', numbers.Number, jax.Array, 'Array', 'Variable') # noqa +ArrayType = TypeVar('ArrayType', Array, Variable, TrainVar, jax.Array, np.ndarray) # noqa +Array = ArrayType # noqa +PyTree = TypeVar('PyTree') # noqa # shape -Shape = TypeVar('Shape', int, Tuple[int, ...]) # noqa +Shape = TypeVar('Shape', int, Tuple[int, ...]) # noqa # component -Output = TypeVar('Output') # noqa -Monitor = TypeVar('Monitor') # noqa +Output = TypeVar('Output') # noqa +Monitor = TypeVar('Monitor') # noqa Connector = Union[conn.Connector, Array, Variable, jax.Array, np.ndarray] Initializer = Union[init.Initializer, Callable, Array, Variable, jax.Array, np.ndarray] Sharding = Union[Sequence[str], jax.sharding.Sharding, jax.Device] - diff --git a/brainpy/_src/visualization.py b/brainpy/_src/visualization.py new file mode 100644 index 000000000..2b057bd74 --- /dev/null +++ b/brainpy/_src/visualization.py @@ -0,0 +1,25 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import braintools + +get_figure = braintools.visualize.get_figure +line_plot = braintools.visualize.line_plot +raster_plot = braintools.visualize.raster_plot +animate_1D = braintools.visualize.animate_1D +animate_2D = braintools.visualize.animate_2D +remove_axis = braintools.visualize.remove_axis +animator = braintools.visualize.animator + diff --git a/brainpy/_src/visualization/__init__.py b/brainpy/_src/visualization/__init__.py deleted file mode 100644 index 1d34cacd6..000000000 --- a/brainpy/_src/visualization/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Visualization toolkit. -""" - -from .base import * diff --git a/brainpy/_src/visualization/animation.py b/brainpy/_src/visualization/animation.py deleted file mode 100644 index 182664d33..000000000 --- a/brainpy/_src/visualization/animation.py +++ /dev/null @@ -1,121 +0,0 @@ -from collections import defaultdict -from typing import Dict, List - -import matplotlib.pyplot as plt -from matplotlib.animation import ArtistAnimation -from matplotlib.artist import Artist -from matplotlib.figure import Figure - -import brainpy.math as bm - -__all__ = [ - 'animator', -] - - -def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"): - """Generate an animation by looping through the first dimension of a - sample of spiking data. - Time must be the first dimension of ``data``. - - Example:: - - import matplotlib.pyplot as plt - - # Index into a single sample from a minibatch - spike_data_sample = bm.random.rand(100, 28, 28) - print(spike_data_sample.shape) - >>> (100, 28, 28) - - # Plot - fig, ax = plt.subplots() - anim = splt.animator(spike_data_sample, fig, ax) - HTML(anim.to_html5_video()) - - # Save as a gif - anim.save("spike_mnist.gif") - - :param data: Data tensor for a single sample across time steps of - shape [num_steps x input_size] - :type data: torch.Tensor - - :param fig: Top level container for all plot elements - :type fig: matplotlib.figure.Figure - - :param ax: Contains additional figure elements and sets the coordinate - system. E.g.: - fig, ax = plt.subplots(facecolor='w', figsize=(12, 7)) - :type ax: matplotlib.axes._subplots.AxesSubplot - - :param num_steps: Number of time steps to plot. If not specified, - the number of entries in the first dimension - of ``data`` will automatically be used, defaults to ``False`` - :type num_steps: int, optional - - :param interval: Delay between frames in milliseconds, defaults to ``40`` - :type interval: int, optional - - :param cmap: color map, defaults to ``plasma`` - :type cmap: string, optional - - :return: animation to be displayed using ``matplotlib.pyplot.show()`` - :rtype: FuncAnimation - - """ - - data = bm.as_numpy(data) - if not num_steps: - num_steps = data.shape[0] - camera = Camera(fig) - plt.axis("off") - # iterate over time and take a snapshot with celluloid - for step in range( - num_steps - ): # im appears unused but is required by camera.snap() - im = ax.imshow(data[step], cmap=cmap) # noqa: F841 - camera.snap() - anim = camera.animate(interval=interval) - return anim - - -class Camera: - """Make animations easier.""" - - def __init__(self, figure: Figure) -> None: - """Create camera from matplotlib figure.""" - self._figure = figure - # need to keep track off artists for each axis - self._offsets: Dict[str, Dict[int, int]] = { - k: defaultdict(int) - for k in [ - "collections", - "patches", - "lines", - "texts", - "artists", - "images", - ] - } - self._photos: List[List[Artist]] = [] - - def snap(self) -> List[Artist]: - """Capture current state of the figure.""" - frame_artists: List[Artist] = [] - for i, axis in enumerate(self._figure.axes): - if axis.legend_ is not None: - axis.add_artist(axis.legend_) - for name in self._offsets: - new_artists = getattr(axis, name)[self._offsets[name][i]:] - frame_artists += new_artists - self._offsets[name][i] += len(new_artists) - self._photos.append(frame_artists) - return frame_artists - - def animate(self, *args, **kwargs) -> ArtistAnimation: - """Animate the snapshots taken. - Uses matplotlib.animation.ArtistAnimation - Returns:: - - ArtistAnimation - """ - return ArtistAnimation(self._figure, self._photos, *args, **kwargs) diff --git a/brainpy/_src/visualization/base.py b/brainpy/_src/visualization/base.py deleted file mode 100644 index efd33cdc8..000000000 --- a/brainpy/_src/visualization/base.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- - - -__all__ = [ - 'visualize' -] - - -class visualize(object): - @staticmethod - def get_figure(row_num, - col_num, - row_len=3, - col_len=6): - from .figures import get_figure - return get_figure(row_num, col_num, row_len, col_len) - - @staticmethod - def line_plot(ts, - val_matrix, - plot_ids=None, - ax=None, - xlim=None, - ylim=None, - xlabel='Time (ms)', - ylabel=None, - legend=None, - title=None, - show=False, - **kwargs): - from .plots import line_plot - line_plot(ts, val_matrix, plot_ids=plot_ids, ax=ax, xlim=xlim, ylim=ylim, - xlabel=xlabel, ylabel=ylabel, legend=legend, title=title, show=show, **kwargs) - - @staticmethod - def raster_plot(ts, - sp_matrix, - ax=None, - marker='.', - markersize=2, - color='k', - xlabel='Time (ms)', - ylabel='Neuron index', - xlim=None, - ylim=None, - title=None, - show=False, - **kwargs): - from .plots import raster_plot - raster_plot(ts, sp_matrix, ax=ax, marker=marker, markersize=markersize, color=color, - xlabel=xlabel, ylabel=ylabel, xlim=xlim, ylim=ylim, title=title, show=show, **kwargs) - - @staticmethod - def animate_1D(dynamical_vars, - static_vars=(), - dt=None, - xlim=None, - ylim=None, - xlabel=None, - ylabel=None, - frame_delay=50., - frame_step=1, - title_size=10, - figsize=None, - gif_dpi=None, - video_fps=None, - save_path=None, - show=True, - **kwargs): - from .plots import animate_1D - return animate_1D(dynamical_vars, static_vars=static_vars, dt=dt, xlim=xlim, ylim=ylim, - xlabel=xlabel, ylabel=ylabel, frame_delay=frame_delay, frame_step=frame_step, - title_size=title_size, figsize=figsize, gif_dpi=gif_dpi, video_fps=video_fps, - save_path=save_path, show=show, **kwargs) - - @staticmethod - def animate_2D(values, - net_size, - dt=None, - val_min=None, - val_max=None, - cmap=None, - frame_delay=10, - frame_step=1, - title_size=10, - figsize=None, - gif_dpi=None, - video_fps=None, - save_path=None, - show=True): - from .plots import animate_2D - return animate_2D(values, net_size, dt=dt, val_min=val_min, val_max=val_max, cmap=cmap, - frame_delay=frame_delay, frame_step=frame_step, title_size=title_size, - figsize=figsize, gif_dpi=gif_dpi, video_fps=video_fps, save_path=save_path, show=show) - - @staticmethod - def remove_axis(ax, *pos): - from .plots import remove_axis - return remove_axis(ax, *pos) - - @staticmethod - def plot_style1(fontsize=22, - axes_edgecolor='black', - figsize='5,4', - lw=1): - from .styles import plot_style1 - plot_style1(fontsize=fontsize, axes_edgecolor=axes_edgecolor, figsize=figsize, lw=lw) - - @staticmethod - def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"): - from .animation import animator - return animator(data, fig, ax, num_steps=num_steps, interval=interval, cmap=cmap) diff --git a/brainpy/_src/visualization/figures.py b/brainpy/_src/visualization/figures.py deleted file mode 100644 index 0a616d35e..000000000 --- a/brainpy/_src/visualization/figures.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- - - -import matplotlib.pyplot as plt -from matplotlib.gridspec import GridSpec - -__all__ = [ - 'get_figure', -] - - -def get_figure(row_num, col_num, row_len=3, col_len=6, name=None): - """Get the constrained_layout figure. - - Parameters:: - - row_num : int - The row number of the figure. - col_num : int - The column number of the figure. - row_len : int, float - The length of each row. - col_len : int, float - The length of each column. - - Returns:: - - fig_and_gs : tuple - Figure and GridSpec. - """ - if name is None: - fig = plt.figure(figsize=(col_num * col_len, row_num * row_len), constrained_layout=True) - else: - fig = plt.figure(name, figsize=(col_num * col_len, row_num * row_len), constrained_layout=True) - gs = GridSpec(row_num, col_num, figure=fig) - return fig, gs diff --git a/brainpy/_src/visualization/plots.py b/brainpy/_src/visualization/plots.py deleted file mode 100644 index e62a21f44..000000000 --- a/brainpy/_src/visualization/plots.py +++ /dev/null @@ -1,516 +0,0 @@ -# -*- coding: utf-8 -*- - -import logging - -import matplotlib.pyplot as plt -import numpy as np -from matplotlib import animation -from matplotlib.gridspec import GridSpec - -from brainpy import math, errors - -logger = logging.getLogger('brainpy.visualization') - - -__all__ = [ - 'line_plot', - 'raster_plot', - 'animate_2D', - 'animate_1D', - 'remove_axis', -] - - -def line_plot(ts, - val_matrix, - plot_ids=None, - ax=None, - xlim=None, - ylim=None, - xlabel='Time (ms)', - ylabel=None, - legend=None, - title=None, - show=False, - **kwargs): - """Show the specified value in the given object (Neurons or Synapses.) - - Parameters:: - - ts : np.ndarray - The time steps. - val_matrix : np.ndarray - The value matrix which record the history trajectory. - It can be easily accessed by specifying the ``monitors`` - of NeuGroup/SynConn by: - ``neu/syn = NeuGroup/SynConn(..., monitors=[k1, k2])`` - plot_ids : None, int, tuple, a_list - The index of the value to plot. - ax : None, Axes - The figure to plot. - xlim : list, tuple - The xlim. - ylim : list, tuple - The ylim. - xlabel : str - The xlabel. - ylabel : str - The ylabel. - legend : str - The prefix of legend for plot. - show : bool - Whether show the figure. - """ - # get plot_ids - if plot_ids is None: - plot_ids = [0] - elif isinstance(plot_ids, int): - plot_ids = [plot_ids] - if not isinstance(plot_ids, (list, tuple)) and \ - not (isinstance(plot_ids, np.ndarray) and np.ndim(plot_ids) == 1): - raise errors.BrainPyError(f'"plot_ids" specifies the value index to plot, it must ' - f'be a list/tuple/1D numpy.ndarray, not {type(plot_ids)}.') - - # get ax - if ax is None: - ax = plt - - val_matrix = val_matrix.reshape((val_matrix.shape[0], -1)) - # change data - val_matrix = np.asarray(val_matrix) - ts = np.asarray(ts) - - # plot - if legend: - for idx in plot_ids: - label = legend if len(plot_ids) == 1 else f'{legend}-{idx}' - ax.plot(ts, val_matrix[:, idx], label=label, **kwargs) - else: - for idx in plot_ids: - ax.plot(ts, val_matrix[:, idx], **kwargs) - - # legend - if legend: - ax.legend() - - # xlim - if xlim is not None: - plt.xlim(xlim[0], xlim[1]) - - # ylim - if ylim is not None: - plt.ylim(ylim[0], ylim[1]) - - # xlable - if xlabel: - plt.xlabel(xlabel) - - # ylabel - if ylabel: - plt.ylabel(ylabel) - - # title - if title: - plt.title(title) - - # show - if show: - plt.show() - - -def raster_plot(ts, - sp_matrix, - ax=None, - marker='.', - markersize=2, - color='k', - xlabel='Time (ms)', - ylabel='Neuron index', - xlim=None, - ylim=None, - title=None, - show=False, - **kwargs): - """Show the rater plot of the spikes. - - Parameters:: - - ts : np.ndarray - The run times. - sp_matrix : np.ndarray - The spike matrix which records the spike information. - It can be easily accessed by specifying the ``monitors`` - of NeuGroup by: ``neu = NeuGroup(..., monitors=['spike'])`` - ax : Axes - The figure. - markersize : int - The size of the marker. - color : str - The color of the marker. - xlim : list, tuple - The xlim. - ylim : list, tuple - The ylim. - xlabel : str - The xlabel. - ylabel : str - The ylabel. - show : bool - Show the figure. - """ - - sp_matrix = np.asarray(sp_matrix) - if ts is None: - raise errors.BrainPyError('Must provide "ts".') - ts = np.asarray(ts) - - # get index and time - elements = np.where(sp_matrix > 0.) - index = elements[1] - time = ts[elements[0]] - - # plot rater - if ax is None: - ax = plt - ax.plot(time, index, marker + color, markersize=markersize, **kwargs) - - # xlable - if xlabel: - plt.xlabel(xlabel) - - # ylabel - if ylabel: - plt.ylabel(ylabel) - - if xlim: - plt.xlim(xlim[0], xlim[1]) - - if ylim: - plt.ylim(ylim[0], ylim[1]) - - if title: - plt.title(title) - - if show: - plt.show() - - -def animate_2D(values, - net_size, - dt=None, - val_min=None, - val_max=None, - cmap=None, - frame_delay=10, - frame_step=1, - title_size=10, - figsize=None, - gif_dpi=None, - video_fps=None, - save_path=None, - show=True): - """Animate the potentials of the neuron group. - - Parameters:: - - values : np.ndarray - The membrane potentials of the neuron group. - net_size : tuple - The size of the neuron group. - dt : float - The time duration of each step. - val_min : float, int - The minimum of the potential. - val_max : float, int - The maximum of the potential. - cmap : str - The colormap. - frame_delay : int, float - The delay to show each frame. - frame_step : int - The step to show the potential. If `frame_step=3`, then each - frame shows one of the every three steps. - title_size : int - The size of the title. - figsize : None, tuple - The size of the figure. - gif_dpi : int - Controls the dots per inch for the movie frames. This combined with - the figure's size in inches controls the size of the movie. If - ``None``, use defaults in matplotlib. - video_fps : int - Frames per second in the movie. Defaults to ``None``, which will use - the animation's specified interval to set the frames per second. - save_path : None, str - The save path of the animation. - show : bool - Whether show the animation. - - Returns:: - - anim : animation.FuncAnimation - The created animation function. - """ - dt = math.get_dt() if dt is None else dt - num_step, num_neuron = values.shape - height, width = net_size - - values = np.asarray(values) - val_min = values.min() if val_min is None else val_min - val_max = values.max() if val_max is None else val_max - - figsize = figsize or (6, 6) - - fig = plt.figure(figsize=(figsize[0], figsize[1]), constrained_layout=True) - gs = GridSpec(1, 1, figure=fig) - fig.add_subplot(gs[0, 0]) - - def frame(t): - img = values[t] - fig.clf() - plt.pcolor(img, cmap=cmap, vmin=val_min, vmax=val_max) - plt.colorbar() - plt.axis('off') - fig.suptitle(t="Time: {:.2f} ms".format((t + 1) * dt), - fontsize=title_size, - fontweight='bold') - return [fig.gca()] - - values = values.reshape((num_step, height, width)) - anim = animation.FuncAnimation(fig=fig, - func=frame, - frames=list(range(1, num_step, frame_step)), - init_func=None, - interval=frame_delay, - repeat_delay=3000) - if save_path is None: - if show: - plt.show() - else: - logger.warning(f'Saving the animation into {save_path} ...') - if save_path[-3:] == 'gif': - anim.save(save_path, dpi=gif_dpi, writer='imagemagick') - elif save_path[-3:] == 'mp4': - anim.save(save_path, writer='ffmpeg', fps=video_fps, bitrate=3000) - else: - anim.save(save_path + '.mp4', writer='ffmpeg', fps=video_fps, bitrate=3000) - return anim - - -def animate_1D(dynamical_vars, - static_vars=(), - dt=None, - xlim=None, - ylim=None, - xlabel=None, - ylabel=None, - frame_delay=50., - frame_step=1, - title_size=10, - figsize=None, - gif_dpi=None, - video_fps=None, - save_path=None, - show=True, - **kwargs): - """Animation of one-dimensional data. - - Parameters:: - - dynamical_vars : dict, np.ndarray, list of np.ndarray, list of dict - The dynamical variables which will be animated. - static_vars : dict, np.ndarray, list of np.ndarray, list of dict - The static variables. - xticks : list, np.ndarray - The xticks. - dt : float - The numerical integration step. - xlim : tuple - The xlim. - ylim : tuple - The ylim. - xlabel : str - The xlabel. - ylabel : str - The ylabel. - frame_delay : int, float - The delay to show each frame. - frame_step : int - The step to show the potential. If `frame_step=3`, then each - frame shows one of the every three steps. - title_size : int - The size of the title. - figsize : None, tuple - The size of the figure. - gif_dpi : int - Controls the dots per inch for the movie frames. This combined with - the figure's size in inches controls the size of the movie. If - ``None``, use defaults in matplotlib. - video_fps : int - Frames per second in the movie. Defaults to ``None``, which will use - the animation's specified interval to set the frames per second. - save_path : None, str - The save path of the animation. - show : bool - Whether show the animation. - - Returns:: - - figure : plt.figure - The created figure instance. - """ - - # check dt - dt = math.get_dt() if dt is None else dt - - # check figure - fig = plt.figure(figsize=(figsize or (6, 6)), constrained_layout=True) - gs = GridSpec(1, 1, figure=fig) - fig.add_subplot(gs[0, 0]) - - # check dynamical variables - final_dynamic_vars = [] - lengths = [] - has_legend = False - if isinstance(dynamical_vars, (tuple, list)): - for var in dynamical_vars: - if isinstance(var, dict): - assert 'ys' in var, 'Must provide "ys" item.' - if 'legend' not in var: - var['legend'] = None - else: - has_legend = True - var['ys'] = np.asarray(var['ys']) - if 'xs' not in var: - var['xs'] = np.arange(var['ys'].shape[1]) - elif isinstance(var, (np.ndarray, math.ndarray)): - var = np.asarray(var) - var = {'ys': var, - 'xs': np.arange(var.shape[1]), - 'legend': None} - else: - raise ValueError(f'Unknown data type: {type(var)}') - assert np.ndim(var['ys']) == 2, "Dynamic variable must be 2D data." - lengths.append(var['ys'].shape[0]) - final_dynamic_vars.append(var) - elif isinstance(dynamical_vars, dict): - assert 'ys' in dynamical_vars, 'Must provide "ys" item.' - if 'legend' not in dynamical_vars: - dynamical_vars['legend'] = None - else: - has_legend = True - dynamical_vars['ys'] = np.asarray(dynamical_vars['ys']) - if 'xs' not in dynamical_vars: - dynamical_vars['xs'] = np.arange(dynamical_vars['ys'].shape[1]) - lengths.append(dynamical_vars['ys'].shape[0]) - final_dynamic_vars.append(dynamical_vars) - else: - assert np.ndim(dynamical_vars) == 2, "Dynamic variable must be 2D data." - dynamical_vars = np.asarray(dynamical_vars) - lengths.append(dynamical_vars.shape[0]) - final_dynamic_vars.append({'ys': dynamical_vars, - 'xs': np.arange(dynamical_vars.shape[1]), - 'legend': None}) - lengths = np.array(lengths) - assert np.all(lengths == lengths[0]), 'Dynamic variables must have equal length.' - - # check static variables - final_static_vars = [] - if isinstance(static_vars, (tuple, list)): - for var in static_vars: - if isinstance(var, dict): - assert 'data' in var, 'Must provide "ys" item.' - if 'legend' not in var: - var['legend'] = None - else: - has_legend = True - elif isinstance(var, np.ndarray): - var = {'data': var, 'legend': None} - else: - raise ValueError(f'Unknown data type: {type(var)}') - assert np.ndim(var['data']) == 1, "Static variable must be 1D data." - final_static_vars.append(var) - elif isinstance(static_vars, np.ndarray): - final_static_vars.append({'data': static_vars, - 'xs': np.arange(static_vars.shape[0]), - 'legend': None}) - elif isinstance(static_vars, dict): - assert 'ys' in static_vars, 'Must provide "ys" item.' - if 'legend' not in static_vars: - static_vars['legend'] = None - else: - has_legend = True - if 'xs' not in static_vars: - static_vars['xs'] = np.arange(static_vars['ys'].shape[0]) - final_static_vars.append(static_vars) - - else: - raise ValueError(f'Unknown static data type: {type(static_vars)}') - - # ylim - if ylim is None: - ylim_min = np.inf - ylim_max = -np.inf - for var in final_dynamic_vars + final_static_vars: - if var['ys'].max() > ylim_max: - ylim_max = var['ys'].max() - if var['ys'].min() < ylim_min: - ylim_min = var['ys'].min() - if ylim_min > 0: - ylim_min = ylim_min * 0.98 - else: - ylim_min = ylim_min * 1.02 - if ylim_max > 0: - ylim_max = ylim_max * 1.02 - else: - ylim_max = ylim_max * 0.98 - ylim = (ylim_min, ylim_max) - - def frame(t): - fig.clf() - for dvar in final_dynamic_vars: - plt.plot(dvar['xs'], dvar['ys'][t], label=dvar['legend'], **kwargs) - for svar in final_static_vars: - plt.plot(svar['xs'], svar['ys'], label=svar['legend'], **kwargs) - if xlim is not None: - plt.xlim(xlim[0], xlim[1]) - if has_legend: - plt.legend() - if xlabel: - plt.xlabel(xlabel) - if ylabel: - plt.ylabel(ylabel) - plt.ylim(ylim[0], ylim[1]) - fig.suptitle(t="Time: {:.2f} ms".format((t + 1) * dt), - fontsize=title_size, - fontweight='bold') - return [fig.gca()] - - anim_result = animation.FuncAnimation(fig=fig, - func=frame, - frames=range(1, lengths[0], frame_step), - init_func=None, - interval=frame_delay, - repeat_delay=3000) - - # save or show - if save_path is None: - if show: plt.show() - else: - logger.warning(f'Saving the animation into {save_path} ...') - if save_path[-3:] == 'gif': - anim_result.save(save_path, dpi=gif_dpi, writer='imagemagick') - elif save_path[-3:] == 'mp4': - anim_result.save(save_path, writer='ffmpeg', fps=video_fps, bitrate=3000) - else: - anim_result.save(save_path + '.mp4', writer='ffmpeg', fps=video_fps, bitrate=3000) - return fig - - -def remove_axis(ax, *pos): - for p in pos: - if p not in ['left', 'right', 'top', 'bottom']: - raise ValueError - ax.spine[p].set_visible(False) - - diff --git a/brainpy/_src/visualization/styles.py b/brainpy/_src/visualization/styles.py deleted file mode 100644 index fc442b912..000000000 --- a/brainpy/_src/visualization/styles.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- - - -from matplotlib import rcParams - -__all__ = [ - 'plot_style1', -] - - -def plot_style1(fontsize=22, axes_edgecolor='black', figsize='5,4', lw=1): - """Plot style for publication. - - Parameters:: - - fontsize : int - The font size. - axes_edgecolor : str - The exes edge color. - figsize : str, tuple - The figure size. - lw : int - Line width. - """ - rcParams['text.latex.preamble'] = [r"\usepackage{amsmath, lmodern}"] - params = { - 'text.usetex': True, - 'font.family': 'lmodern', - # 'text.latex.unicode': True, - 'text.color': 'black', - 'xtick.labelsize': fontsize - 2, - 'ytick.labelsize': fontsize - 2, - 'axes.labelsize': fontsize, - 'axes.labelweight': 'bold', - 'axes.edgecolor': axes_edgecolor, - 'axes.titlesize': fontsize, - 'axes.titleweight': 'bold', - 'pdf.fonttype': 42, - 'ps.fonttype': 42, - 'axes.grid': False, - 'axes.facecolor': 'white', - 'lines.linewidth': lw, - "figure.figsize": figsize, - } - rcParams.update(params) diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py index afa200ed0..501e182d0 100644 --- a/brainpy/algorithms/offline.py +++ b/brainpy/algorithms/offline.py @@ -2,8 +2,8 @@ import warnings -import numpy as np import jax.numpy as jnp +import numpy as np from jax.lax import while_loop import brainpy.math as bm @@ -18,185 +18,185 @@ normalize) __all__ = [ - # brainpy_object class for offline training algorithm - 'OfflineAlgorithm', - - # training methods - 'LinearRegression', 'linear_regression', - 'RidgeRegression', 'ridge_regression', - 'LassoRegression', - 'LogisticRegression', - 'PolynomialRegression', - 'PolynomialRidgeRegression', - 'ElasticNetRegression', - - # general supports - 'get_supported_offline_methods', - 'register_offline_method', + # brainpy_object class for offline training algorithm + 'OfflineAlgorithm', + + # training methods + 'LinearRegression', 'linear_regression', + 'RidgeRegression', 'ridge_regression', + 'LassoRegression', + 'LogisticRegression', + 'PolynomialRegression', + 'PolynomialRidgeRegression', + 'ElasticNetRegression', + + # general supports + 'get_supported_offline_methods', + 'register_offline_method', ] name2func = dict() class OfflineAlgorithm(BrainPyObject): - """Base class for offline training algorithm.""" + """Base class for offline training algorithm.""" - def __init__(self, name=None): - super(OfflineAlgorithm, self).__init__(name=name) + def __init__(self, name=None): + super(OfflineAlgorithm, self).__init__(name=name) - def __call__(self, targets, inputs, outputs=None): - """The training procedure. + def __call__(self, targets, inputs, outputs=None): + """The training procedure. - Parameters:: + Parameters:: - targets: ArrayType - The 2d target data with the shape of `(num_batch, num_output)`. - inputs: ArrayType - The 2d input data with the shape of `(num_batch, num_input)`. - outputs: ArrayType - The 2d output data with the shape of `(num_batch, num_output)`. + targets: ArrayType + The 2d target data with the shape of `(num_batch, num_output)`. + inputs: ArrayType + The 2d input data with the shape of `(num_batch, num_input)`. + outputs: ArrayType + The 2d output data with the shape of `(num_batch, num_output)`. - Returns:: + Returns:: - weight: ArrayType - The weights after fit. - """ - return self.call(targets, inputs, outputs) + weight: ArrayType + The weights after fit. + """ + return self.call(targets, inputs, outputs) - def call(self, targets, inputs, outputs=None) -> ArrayType: - """The training procedure. + def call(self, targets, inputs, outputs=None) -> ArrayType: + """The training procedure. - Parameters:: + Parameters:: - inputs: ArrayType - The 3d input data with the shape of `(num_batch, num_time, num_input)`, - or, the 2d input data with the shape of `(num_time, num_input)`. + inputs: ArrayType + The 3d input data with the shape of `(num_batch, num_time, num_input)`, + or, the 2d input data with the shape of `(num_time, num_input)`. - targets: ArrayType - The 3d target data with the shape of `(num_batch, num_time, num_output)`, - or the 2d target data with the shape of `(num_time, num_output)`. + targets: ArrayType + The 3d target data with the shape of `(num_batch, num_time, num_output)`, + or the 2d target data with the shape of `(num_time, num_output)`. - outputs: ArrayType - The 3d output data with the shape of `(num_batch, num_time, num_output)`, - or the 2d output data with the shape of `(num_time, num_output)`. + outputs: ArrayType + The 3d output data with the shape of `(num_batch, num_time, num_output)`, + or the 2d output data with the shape of `(num_time, num_output)`. - Returns:: + Returns:: - weight: ArrayType - The weights after fit. - """ - raise NotImplementedError('Must implement the __call__ function by the subclass itself.') + weight: ArrayType + The weights after fit. + """ + raise NotImplementedError('Must implement the __call__ function by the subclass itself.') - def __repr__(self): - return self.__class__.__name__ + def __repr__(self): + return self.__class__.__name__ def _check_data_2d_atls(x): - if x.ndim < 2: - raise ValueError(f'Data must be a 2d tensor. But we got {x.ndim}d: {x.shape}.') - if x.ndim != 2: - return bm.flatten(x, end_dim=-2) - else: - return x + if x.ndim < 2: + raise ValueError(f'Data must be a 2d tensor. But we got {x.ndim}d: {x.shape}.') + if x.ndim != 2: + return bm.flatten(x, end_dim=-2) + else: + return x class RegressionAlgorithm(OfflineAlgorithm): - """ Base regression model. Models the relationship between a scalar dependent variable y and the independent - variables X. - - Parameters:: - - max_iter: int - The number of training iterations the algorithm will tune the weights for. - learning_rate: float - The step length that will be used when updating the weights. - """ - - def __init__( - self, - max_iter: int = None, - learning_rate: float = None, - regularizer: Regularization = None, - name: str = None - ): - super(RegressionAlgorithm, self).__init__(name=name) - self.max_iter = max_iter - self.learning_rate = learning_rate - self.regularizer = regularizer - - def initialize(self, *args, **kwargs): - pass - - def init_weights(self, n_features, n_out): - """ Initialize weights randomly [-1/N, 1/N] """ - limit = 1 / np.sqrt(n_features) - return bm.random.uniform(-limit, limit, (n_features, n_out)) - - def gradient_descent_solve(self, targets, inputs, outputs=None): - # checking - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - - # initialize weights - w = self.init_weights(inputs.shape[1], targets.shape[1]) - - def cond_fun(a): - i, par_old, par_new = a - return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), - i < self.max_iter).value - - def body_fun(a): - i, _, par_new = a - # Gradient of regularization loss w.r.t w - y_pred = inputs.dot(par_new) - grad_w = jnp.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new) - # Update the weights - par_new2 = par_new - self.learning_rate * grad_w - return i + 1, par_new, par_new2 - - # Tune parameters for n iterations - r = while_loop(cond_fun, body_fun, (0, w - 1e-8, w)) - return r[-1] - - def predict(self, W, X): - return jnp.dot(X, W) + """ Base regression model. Models the relationship between a scalar dependent variable y and the independent + variables X. + + Parameters:: + + max_iter: int + The number of training iterations the algorithm will tune the weights for. + learning_rate: float + The step length that will be used when updating the weights. + """ + + def __init__( + self, + max_iter: int = None, + learning_rate: float = None, + regularizer: Regularization = None, + name: str = None + ): + super(RegressionAlgorithm, self).__init__(name=name) + self.max_iter = max_iter + self.learning_rate = learning_rate + self.regularizer = regularizer + + def initialize(self, *args, **kwargs): + pass + + def init_weights(self, n_features, n_out): + """ Initialize weights randomly [-1/N, 1/N] """ + limit = 1 / np.sqrt(n_features) + return bm.random.uniform(-limit, limit, (n_features, n_out)) + + def gradient_descent_solve(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + + # initialize weights + w = self.init_weights(inputs.shape[1], targets.shape[1]) + + def cond_fun(a): + i, par_old, par_new = a + return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), + i < self.max_iter).value + + def body_fun(a): + i, _, par_new = a + # Gradient of regularization loss w.r.t w + y_pred = inputs.dot(par_new) + grad_w = jnp.dot(inputs.T, -(targets - y_pred)) + self.regularizer.grad(par_new) + # Update the weights + par_new2 = par_new - self.learning_rate * grad_w + return i + 1, par_new, par_new2 + + # Tune parameters for n iterations + r = while_loop(cond_fun, body_fun, (0, w - 1e-8, w)) + return r[-1] + + def predict(self, W, X): + return jnp.dot(X, W) class LinearRegression(RegressionAlgorithm): - """Training algorithm of least-square regression. - - Parameters:: - - name: str - The name of the algorithm. - """ - - def __init__( - self, - name: str = None, - - # parameters for using gradient descent - max_iter: int = 1000, - learning_rate: float = 0.001, - gradient_descent: bool = False, - ): - super(LinearRegression, self).__init__(name=name, - max_iter=max_iter, - learning_rate=learning_rate, - regularizer=Regularization(0.)) - self.gradient_descent = gradient_descent - - def call(self, targets, inputs, outputs=None): - # checking - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - - # solving - if self.gradient_descent: - return self.gradient_descent_solve(targets, inputs) - else: - weights = jnp.linalg.lstsq(inputs, targets) - return weights[0] + """Training algorithm of least-square regression. + + Parameters:: + + name: str + The name of the algorithm. + """ + + def __init__( + self, + name: str = None, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = False, + ): + super(LinearRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=Regularization(0.)) + self.gradient_descent = gradient_descent + + def call(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + + # solving + if self.gradient_descent: + return self.gradient_descent_solve(targets, inputs) + else: + weights = jnp.linalg.lstsq(inputs, targets) + return weights[0] linear_regression = LinearRegression() @@ -206,64 +206,64 @@ def call(self, targets, inputs, outputs=None): class RidgeRegression(RegressionAlgorithm): - """Training algorithm of ridge regression. - - Parameters:: - - alpha: float - The regularization coefficient. - - .. versionadded:: 2.2.0 - - beta: float - The regularization coefficient. - - .. deprecated:: 2.2.0 - Please use `alpha` to set regularization factor. - - name: str - The name of the algorithm. - """ - - def __init__( - self, - alpha: float = 1e-7, - beta: float = None, - name: str = None, - - # parameters for using gradient descent - max_iter: int = 1000, - learning_rate: float = 0.001, - gradient_descent: bool = False, - ): - if beta is not None: - warnings.warn(f"Please use 'alpha' to set regularization factor. " - f"'beta' has been deprecated since version 2.2.0.", - UserWarning) - alpha = beta - super(RidgeRegression, self).__init__(name=name, - max_iter=max_iter, - learning_rate=learning_rate, - regularizer=L2Regularization(alpha=alpha)) - self.gradient_descent = gradient_descent - - def call(self, targets, inputs, outputs=None): - # checking - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - - # solving - if self.gradient_descent: - return self.gradient_descent_solve(targets, inputs) - else: - temp = inputs.T @ inputs - if self.regularizer.alpha > 0.: - temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1]) - weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets) - return weights + """Training algorithm of ridge regression. + + Parameters:: + + alpha: float + The regularization coefficient. - def __repr__(self): - return f'{self.__class__.__name__}(beta={self.regularizer.alpha})' + .. versionadded:: 2.2.0 + + beta: float + The regularization coefficient. + + .. deprecated:: 2.2.0 + Please use `alpha` to set regularization factor. + + name: str + The name of the algorithm. + """ + + def __init__( + self, + alpha: float = 1e-7, + beta: float = None, + name: str = None, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = False, + ): + if beta is not None: + warnings.warn(f"Please use 'alpha' to set regularization factor. " + f"'beta' has been deprecated since version 2.2.0.", + UserWarning) + alpha = beta + super(RidgeRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=L2Regularization(alpha=alpha)) + self.gradient_descent = gradient_descent + + def call(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + + # solving + if self.gradient_descent: + return self.gradient_descent_solve(targets, inputs) + else: + temp = inputs.T @ inputs + if self.regularizer.alpha > 0.: + temp += self.regularizer.alpha * jnp.eye(inputs.shape[-1]) + weights = jnp.linalg.pinv(temp) @ (inputs.T @ targets) + return weights + + def __repr__(self): + return f'{self.__class__.__name__}(beta={self.regularizer.alpha})' ridge_regression = RidgeRegression() @@ -272,289 +272,289 @@ def __repr__(self): class LassoRegression(RegressionAlgorithm): - """Lasso regression method for offline training. - - Parameters:: - - alpha: float - Constant that multiplies the L1 term. Defaults to 1.0. - `alpha = 0` is equivalent to an ordinary least square. - max_iter: int - The maximum number of iterations. - degree: int - The degree of the polynomial that the independent variable X will be transformed to. - name: str - The name of the algorithm. - """ - - def __init__( - self, - alpha: float = 1.0, - degree: int = 2, - add_bias: bool = False, - name: str = None, - - # parameters for using gradient descent - max_iter: int = 1000, - learning_rate: float = 0.001, - gradient_descent: bool = True, - ): - super(LassoRegression, self).__init__(name=name, - max_iter=max_iter, - learning_rate=learning_rate, - regularizer=L1Regularization(alpha=alpha)) - self.gradient_descent = gradient_descent - self.add_bias = add_bias - assert gradient_descent - self.degree = degree - - def call(self, targets, inputs, outputs=None): - # checking - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - - # solving - inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)) - return super(LassoRegression, self).gradient_descent_solve(targets, inputs) - - def predict(self, W, X): - X = _check_data_2d_atls(bm.as_jax(X)) - X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) - return super(LassoRegression, self).predict(W, X) + """Lasso regression method for offline training. + + Parameters:: + + alpha: float + Constant that multiplies the L1 term. Defaults to 1.0. + `alpha = 0` is equivalent to an ordinary least square. + max_iter: int + The maximum number of iterations. + degree: int + The degree of the polynomial that the independent variable X will be transformed to. + name: str + The name of the algorithm. + """ + + def __init__( + self, + alpha: float = 1.0, + degree: int = 2, + add_bias: bool = False, + name: str = None, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(LassoRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=L1Regularization(alpha=alpha)) + self.gradient_descent = gradient_descent + self.add_bias = add_bias + assert gradient_descent + self.degree = degree + + def call(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + + # solving + inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias)) + return super(LassoRegression, self).gradient_descent_solve(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.as_jax(X)) + X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) + return super(LassoRegression, self).predict(W, X) name2func['lasso'] = LassoRegression class LogisticRegression(RegressionAlgorithm): - """Logistic regression method for offline training. - - Parameters:: - - learning_rate: float - The step length that will be taken when following the negative gradient during - training. - gradient_descent: boolean - True or false depending on if gradient descent should be used when training. If - false then we use batch optimization by least squares. - max_iter: int - The number of iteration to optimize the parameters. - name: str - The name of the algorithm. - """ - - def __init__( - self, - learning_rate: float = .1, - gradient_descent: bool = True, - max_iter: int = 4000, - name: str = None, - ): - super(LogisticRegression, self).__init__(name=name, - max_iter=max_iter, - learning_rate=learning_rate) - self.gradient_descent = gradient_descent - self.sigmoid = Sigmoid() - - def call(self, targets, inputs, outputs=None) -> ArrayType: - # prepare data - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - if targets.shape[-1] != 1: - raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ') - targets = targets.flatten() - - # initialize parameters - param = self.init_weights(inputs.shape[1], targets.shape[1]) - - def cond_fun(a): - i, par_old, par_new = a - return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), - i < self.max_iter).value - - def body_fun(a): - i, par_old, par_new = a - # Make a new prediction - y_pred = self.sigmoid(inputs.dot(par_new)) - if self.gradient_descent: - # Move against the gradient of the loss function with - # respect to the parameters to minimize the loss - par_new2 = par_new - self.learning_rate * (y_pred - targets).dot(inputs) - else: - gradient = self.sigmoid.grad(inputs.dot(par_new)) - diag_grad = bm.zeros((gradient.size, gradient.size)) - diag = bm.arange(gradient.size) - diag_grad[diag, diag] = gradient - par_new2 = jnp.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot( - diag_grad.dot(inputs).dot(par_new) + targets - y_pred) - return i + 1, par_new, par_new2 - - # Tune parameters for n iterations - r = while_loop(cond_fun, body_fun, (0, param + 1., param)) - return r[-1] - - def predict(self, W, X): - return self.sigmoid(X @ W) + """Logistic regression method for offline training. + Parameters:: -name2func['logistic'] = LogisticRegression + learning_rate: float + The step length that will be taken when following the negative gradient during + training. + gradient_descent: boolean + True or false depending on if gradient descent should be used when training. If + false then we use batch optimization by least squares. + max_iter: int + The number of iteration to optimize the parameters. + name: str + The name of the algorithm. + """ + def __init__( + self, + learning_rate: float = .1, + gradient_descent: bool = True, + max_iter: int = 4000, + name: str = None, + ): + super(LogisticRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate) + self.gradient_descent = gradient_descent + self.sigmoid = Sigmoid() + + def call(self, targets, inputs, outputs=None) -> ArrayType: + # prepare data + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + if targets.shape[-1] != 1: + raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ') + targets = targets.flatten() + + # initialize parameters + param = self.init_weights(inputs.shape[1], targets.shape[1]) + + def cond_fun(a): + i, par_old, par_new = a + return jnp.logical_and(jnp.logical_not(jnp.allclose(par_old, par_new)), + i < self.max_iter).value + + def body_fun(a): + i, par_old, par_new = a + # Make a new prediction + y_pred = self.sigmoid(inputs.dot(par_new)) + if self.gradient_descent: + # Move against the gradient of the loss function with + # respect to the parameters to minimize the loss + par_new2 = par_new - self.learning_rate * (y_pred - targets).dot(inputs) + else: + gradient = self.sigmoid.grad(inputs.dot(par_new)) + diag_grad = bm.zeros((gradient.size, gradient.size)) + diag = bm.arange(gradient.size) + diag_grad[diag, diag] = gradient + par_new2 = jnp.linalg.pinv(inputs.T.dot(diag_grad).dot(inputs)).dot(inputs.T).dot( + diag_grad.dot(inputs).dot(par_new) + targets - y_pred) + return i + 1, par_new, par_new2 + + # Tune parameters for n iterations + r = while_loop(cond_fun, body_fun, (0, param + 1., param)) + return r[-1] + + def predict(self, W, X): + return self.sigmoid(X @ W) -class PolynomialRegression(LinearRegression): - def __init__( - self, - degree: int = 2, - name: str = None, - add_bias: bool = False, - - # parameters for using gradient descent - max_iter: int = 1000, - learning_rate: float = 0.001, - gradient_descent: bool = True, - ): - super(PolynomialRegression, self).__init__(name=name, - max_iter=max_iter, - learning_rate=learning_rate, - gradient_descent=gradient_descent) - self.degree = degree - self.add_bias = add_bias - def call(self, targets, inputs, outputs=None): - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRegression, self).call(targets, inputs) +name2func['logistic'] = LogisticRegression - def predict(self, W, X): - X = _check_data_2d_atls(bm.as_jax(X)) - X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRegression, self).predict(W, X) + +class PolynomialRegression(LinearRegression): + def __init__( + self, + degree: int = 2, + name: str = None, + add_bias: bool = False, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(PolynomialRegression, self).__init__(name=name, + max_iter=max_iter, + learning_rate=learning_rate, + gradient_descent=gradient_descent) + self.degree = degree + self.add_bias = add_bias + + def call(self, targets, inputs, outputs=None): + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRegression, self).call(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.as_jax(X)) + X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRegression, self).predict(W, X) name2func['polynomial'] = PolynomialRegression class PolynomialRidgeRegression(RidgeRegression): - def __init__( - self, - alpha: float = 1.0, - degree: int = 2, - name: str = None, - add_bias: bool = False, - - # parameters for using gradient descent - max_iter: int = 1000, - learning_rate: float = 0.001, - gradient_descent: bool = True, - ): - super(PolynomialRidgeRegression, self).__init__(alpha=alpha, - name=name, - max_iter=max_iter, - learning_rate=learning_rate, - gradient_descent=gradient_descent) - self.degree = degree - self.add_bias = add_bias - - def call(self, targets, inputs, outputs=None): - # checking - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRidgeRegression, self).call(targets, inputs) - - def predict(self, W, X): - X = _check_data_2d_atls(bm.as_jax(X)) - X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) - return super(PolynomialRidgeRegression, self).predict(W, X) + def __init__( + self, + alpha: float = 1.0, + degree: int = 2, + name: str = None, + add_bias: bool = False, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(PolynomialRidgeRegression, self).__init__(alpha=alpha, + name=name, + max_iter=max_iter, + learning_rate=learning_rate, + gradient_descent=gradient_descent) + self.degree = degree + self.add_bias = add_bias + + def call(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + inputs = polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRidgeRegression, self).call(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.as_jax(X)) + X = polynomial_features(X, degree=self.degree, add_bias=self.add_bias) + return super(PolynomialRidgeRegression, self).predict(W, X) name2func['polynomial_ridge'] = PolynomialRidgeRegression class ElasticNetRegression(RegressionAlgorithm): - """ + """ - Parameters: - ----------- - degree: int - The degree of the polynomial that the independent variable X will be transformed to. - reg_factor: float - The factor that will determine the amount of regularization and feature - shrinkage. - l1_ration: float - Weighs the contribution of l1 and l2 regularization. - n_iterations: float - The number of training iterations the algorithm will tune the weights for. - learning_rate: float - The step length that will be used when updating the weights. - """ - - def __init__( - self, - alpha: float = 1.0, - degree: int = 2, - l1_ratio: float = 0.5, - name: str = None, - add_bias: bool = False, - - # parameters for using gradient descent - max_iter: int = 1000, - learning_rate: float = 0.001, - gradient_descent: bool = True, - ): - super(ElasticNetRegression, self).__init__( - name=name, - max_iter=max_iter, - learning_rate=learning_rate, - regularizer=L1L2Regularization(alpha=alpha, l1_ratio=l1_ratio) - ) - self.degree = degree - self.add_bias = add_bias - self.gradient_descent = gradient_descent - assert gradient_descent - - def call(self, targets, inputs, outputs=None): - # checking - inputs = _check_data_2d_atls(bm.as_jax(inputs)) - targets = _check_data_2d_atls(bm.as_jax(targets)) - # solving - inputs = normalize(polynomial_features(inputs, degree=self.degree)) - return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs) - - def predict(self, W, X): - X = _check_data_2d_atls(bm.as_jax(X)) - X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) - return super(ElasticNetRegression, self).predict(W, X) + Parameters: + ----------- + degree: int + The degree of the polynomial that the independent variable X will be transformed to. + reg_factor: float + The factor that will determine the amount of regularization and feature + shrinkage. + l1_ration: float + Weighs the contribution of l1 and l2 regularization. + n_iterations: float + The number of training iterations the algorithm will tune the weights for. + learning_rate: float + The step length that will be used when updating the weights. + """ + + def __init__( + self, + alpha: float = 1.0, + degree: int = 2, + l1_ratio: float = 0.5, + name: str = None, + add_bias: bool = False, + + # parameters for using gradient descent + max_iter: int = 1000, + learning_rate: float = 0.001, + gradient_descent: bool = True, + ): + super(ElasticNetRegression, self).__init__( + name=name, + max_iter=max_iter, + learning_rate=learning_rate, + regularizer=L1L2Regularization(alpha=alpha, l1_ratio=l1_ratio) + ) + self.degree = degree + self.add_bias = add_bias + self.gradient_descent = gradient_descent + assert gradient_descent + + def call(self, targets, inputs, outputs=None): + # checking + inputs = _check_data_2d_atls(bm.as_jax(inputs)) + targets = _check_data_2d_atls(bm.as_jax(targets)) + # solving + inputs = normalize(polynomial_features(inputs, degree=self.degree)) + return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs) + + def predict(self, W, X): + X = _check_data_2d_atls(bm.as_jax(X)) + X = normalize(polynomial_features(X, degree=self.degree, add_bias=self.add_bias)) + return super(ElasticNetRegression, self).predict(W, X) name2func['elastic_net'] = ElasticNetRegression def get_supported_offline_methods(): - """Get all supported offline training methods.""" - return tuple(name2func.keys()) + """Get all supported offline training methods.""" + return tuple(name2func.keys()) def register_offline_method(name: str, method: OfflineAlgorithm): - """Register a new offline learning method. + """Register a new offline learning method. - Parameters:: + Parameters:: - name: str - The method name. - method: OfflineAlgorithm - The function method. - """ - if name in name2func: - raise ValueError(f'"{name}" has been registered in offline training methods.') - if not isinstance(method, OfflineAlgorithm): - raise ValueError(f'"method" must be an instance {OfflineAlgorithm.__name__}, but we got {type(method)}') - name2func[name] = method + name: str + The method name. + method: OfflineAlgorithm + The function method. + """ + if name in name2func: + raise ValueError(f'"{name}" has been registered in offline training methods.') + if not isinstance(method, OfflineAlgorithm): + raise ValueError(f'"method" must be an instance {OfflineAlgorithm.__name__}, but we got {type(method)}') + name2func[name] = method def get(name: str) -> OfflineAlgorithm: - """Get the training function according to the training method name.""" - if name not in name2func: - raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n' - f'But we got {name}.') - return name2func[name] + """Get the training function according to the training method name.""" + if name not in name2func: + raise ValueError(f'All offline methods are: {get_supported_offline_methods()}.\n' + f'But we got {name}.') + return name2func[name] diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py index f5fa51aeb..5e23c8097 100644 --- a/brainpy/algorithms/online.py +++ b/brainpy/algorithms/online.py @@ -7,204 +7,204 @@ from brainpy._src.math.object_transform.base import BrainPyObject __all__ = [ - # brainpy_object class - 'OnlineAlgorithm', + # brainpy_object class + 'OnlineAlgorithm', - # online learning algorithms - 'RLS', - 'LMS', + # online learning algorithms + 'RLS', + 'LMS', - # generic methods - 'get_supported_online_methods', - 'register_online_method', + # generic methods + 'get_supported_online_methods', + 'register_online_method', ] name2func = dict() class OnlineAlgorithm(BrainPyObject): - """Base class for online training algorithm.""" + """Base class for online training algorithm.""" - def __init__(self, name=None): - super(OnlineAlgorithm, self).__init__(name=name) + def __init__(self, name=None): + super(OnlineAlgorithm, self).__init__(name=name) - def __call__(self, *args, **kwargs): - """The training procedure. + def __call__(self, *args, **kwargs): + """The training procedure. - Parameters:: + Parameters:: - identifier: str - The variable name. - target: ArrayType - The 2d target data with the shape of `(num_batch, num_output)`. - input: ArrayType - The 2d input data with the shape of `(num_batch, num_input)`. - output: ArrayType - The 2d output data with the shape of `(num_batch, num_output)`. + identifier: str + The variable name. + target: ArrayType + The 2d target data with the shape of `(num_batch, num_output)`. + input: ArrayType + The 2d input data with the shape of `(num_batch, num_input)`. + output: ArrayType + The 2d output data with the shape of `(num_batch, num_output)`. - Returns:: + Returns:: - weight: ArrayType - The weights after fit. - """ - return self.call(*args, **kwargs) + weight: ArrayType + The weights after fit. + """ + return self.call(*args, **kwargs) - def register_target(self, *args, **kwargs): - pass + def register_target(self, *args, **kwargs): + pass - def call(self, target, input, output, identifier: str=''): - """The training procedure. + def call(self, target, input, output, identifier: str = ''): + """The training procedure. - Parameters:: + Parameters:: - identifier: str - The variable name. - target: ArrayType - The 2d target data with the shape of `(num_batch, num_output)`. - input: ArrayType - The 2d input data with the shape of `(num_batch, num_input)`. - output: ArrayType - The 2d output data with the shape of `(num_batch, num_output)`. + identifier: str + The variable name. + target: ArrayType + The 2d target data with the shape of `(num_batch, num_output)`. + input: ArrayType + The 2d input data with the shape of `(num_batch, num_input)`. + output: ArrayType + The 2d output data with the shape of `(num_batch, num_output)`. - Returns:: + Returns:: - weight: ArrayType - The weights after fit. - """ - raise NotImplementedError('Must implement the call() function by the subclass itself.') + weight: ArrayType + The weights after fit. + """ + raise NotImplementedError('Must implement the call() function by the subclass itself.') - def __repr__(self): - return self.__class__.__name__ + def __repr__(self): + return self.__class__.__name__ class RLS(OnlineAlgorithm): - """The recursive least squares (RLS) algorithm. - - RLS is an adaptive filter algorithm that recursively finds the - coefficients that minimize a weighted linear least squares cost - function relating to the input signals. This approach is in - contrast to other algorithms such as the least mean squares - (LMS) that aim to reduce the mean square error. - - See Also:: - - LMS, ForceLearning - - Parameters:: - - alpha: float - The learning rate. - name: str - The algorithm name. - - """ - - postfix = '.rls.P' - - def __init__(self, alpha=0.1, name=None): - super(RLS, self).__init__(name=name) - self.alpha = alpha - - def register_target( - self, - feature_in: int, - identifier: str = '', - ): - identifier = identifier + self.postfix - self.implicit_vars[identifier] = bm.Variable(jnp.eye(feature_in) * self.alpha) - - def call( - self, - target: jax.Array, - input: jax.Array, - output: jax.Array, - identifier: str = '', - ): - identifier = identifier + self.postfix - P = self.implicit_vars[identifier] - input = bm.as_jax(input) - output = bm.as_jax(output) - target = bm.as_jax(target) - if input.ndim == 1: input = jnp.expand_dims(input, 0) - if target.ndim == 1: target = jnp.expand_dims(target, 0) - if output.ndim == 1: output = jnp.expand_dims(output, 0) - assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' - assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' - assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' - k = jnp.dot(P.value, input.T) # (num_input, num_batch) - hPh = jnp.dot(input, k) # (num_batch, num_batch) - c = jnp.sum(1.0 / (1.0 + hPh)) # () - P -= c * jnp.dot(k, k.T) # (num_input, num_input) - e = output - target # (num_batch, num_output) - dw = -c * jnp.dot(k, e) # (num_input, num_output) - return dw + """The recursive least squares (RLS) algorithm. + + RLS is an adaptive filter algorithm that recursively finds the + coefficients that minimize a weighted linear least squares cost + function relating to the input signals. This approach is in + contrast to other algorithms such as the least mean squares + (LMS) that aim to reduce the mean square error. + + See Also:: + + LMS, ForceLearning + + Parameters:: + + alpha: float + The learning rate. + name: str + The algorithm name. + + """ + + postfix = '.rls.P' + + def __init__(self, alpha=0.1, name=None): + super(RLS, self).__init__(name=name) + self.alpha = alpha + + def register_target( + self, + feature_in: int, + identifier: str = '', + ): + identifier = identifier + self.postfix + self.implicit_vars[identifier] = bm.Variable(jnp.eye(feature_in) * self.alpha) + + def call( + self, + target: jax.Array, + input: jax.Array, + output: jax.Array, + identifier: str = '', + ): + identifier = identifier + self.postfix + P = self.implicit_vars[identifier] + input = bm.as_jax(input) + output = bm.as_jax(output) + target = bm.as_jax(target) + if input.ndim == 1: input = jnp.expand_dims(input, 0) + if target.ndim == 1: target = jnp.expand_dims(target, 0) + if output.ndim == 1: output = jnp.expand_dims(output, 0) + assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' + assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' + assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' + k = jnp.dot(P.value, input.T) # (num_input, num_batch) + hPh = jnp.dot(input, k) # (num_batch, num_batch) + c = jnp.sum(1.0 / (1.0 + hPh)) # () + P -= c * jnp.dot(k, k.T) # (num_input, num_input) + e = output - target # (num_batch, num_output) + dw = -c * jnp.dot(k, e) # (num_input, num_output) + return dw name2func['rls'] = RLS class LMS(OnlineAlgorithm): - """The least mean squares (LMS). - - LMS algorithms are a class of adaptive filter used to mimic a desired filter - by finding the filter coefficients that relate to producing the least mean - square of the error signal (difference between the desired and the actual signal). - It is a stochastic gradient descent method in that the filter is only adapted - based on the error at the current time. It was invented in 1960 by - Stanford University professor Bernard Widrow and his first Ph.D. student, Ted Hoff. - - Parameters:: - - alpha: float - The learning rate. - name: str - The target name. - """ - - def __init__(self, alpha=0.1, name=None): - super(LMS, self).__init__(name=name) - self.alpha = alpha - - def call(self, target, input, output, identifier: str=''): - if input.ndim == 1: input = jnp.expand_dims(input, 0) - if target.ndim == 1: target = jnp.expand_dims(target, 0) - if output.ndim == 1: output = jnp.expand_dims(output, 0) - assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' - assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' - assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' - error = bm.as_jax(output - target) - input = bm.as_jax(input) - return -self.alpha * jnp.sum(vmap(jnp.outer)(input, error), axis=0) + """The least mean squares (LMS). + + LMS algorithms are a class of adaptive filter used to mimic a desired filter + by finding the filter coefficients that relate to producing the least mean + square of the error signal (difference between the desired and the actual signal). + It is a stochastic gradient descent method in that the filter is only adapted + based on the error at the current time. It was invented in 1960 by + Stanford University professor Bernard Widrow and his first Ph.D. student, Ted Hoff. + + Parameters:: + + alpha: float + The learning rate. + name: str + The target name. + """ + + def __init__(self, alpha=0.1, name=None): + super(LMS, self).__init__(name=name) + self.alpha = alpha + + def call(self, target, input, output, identifier: str = ''): + if input.ndim == 1: input = jnp.expand_dims(input, 0) + if target.ndim == 1: target = jnp.expand_dims(target, 0) + if output.ndim == 1: output = jnp.expand_dims(output, 0) + assert input.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {input.shape}' + assert target.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {target.shape}' + assert output.ndim == 2, f'should be a 2D array with shape of (batch, feature). Got {output.shape}' + error = bm.as_jax(output - target) + input = bm.as_jax(input) + return -self.alpha * jnp.sum(vmap(jnp.outer)(input, error), axis=0) name2func['lms'] = LMS def get_supported_online_methods(): - """Get all supported online training methods.""" - return tuple(name2func.keys()) + """Get all supported online training methods.""" + return tuple(name2func.keys()) def register_online_method(name: str, method: OnlineAlgorithm): - """Register a new oneline learning method. + """Register a new oneline learning method. - Parameters:: + Parameters:: - name: str - The method name. - method: callable - The function method. - """ - if name in name2func: - raise ValueError(f'"{name}" has been registered in online training methods. Please change another name.') - if not isinstance(method, OnlineAlgorithm): - raise ValueError(f'"method" must be an instance of {OnlineAlgorithm.__name__}, but we got {type(method)}') - name2func[name] = method + name: str + The method name. + method: callable + The function method. + """ + if name in name2func: + raise ValueError(f'"{name}" has been registered in online training methods. Please change another name.') + if not isinstance(method, OnlineAlgorithm): + raise ValueError(f'"method" must be an instance of {OnlineAlgorithm.__name__}, but we got {type(method)}') + name2func[name] = method def get(name: str): - """Get the training function according to the training method name.""" - if name not in name2func: - raise ValueError(f'All online methods are: {get_supported_online_methods()}.\n' - f'But we got {name}.') - return name2func[name] + """Get the training function according to the training method name.""" + if name not in name2func: + raise ValueError(f'All online methods are: {get_supported_online_methods()}.\n' + f'But we got {name}.') + return name2func[name] diff --git a/brainpy/algorithms/utils.py b/brainpy/algorithms/utils.py index 2828db854..0265a3201 100644 --- a/brainpy/algorithms/utils.py +++ b/brainpy/algorithms/utils.py @@ -1,111 +1,111 @@ # -*- coding: utf-8 -*- -import brainpy.math as bm - from itertools import combinations_with_replacement +import brainpy.math as bm + __all__ = [ - 'Sigmoid', - 'Regularization', - 'L1Regularization', - 'L2Regularization', - 'L1L2Regularization', - - 'polynomial_features', - 'normalize', + 'Sigmoid', + 'Regularization', + 'L1Regularization', + 'L2Regularization', + 'L1L2Regularization', + + 'polynomial_features', + 'normalize', ] class Sigmoid(object): - def __call__(self, x): - return 1 / (1 + bm.exp(-x)) + def __call__(self, x): + return 1 / (1 + bm.exp(-x)) - def grad(self, x): - exp = bm.exp(-x) - return exp / (1 + exp) ** 2 + def grad(self, x): + exp = bm.exp(-x) + return exp / (1 + exp) ** 2 class Regularization(object): - def __init__(self, alpha): - self.alpha = alpha + def __init__(self, alpha): + self.alpha = alpha - def __call__(self, x): - return 0 + def __call__(self, x): + return 0 - def grad(self, x): - return 0 + def grad(self, x): + return 0 class L1Regularization(Regularization): - """L1 Regularization.""" + """L1 Regularization.""" - def __init__(self, alpha): - super(L1Regularization, self).__init__(alpha=alpha) + def __init__(self, alpha): + super(L1Regularization, self).__init__(alpha=alpha) - def __call__(self, w): - return self.alpha * bm.linalg.norm(w) + def __call__(self, w): + return self.alpha * bm.linalg.norm(w) - def grad(self, w): - return self.alpha * bm.sign(w) + def grad(self, w): + return self.alpha * bm.sign(w) class L2Regularization(Regularization): - """L2 Regularization.""" + """L2 Regularization.""" - def __init__(self, alpha): - super(L2Regularization, self).__init__(alpha=alpha) + def __init__(self, alpha): + super(L2Regularization, self).__init__(alpha=alpha) - def __call__(self, w): - return self.alpha * 0.5 * w.T.dot(w) + def __call__(self, w): + return self.alpha * 0.5 * w.T.dot(w) - def grad(self, w): - return self.alpha * w + def grad(self, w): + return self.alpha * w class L1L2Regularization(Regularization): - """L1 and L2 Regularization.""" + """L1 and L2 Regularization.""" - def __init__(self, alpha, l1_ratio=0.5): - super(L1L2Regularization, self).__init__(alpha=alpha) - self.l1_ratio = l1_ratio + def __init__(self, alpha, l1_ratio=0.5): + super(L1L2Regularization, self).__init__(alpha=alpha) + self.l1_ratio = l1_ratio - def __call__(self, w): - l1_contr = self.l1_ratio * bm.linalg.norm(w) - l2_contr = (1 - self.l1_ratio) * 0.5 * w.T.dot(w) - return self.alpha * (l1_contr + l2_contr) + def __call__(self, w): + l1_contr = self.l1_ratio * bm.linalg.norm(w) + l2_contr = (1 - self.l1_ratio) * 0.5 * w.T.dot(w) + return self.alpha * (l1_contr + l2_contr) - def grad(self, w): - l1_contr = self.l1_ratio * bm.sign(w) - l2_contr = (1 - self.l1_ratio) * w - return self.alpha * (l1_contr + l2_contr) + def grad(self, w): + l1_contr = self.l1_ratio * bm.sign(w) + l2_contr = (1 - self.l1_ratio) * w + return self.alpha * (l1_contr + l2_contr) def index_combinations(n_features, degree): - combs = [combinations_with_replacement(range(n_features), i) for i in range(2, degree + 1)] - flat_combs = [item for sublist in combs for item in sublist] - return flat_combs + combs = [combinations_with_replacement(range(n_features), i) for i in range(2, degree + 1)] + flat_combs = [item for sublist in combs for item in sublist] + return flat_combs def polynomial_features(X, degree: int, add_bias: bool = True): - n_samples, n_features = X.shape - combinations = index_combinations(n_features, degree) - if len(combinations) == 0: - return bm.insert(X, 0, 1, axis=1) if add_bias else X - if add_bias: - n_features += 1 - X_new = bm.zeros((n_samples, 1 + n_features + len(combinations))) - if add_bias: - X_new[:, 0] = 1 - X_new[:, 1:n_features] = X - else: - X_new[:, :n_features] = X - for i, index_combs in enumerate(combinations): - X_new[:, n_features + i] = bm.prod(X[:, index_combs], axis=1) - return X_new + n_samples, n_features = X.shape + combinations = index_combinations(n_features, degree) + if len(combinations) == 0: + return bm.insert(X, 0, 1, axis=1) if add_bias else X + if add_bias: + n_features += 1 + X_new = bm.zeros((n_samples, 1 + n_features + len(combinations))) + if add_bias: + X_new[:, 0] = 1 + X_new[:, 1:n_features] = X + else: + X_new[:, :n_features] = X + for i, index_combs in enumerate(combinations): + X_new[:, n_features + i] = bm.prod(X[:, index_combs], axis=1) + return X_new def normalize(X, axis=-1, order=2): - """ Normalize the dataset X """ - l2 = bm.atleast_1d(bm.linalg.norm(X, order, axis)) - l2 = bm.where(l2 == 0, 1, l2) - return X / bm.expand_dims(l2, axis) + """ Normalize the dataset X """ + l2 = bm.atleast_1d(bm.linalg.norm(X, order, axis)) + l2 = bm.where(l2 == 0, 1, l2) + return X / bm.expand_dims(l2, axis) diff --git a/brainpy/check.py b/brainpy/check.py index 07a3efbd3..5166760b0 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -17,32 +17,32 @@ BrainPyObject = None __all__ = [ - 'is_checking', - 'turn_on', - 'turn_off', - - 'is_shape_consistency', - 'is_shape_broadcastable', - 'check_shape_except_batch', - 'check_shape', - 'is_dict_data', - 'is_callable', - 'is_initializer', - 'is_connector', - 'is_float', - 'is_integer', - 'is_string', - 'is_sequence', - 'is_subclass', - 'is_instance', - 'is_elem_or_seq_or_dict', - 'is_all_vars', - 'is_all_objs', - 'jit_error', - 'jit_error_checking', - 'jit_error_checking_no_args', - - 'serialize_kwargs', + 'is_checking', + 'turn_on', + 'turn_off', + + 'is_shape_consistency', + 'is_shape_broadcastable', + 'check_shape_except_batch', + 'check_shape', + 'is_dict_data', + 'is_callable', + 'is_initializer', + 'is_connector', + 'is_float', + 'is_integer', + 'is_string', + 'is_sequence', + 'is_subclass', + 'is_instance', + 'is_elem_or_seq_or_dict', + 'is_all_vars', + 'is_all_objs', + 'jit_error', + 'jit_error_checking', + 'jit_error_checking_no_args', + + 'serialize_kwargs', ] _check = True @@ -50,145 +50,145 @@ def is_checking(): - """Whether the checking is turn on.""" - return _check + """Whether the checking is turn on.""" + return _check def turn_on(): - """Turn on the checking.""" - global _check - _check = True + """Turn on the checking.""" + global _check + _check = True def turn_off(): - """Turn off the checking.""" - global _check - _check = False + """Turn off the checking.""" + global _check + _check = False # def turn_off_name_check def is_shape_consistency(shapes, free_axes=None, return_format_shapes=False): - assert isinstance(shapes, (tuple, list)), f'Must be a sequence of shape. While we got {shapes}.' - for shape in shapes: - assert isinstance(shapes, (tuple, list)), (f'Must be a sequence of shape. While ' - f'we got one element is {shape}.') - dims = onp.unique([len(shape) for shape in shapes]) - if len(dims) > 1: - raise ValueError(f'The provided shape dimensions are not consistent. ') - if free_axes is None: - type_ = 'none' - free_axes = () - elif isinstance(free_axes, (tuple, list)): - type_ = 'seq' - free_axes = tuple(free_axes) - elif isinstance(free_axes, int): - type_ = 'int' - free_axes = (free_axes,) - else: - raise ValueError - free_axes = [(dims[0] + axis if axis < 0 else axis) for axis in free_axes] - all_shapes = [] - for shape in shapes: - assert isinstance(shapes, (tuple, list)), (f'Must be a sequence of shape. While ' - f'we got one element is {shape}.') - shape = tuple([sh for i, sh in enumerate(shape) if i not in free_axes]) - all_shapes.append(shape) - unique_shape = tuple(set(all_shapes)) - if len(unique_shape) > 1: - if len(free_axes): - raise ValueError(f'The provided shape (without axes of {free_axes}) are not consistent.') - else: - raise ValueError(f'The provided shape are not consistent.') - if return_format_shapes: - if type_ == 'int': - free_shapes = tuple([shape[free_axes[0]] for shape in shapes]) - elif type_ == 'seq': - free_shapes = tuple([tuple([shape[axis] for axis in free_axes]) for shape in shapes]) + assert isinstance(shapes, (tuple, list)), f'Must be a sequence of shape. While we got {shapes}.' + for shape in shapes: + assert isinstance(shapes, (tuple, list)), (f'Must be a sequence of shape. While ' + f'we got one element is {shape}.') + dims = onp.unique([len(shape) for shape in shapes]) + if len(dims) > 1: + raise ValueError(f'The provided shape dimensions are not consistent. ') + if free_axes is None: + type_ = 'none' + free_axes = () + elif isinstance(free_axes, (tuple, list)): + type_ = 'seq' + free_axes = tuple(free_axes) + elif isinstance(free_axes, int): + type_ = 'int' + free_axes = (free_axes,) else: - free_shapes = None - return unique_shape[0], free_shapes + raise ValueError + free_axes = [(dims[0] + axis if axis < 0 else axis) for axis in free_axes] + all_shapes = [] + for shape in shapes: + assert isinstance(shapes, (tuple, list)), (f'Must be a sequence of shape. While ' + f'we got one element is {shape}.') + shape = tuple([sh for i, sh in enumerate(shape) if i not in free_axes]) + all_shapes.append(shape) + unique_shape = tuple(set(all_shapes)) + if len(unique_shape) > 1: + if len(free_axes): + raise ValueError(f'The provided shape (without axes of {free_axes}) are not consistent.') + else: + raise ValueError(f'The provided shape are not consistent.') + if return_format_shapes: + if type_ == 'int': + free_shapes = tuple([shape[free_axes[0]] for shape in shapes]) + elif type_ == 'seq': + free_shapes = tuple([tuple([shape[axis] for axis in free_axes]) for shape in shapes]) + else: + free_shapes = None + return unique_shape[0], free_shapes def is_shape_broadcastable(shapes, free_axes=(), return_format_shapes=False): - """Check whether the given shapes are broadcastable. + """Check whether the given shapes are broadcastable. - See https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html - for more details. + See https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html + for more details. - Parameters:: + Parameters:: - shapes - free_axes - return_format_shapes + shapes + free_axes + return_format_shapes - Returns:: + Returns:: - """ - max_dim = max([len(shape) for shape in shapes]) - shapes = [[1] * (max_dim - len(s)) + list(s) for s in shapes] - return is_shape_consistency(shapes, free_axes, return_format_shapes) + """ + max_dim = max([len(shape) for shape in shapes]) + shapes = [[1] * (max_dim - len(s)) + list(s) for s in shapes] + return is_shape_consistency(shapes, free_axes, return_format_shapes) def check_shape_except_batch(shape1, shape2, batch_idx=0, mode='raise'): - """Check whether two shapes are compatible except the batch size axis.""" - assert mode in ['raise', 'bool'] - if len(shape2) != len(shape1): - if mode == 'raise': - raise ValueError(f'Dimension mismatch between two shapes. ' - f'{shape1} != {shape2}') - else: - return False - new_shape1 = list(shape1) - new_shape2 = list(shape2) - new_shape1.pop(batch_idx) - new_shape2.pop(batch_idx) - if new_shape1 != new_shape2: - if mode == 'raise': - raise ValueError(f'Two shapes {new_shape1} and {new_shape2} are not ' - f'consistent when excluding the batch axis ' - f'{batch_idx}') - else: - return False - return True + """Check whether two shapes are compatible except the batch size axis.""" + assert mode in ['raise', 'bool'] + if len(shape2) != len(shape1): + if mode == 'raise': + raise ValueError(f'Dimension mismatch between two shapes. ' + f'{shape1} != {shape2}') + else: + return False + new_shape1 = list(shape1) + new_shape2 = list(shape2) + new_shape1.pop(batch_idx) + new_shape2.pop(batch_idx) + if new_shape1 != new_shape2: + if mode == 'raise': + raise ValueError(f'Two shapes {new_shape1} and {new_shape2} are not ' + f'consistent when excluding the batch axis ' + f'{batch_idx}') + else: + return False + return True def check_shape(all_shapes, free_axes: Union[Sequence[int], int] = -1): - # check "all_shapes" - if isinstance(all_shapes, dict): - all_shapes = tuple(all_shapes.values()) - elif isinstance(all_shapes, (tuple, list)): - all_shapes = tuple(all_shapes) - else: - raise ValueError - # maximum number of dimension - max_dim = max([len(shape) for shape in all_shapes]) - all_shapes = [[1] * (max_dim - len(s)) + list(s) for s in all_shapes] - # check "free_axes" - type_ = 'seq' - if isinstance(free_axes, int): - free_axes = (free_axes,) - type_ = 'int' - elif isinstance(free_axes, (tuple, list)): - free_axes = tuple(free_axes) - assert isinstance(free_axes, tuple) - free_axes = [(axis + max_dim if axis < 0 else axis) for axis in free_axes] - fixed_axes = [i for i in range(max_dim) if i not in free_axes] - # get all free shapes - if type_ == 'int': - free_shape = [shape[free_axes[0]] for shape in all_shapes] - else: - free_shape = [[shape[axis] for axis in free_axes] for shape in all_shapes] - # get all assumed fixed shapes - fixed_shapes = [[shape[axis] for shape in all_shapes] for axis in fixed_axes] - max_fixed_shapes = [max(shape) for shape in fixed_shapes] - # check whether they can broadcast compatible - for i, shape in enumerate(fixed_shapes): - if len(set(shape) - {1, max_fixed_shapes[i]}): - raise ValueError(f'Shapes out of axes {free_axes} are not ' - f'broadcast compatible: \n' - f'{all_shapes}') - return free_shape, max_fixed_shapes + # check "all_shapes" + if isinstance(all_shapes, dict): + all_shapes = tuple(all_shapes.values()) + elif isinstance(all_shapes, (tuple, list)): + all_shapes = tuple(all_shapes) + else: + raise ValueError + # maximum number of dimension + max_dim = max([len(shape) for shape in all_shapes]) + all_shapes = [[1] * (max_dim - len(s)) + list(s) for s in all_shapes] + # check "free_axes" + type_ = 'seq' + if isinstance(free_axes, int): + free_axes = (free_axes,) + type_ = 'int' + elif isinstance(free_axes, (tuple, list)): + free_axes = tuple(free_axes) + assert isinstance(free_axes, tuple) + free_axes = [(axis + max_dim if axis < 0 else axis) for axis in free_axes] + fixed_axes = [i for i in range(max_dim) if i not in free_axes] + # get all free shapes + if type_ == 'int': + free_shape = [shape[free_axes[0]] for shape in all_shapes] + else: + free_shape = [[shape[axis] for axis in free_axes] for shape in all_shapes] + # get all assumed fixed shapes + fixed_shapes = [[shape[axis] for shape in all_shapes] for axis in fixed_axes] + max_fixed_shapes = [max(shape) for shape in fixed_shapes] + # check whether they can broadcast compatible + for i, shape in enumerate(fixed_shapes): + if len(set(shape) - {1, max_fixed_shapes[i]}): + raise ValueError(f'Shapes out of axes {free_axes} are not ' + f'broadcast compatible: \n' + f'{all_shapes}') + return free_shape, max_fixed_shapes def is_dict_data(a_dict: Dict, @@ -196,35 +196,35 @@ def is_dict_data(a_dict: Dict, val_type: Union[Type, Tuple[Type, ...]] = None, name: str = None, allow_none: bool = True): - """Check the dictionary data. - """ - if allow_none and a_dict is None: - return None - name = '' if (name is None) else f'"{name}"' - if not isinstance(a_dict, dict): - raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}') - for key, value in a_dict.items(): - if (key_type is not None) and (not isinstance(key, key_type)): - raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' - f'while we got ({type(key)}, {type(value)})') - if (val_type is not None) and (not isinstance(value, val_type)): - raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' - f'while we got ({type(key)}, {type(value)})') - return a_dict + """Check the dictionary data. + """ + if allow_none and a_dict is None: + return None + name = '' if (name is None) else f'"{name}"' + if not isinstance(a_dict, dict): + raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}') + for key, value in a_dict.items(): + if (key_type is not None) and (not isinstance(key, key_type)): + raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' + f'while we got ({type(key)}, {type(value)})') + if (val_type is not None) and (not isinstance(value, val_type)): + raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), ' + f'while we got ({type(key)}, {type(value)})') + return a_dict def is_callable(fun: Callable, name: str = None, allow_none: bool = False): - name = '' if name is None else name - if fun is None: - if allow_none: - return None - else: - raise ValueError(f'{name} must be a callable function, but we got None.') - if not callable(fun): - raise ValueError(f'{name} should be a callable function. While we got {type(fun)}') - return fun + name = '' if name is None else name + if fun is None: + if allow_none: + return None + else: + raise ValueError(f'{name} must be a callable function, but we got None.') + if not callable(fun): + raise ValueError(f'{name} should be a callable function. While we got {type(fun)}') + return fun def is_initializer( @@ -232,31 +232,31 @@ def is_initializer( name: str = None, allow_none: bool = False ): - """Check the initializer. - """ - global Array - if Array is None: from brainpy._src.math.ndarray import BaseArray as Array - - global init - if init is None: - from brainpy import initialize - init = initialize - - name = '' if name is None else name - if initializer is None: - if allow_none: - return + """Check the initializer. + """ + global Array + if Array is None: from brainpy._src.math.ndarray import BaseArray as Array + + global init + if init is None: + from brainpy import initialize + init = initialize + + name = '' if name is None else name + if initializer is None: + if allow_none: + return + else: + raise ValueError(f'{name} must be an initializer, but we got None.') + if isinstance(initializer, init.Initializer): + return initializer + elif isinstance(initializer, (Array, jax.Array)): + return initializer + elif callable(initializer): + return initializer else: - raise ValueError(f'{name} must be an initializer, but we got None.') - if isinstance(initializer, init.Initializer): - return initializer - elif isinstance(initializer, (Array, jax.Array)): - return initializer - elif callable(initializer): - return initializer - else: - raise ValueError(f'{name} should be an instance of brainpy.init.Initializer, ' - f'tensor or callable function. While we got {type(initializer)}') + raise ValueError(f'{name} should be an instance of brainpy.init.Initializer, ' + f'tensor or callable function. While we got {type(initializer)}') def is_connector( @@ -264,29 +264,29 @@ def is_connector( name: str = None, allow_none: bool = False ): - """Check the connector. - """ - global Array - if Array is None: - from brainpy._src.math.ndarray import BaseArray as Array - global conn - if conn is None: from brainpy import connect as conn - - name = '' if name is None else name - if connector is None: - if allow_none: - return None + """Check the connector. + """ + global Array + if Array is None: + from brainpy._src.math.ndarray import BaseArray as Array + global conn + if conn is None: from brainpy import connect as conn + + name = '' if name is None else name + if connector is None: + if allow_none: + return None + else: + raise ValueError(f'{name} must be an initializer, but we got None.') + if isinstance(connector, conn.Connector): + return connector + elif isinstance(connector, (Array, jax.Array)): + return connector + elif callable(connector): + return connector else: - raise ValueError(f'{name} must be an initializer, but we got None.') - if isinstance(connector, conn.Connector): - return connector - elif isinstance(connector, (Array, jax.Array)): - return connector - elif callable(connector): - return connector - else: - raise ValueError(f'{name} should be an instance of brainpy.conn.Connector, ' - f'tensor or callable function. While we got {type(connector)}') + raise ValueError(f'{name} should be an instance of brainpy.conn.Connector, ' + f'tensor or callable function. While we got {type(connector)}') def is_sequence( @@ -295,20 +295,20 @@ def is_sequence( elem_type: Union[type, Sequence[type]] = None, allow_none: bool = True ): - if name is None: name = '' - if value is None: - if allow_none: - return - else: - raise ValueError(f'{name} must be a sequence, but got None') - if not isinstance(value, (tuple, list)): - raise ValueError(f'{name} should be a sequence, but we got a {type(value)}') - if elem_type is not None: - for v in value: - if not isinstance(v, elem_type): - raise ValueError(f'Elements in {name} should be {elem_type}, ' - f'but we got {type(elem_type)}: {v}') - return value + if name is None: name = '' + if value is None: + if allow_none: + return + else: + raise ValueError(f'{name} must be a sequence, but got None') + if not isinstance(value, (tuple, list)): + raise ValueError(f'{name} should be a sequence, but we got a {type(value)}') + if elem_type is not None: + for v in value: + if not isinstance(v, elem_type): + raise ValueError(f'Elements in {name} should be {elem_type}, ' + f'but we got {type(elem_type)}: {v}') + return value def is_float( @@ -319,107 +319,107 @@ def is_float( allow_none: bool = False, allow_int: bool = True ) -> float: - """Check float type. - - Parameters:: - - value: Any - name: optional, str - min_bound: optional, float - The allowed minimum value. - max_bound: optional, float - The allowed maximum value. - allow_none: bool - Whether allow the value is None. - allow_int: bool - Whether allow the value be an integer. - """ - if name is None: name = '' - if value is None: - if allow_none: - return None + """Check float type. + + Parameters:: + + value: Any + name: optional, str + min_bound: optional, float + The allowed minimum value. + max_bound: optional, float + The allowed maximum value. + allow_none: bool + Whether allow the value is None. + allow_int: bool + Whether allow the value be an integer. + """ + if name is None: name = '' + if value is None: + if allow_none: + return None + else: + raise ValueError(f'{name} must be a float, but got None') + if allow_int: + if not isinstance(value, (float, int, np.integer, np.floating)): + raise ValueError(f'{name} must be a float, but got {type(value)}') else: - raise ValueError(f'{name} must be a float, but got None') - if allow_int: - if not isinstance(value, (float, int, np.integer, np.floating)): - raise ValueError(f'{name} must be a float, but got {type(value)}') - else: - if not isinstance(value, (float, np.floating)): - raise ValueError(f'{name} must be a float, but got {type(value)}') - if min_bound is not None: - jit_error_checking_no_args(value < min_bound, - ValueError(f"{name} must be a float bigger than {min_bound}, " - f"while we got {value}")) - - if max_bound is not None: - jit_error_checking_no_args(value > max_bound, - ValueError(f"{name} must be a float smaller than {max_bound}, " - f"while we got {value}")) - return value + if not isinstance(value, (float, np.floating)): + raise ValueError(f'{name} must be a float, but got {type(value)}') + if min_bound is not None: + jit_error_checking_no_args(value < min_bound, + ValueError(f"{name} must be a float bigger than {min_bound}, " + f"while we got {value}")) + + if max_bound is not None: + jit_error_checking_no_args(value > max_bound, + ValueError(f"{name} must be a float smaller than {max_bound}, " + f"while we got {value}")) + return value def is_integer(value: int, name=None, min_bound=None, max_bound=None, allow_none=False): - """Check integer type. - - Parameters:: - - value: int, optional - name: optional, str - min_bound: optional, int - The allowed minimum value. - max_bound: optional, int - The allowed maximum value. - allow_none: bool - Whether allow the value is None. - """ - if name is None: name = '' - if value is None: - if allow_none: - return - else: - raise ValueError(f'{name} must be an int, but got None') - if not isinstance(value, (int, np.integer)): - if hasattr(value, '__array__'): - if not (np.issubdtype(value.dtype, np.integer) and value.ndim == 0 and value.size == 1): - raise ValueError(f'{name} must be an int, but got {value}') - else: - raise ValueError(f'{name} must be an int, but got {value}') - if min_bound is not None: - jit_error_checking_no_args(jnp.any(value < min_bound), - ValueError(f"{name} must be an int bigger than {min_bound}, " - f"while we got {value}")) - if max_bound is not None: - jit_error_checking_no_args(jnp.any(value > max_bound), - ValueError(f"{name} must be an int smaller than {max_bound}, " - f"while we got {value}")) - return value + """Check integer type. + + Parameters:: + + value: int, optional + name: optional, str + min_bound: optional, int + The allowed minimum value. + max_bound: optional, int + The allowed maximum value. + allow_none: bool + Whether allow the value is None. + """ + if name is None: name = '' + if value is None: + if allow_none: + return + else: + raise ValueError(f'{name} must be an int, but got None') + if not isinstance(value, (int, np.integer)): + if hasattr(value, '__array__'): + if not (np.issubdtype(value.dtype, np.integer) and value.ndim == 0 and value.size == 1): + raise ValueError(f'{name} must be an int, but got {value}') + else: + raise ValueError(f'{name} must be an int, but got {value}') + if min_bound is not None: + jit_error_checking_no_args(jnp.any(value < min_bound), + ValueError(f"{name} must be an int bigger than {min_bound}, " + f"while we got {value}")) + if max_bound is not None: + jit_error_checking_no_args(jnp.any(value > max_bound), + ValueError(f"{name} must be an int smaller than {max_bound}, " + f"while we got {value}")) + return value def is_string(value: str, name: str = None, candidates: Sequence[str] = None, allow_none=False): - """Check string type. - """ - if name is None: name = '' - if value is None: - if allow_none: - return None - else: - raise ValueError(f'{name} must be a str, but got None') - if candidates is not None: - if value not in candidates: - raise ValueError(f'{name} must be a str in {candidates}, ' - f'but we got {value}') - return value + """Check string type. + """ + if name is None: name = '' + if value is None: + if allow_none: + return None + else: + raise ValueError(f'{name} must be a str, but got None') + if candidates is not None: + if value not in candidates: + raise ValueError(f'{name} must be a str in {candidates}, ' + f'but we got {value}') + return value def serialize_kwargs(shared_kwargs: Optional[Dict]): - """Serialize kwargs.""" - shared_kwargs = dict() if shared_kwargs is None else shared_kwargs - is_dict_data(shared_kwargs, - key_type=str, - val_type=(bool, float, int, complex, str), - name='shared_kwargs') - shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())} - return str(shared_kwargs) + """Serialize kwargs.""" + shared_kwargs = dict() if shared_kwargs is None else shared_kwargs + is_dict_data(shared_kwargs, + key_type=str, + val_type=(bool, float, int, complex, str), + name='shared_kwargs') + shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())} + return str(shared_kwargs) def is_subclass( @@ -427,50 +427,50 @@ def is_subclass( supported_types: Union[Type, Sequence[Type]], name: str = '' ) -> None: - r"""Check whether the instance is in the inheritance tree of the supported types. - - This function is used to check whether the given ``instance`` is an instance of - parent types in the inheritance hierarchy of the given ``supported_types``. - - - Here we have the following inheritance hierarchy:: - - A - / \ - B C - / \ / \ - D E F G - - If ``supported_types`` is ``[E, F]``, then - - - the instance of ``D`` or ``G`` will fail to pass the check. - - the instance of ``E`` or ``F`` will success to pass the check. - - the instance of ``B`` or ``C`` will also success to pass the check. - - the instance of ``A`` will success to pass the check too. - - Parameters:: - - instance: Any - The instance in the inheritance hierarchy tree. - supported_types: type, list of type, tuple of type - All types that are supported. - name: str - The checking target name. - """ - mode_type = type(instance) - if isinstance(supported_types, type): - supported_types = (supported_types,) - if not isinstance(supported_types, (tuple, list)): - raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {type(supported_types)}') - for smode in supported_types: - if not isinstance(smode, type): - raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {smode}') - checking = [issubclass(smode, mode_type) for smode in supported_types] - if any(checking): - return instance - else: - raise NotImplementedError(f"{name} does not support {instance}. We only support " - f"{', '.join([mode.__name__ for mode in supported_types])}. ") + r"""Check whether the instance is in the inheritance tree of the supported types. + + This function is used to check whether the given ``instance`` is an instance of + parent types in the inheritance hierarchy of the given ``supported_types``. + + + Here we have the following inheritance hierarchy:: + + A + / \ + B C + / \ / \ + D E F G + + If ``supported_types`` is ``[E, F]``, then + + - the instance of ``D`` or ``G`` will fail to pass the check. + - the instance of ``E`` or ``F`` will success to pass the check. + - the instance of ``B`` or ``C`` will also success to pass the check. + - the instance of ``A`` will success to pass the check too. + + Parameters:: + + instance: Any + The instance in the inheritance hierarchy tree. + supported_types: type, list of type, tuple of type + All types that are supported. + name: str + The checking target name. + """ + mode_type = type(instance) + if isinstance(supported_types, type): + supported_types = (supported_types,) + if not isinstance(supported_types, (tuple, list)): + raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {type(supported_types)}') + for smode in supported_types: + if not isinstance(smode, type): + raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {smode}') + checking = [issubclass(smode, mode_type) for smode in supported_types] + if any(checking): + return instance + else: + raise NotImplementedError(f"{name} does not support {instance}. We only support " + f"{', '.join([mode.__name__ for mode in supported_types])}. ") def is_instance( @@ -478,159 +478,158 @@ def is_instance( supported_types: Union[Type, Sequence[Type]], name: str = '' ): - r"""Check whether the ``instance`` is the instance of the given types. - - This function is used to check whether the given ``instance`` is an instance of - the given ``supported_types``. - - Here we have the following inheritance hierarchy:: - - A - / \ - B C - / \ / \ - D E F G - - If ``supported_types`` is ``[B, F]``, then - - - the instance of ``A`` or ``C`` or ``G`` will fail to pass the check. - - the instance of ``B`` or ``D`` or ``E`` or ``F`` will success to pass the check. - - Parameters:: - - instance: Any - The instance in the inheritance hierarchy tree. - supported_types: type, list of type, tuple of type - All types that are supported. - name: str - The checking target name. - """ - if not name: - name = 'We' - if not isinstance(instance, supported_types): - raise NotImplementedError(f"{name} expect to get an instance of {supported_types}." - f"But we got {type(instance)}. ") - return instance + r"""Check whether the ``instance`` is the instance of the given types. + + This function is used to check whether the given ``instance`` is an instance of + the given ``supported_types``. + + Here we have the following inheritance hierarchy:: + + A + / \ + B C + / \ / \ + D E F G + + If ``supported_types`` is ``[B, F]``, then + + - the instance of ``A`` or ``C`` or ``G`` will fail to pass the check. + - the instance of ``B`` or ``D`` or ``E`` or ``F`` will success to pass the check. + + Parameters:: + + instance: Any + The instance in the inheritance hierarchy tree. + supported_types: type, list of type, tuple of type + All types that are supported. + name: str + The checking target name. + """ + if not name: + name = 'We' + if not isinstance(instance, supported_types): + raise NotImplementedError(f"{name} expect to get an instance of {supported_types}." + f"But we got {type(instance)}. ") + return instance def is_elem_or_seq_or_dict(targets: Any, elem_type: Union[type, Tuple[type, ...]], out_as: str = 'tuple'): - assert out_as in ['tuple', 'list', 'dict', None], 'Only support to output as tuple/list/dict/None' - - if targets is None: - keys = [] - vals = [] - elif isinstance(targets, elem_type): - keys = [id(targets)] - vals = [targets] - elif isinstance(targets, (list, tuple)): - is_leaf = [isinstance(l, elem_type) for l in targets] - if not all(is_leaf): - raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.') - keys = [id(v) for v in targets] - vals = list(targets) - elif isinstance(targets, dict): - is_leaf = [isinstance(l, elem_type) for l in targets.values()] - if not all(is_leaf): - raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.') - keys = list(targets.keys()) - vals = list(targets.values()) - else: - raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.') - - if out_as is None: - return targets - elif out_as == 'list': - return vals - elif out_as == 'tuple': - return tuple(vals) - elif out_as == 'dict': - return dict(zip(keys, vals)) - else: - raise KeyError + assert out_as in ['tuple', 'list', 'dict', None], 'Only support to output as tuple/list/dict/None' + + if targets is None: + keys = [] + vals = [] + elif isinstance(targets, elem_type): + keys = [id(targets)] + vals = [targets] + elif isinstance(targets, (list, tuple)): + is_leaf = [isinstance(l, elem_type) for l in targets] + if not all(is_leaf): + raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.') + keys = [id(v) for v in targets] + vals = list(targets) + elif isinstance(targets, dict): + is_leaf = [isinstance(l, elem_type) for l in targets.values()] + if not all(is_leaf): + raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.') + keys = list(targets.keys()) + vals = list(targets.values()) + else: + raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.') + + if out_as is None: + return targets + elif out_as == 'list': + return vals + elif out_as == 'tuple': + return tuple(vals) + elif out_as == 'dict': + return dict(zip(keys, vals)) + else: + raise KeyError def is_all_vars(dyn_vars: Any, out_as: str = 'tuple'): - global var_obs - if var_obs is None: - from brainpy.math import Variable, VarList, VarDict - var_obs = (VarList, VarDict, Variable) + global var_obs + if var_obs is None: + from brainpy.math import Variable, VarList, VarDict + var_obs = (VarList, VarDict, Variable) - return is_elem_or_seq_or_dict(dyn_vars, var_obs, out_as) + return is_elem_or_seq_or_dict(dyn_vars, var_obs, out_as) def is_all_objs(targets: Any, out_as: str = 'tuple'): - global BrainPyObject - if BrainPyObject is None: - from brainpy._src.math.object_transform.base import BrainPyObject - return is_elem_or_seq_or_dict(targets, BrainPyObject, out_as) + global BrainPyObject + if BrainPyObject is None: + from brainpy._src.math.object_transform.base import BrainPyObject + return is_elem_or_seq_or_dict(targets, BrainPyObject, out_as) def _err_jit_true_branch(err_fun, x): - if isinstance(x, (tuple, list)): - x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x) - else: - x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype) - jax.pure_callback(err_fun, x_shape_dtype, x, vmap_method='sequential') - return + if isinstance(x, (tuple, list)): + x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x) + else: + x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype) + jax.pure_callback(err_fun, x_shape_dtype, x, vmap_method='sequential') + return def _err_jit_false_branch(x): - return + return def _cond(err_fun, pred, err_arg): - from brainpy._src.math.remove_vmap import remove_vmap + from brainpy._src.math.remove_vmap import remove_vmap - @wraps(err_fun) - def true_err_fun(arg, transforms): - err_fun(arg) + @wraps(err_fun) + def true_err_fun(arg, transforms): + err_fun(arg) - cond(remove_vmap(pred), - partial(_err_jit_true_branch, true_err_fun), - _err_jit_false_branch, - err_arg) + cond(remove_vmap(pred), + partial(_err_jit_true_branch, true_err_fun), + _err_jit_false_branch, + err_arg) def jit_error(pred, err_fun, err_arg=None): - """Check errors in a jit function. + """Check errors in a jit function. - Parameters:: + Parameters:: - pred: bool, Array - The boolean prediction. - err_fun: callable - The error function, which raise errors. - err_arg: any - The arguments which passed into `err_f`. - """ - from brainpy._src.math.interoperability import as_jax - partial(_cond, err_fun)(as_jax(pred), err_arg) + pred: bool, Array + The boolean prediction. + err_fun: callable + The error function, which raise errors. + err_arg: any + The arguments which passed into `err_f`. + """ + from brainpy._src.math.interoperability import as_jax + partial(_cond, err_fun)(as_jax(pred), err_arg) jit_error_checking = jit_error def jit_error_checking_no_args(pred: bool, err: Exception): - """Check errors in a jit function. - - Parameters:: + """Check errors in a jit function. - pred: bool - The boolean prediction. - err: Exception - The error. - """ - from brainpy._src.math.remove_vmap import remove_vmap - from brainpy._src.math.interoperability import as_jax + Parameters:: - assert isinstance(err, Exception), 'Must be instance of Exception.' + pred: bool + The boolean prediction. + err: Exception + The error. + """ + from brainpy._src.math.remove_vmap import remove_vmap + from brainpy._src.math.interoperability import as_jax - def true_err_fun(arg, transforms): - raise err + assert isinstance(err, Exception), 'Must be instance of Exception.' - cond(remove_vmap(as_jax(pred)), - lambda: jax.pure_callback(true_err_fun, None), - lambda: None) + def true_err_fun(arg, transforms): + raise err + cond(remove_vmap(as_jax(pred)), + lambda: jax.pure_callback(true_err_fun, None), + lambda: None) diff --git a/brainpy/checkpoints.py b/brainpy/checkpoints.py index dbe4a9336..24b573c83 100644 --- a/brainpy/checkpoints.py +++ b/brainpy/checkpoints.py @@ -2,10 +2,7 @@ from brainpy._src.checkpoints.serialization import ( - save as save, - load as load, - save_pytree as save_pytree, - load_pytree as load_pytree, - AsyncManager as AsyncManager + save_pytree as save_pytree, + load_pytree as load_pytree, + AsyncManager as AsyncManager ) - diff --git a/brainpy/connect.py b/brainpy/connect.py index fe7a9f426..dd38034b4 100644 --- a/brainpy/connect.py +++ b/brainpy/connect.py @@ -1,60 +1,59 @@ # -*- coding: utf-8 -*- from brainpy._src.connect.base import ( - Connector as Connector, - TwoEndConnector as TwoEndConnector, - OneEndConnector as OneEndConnector, - mat2coo as mat2coo, - mat2csc as mat2csc, - mat2csr as mat2csr, - csr2csc as csr2csc, - csr2mat as csr2mat, - csr2coo as csr2coo, - coo2csr as coo2csr, - coo2csc as coo2csc, - coo2mat as coo2mat, - coo2mat_num as coo2mat_num, - mat2mat_num as mat2mat_num, - visualizeMat as visualizeMat, - set_default_dtype as set_default_dtype, + Connector as Connector, + TwoEndConnector as TwoEndConnector, + OneEndConnector as OneEndConnector, + mat2coo as mat2coo, + mat2csc as mat2csc, + mat2csr as mat2csr, + csr2csc as csr2csc, + csr2mat as csr2mat, + csr2coo as csr2coo, + coo2csr as coo2csr, + coo2csc as coo2csc, + coo2mat as coo2mat, + coo2mat_num as coo2mat_num, + mat2mat_num as mat2mat_num, + visualizeMat as visualizeMat, + set_default_dtype as set_default_dtype, - CONN_MAT, - PRE_IDS, POST_IDS, - PRE2POST, POST2PRE, - PRE2SYN, POST2SYN, - PRE_SLICE, POST_SLICE, - COO, CSR, CSC + CONN_MAT, + PRE_IDS, POST_IDS, + PRE2POST, POST2PRE, + PRE2SYN, POST2SYN, + PRE_SLICE, POST_SLICE, + COO, CSR, CSC ) from brainpy._src.connect.custom_conn import ( - MatConn as MatConn, - IJConn as IJConn, - CSRConn as CSRConn, - SparseMatConn as SparseMatConn, + MatConn as MatConn, + IJConn as IJConn, + CSRConn as CSRConn, + SparseMatConn as SparseMatConn, ) from brainpy._src.connect.random_conn import ( - FixedProb as FixedProb, - FixedPreNum as FixedPreNum, - FixedPostNum as FixedPostNum, - FixedTotalNum as FixedTotalNum, - GaussianProb as GaussianProb, - ProbDist as ProbDist, - SmallWorld as SmallWorld, - ScaleFreeBA as ScaleFreeBA, - ScaleFreeBADual as ScaleFreeBADual, - PowerLaw as PowerLaw, + FixedProb as FixedProb, + FixedPreNum as FixedPreNum, + FixedPostNum as FixedPostNum, + FixedTotalNum as FixedTotalNum, + GaussianProb as GaussianProb, + ProbDist as ProbDist, + SmallWorld as SmallWorld, + ScaleFreeBA as ScaleFreeBA, + ScaleFreeBADual as ScaleFreeBADual, + PowerLaw as PowerLaw, ) - from brainpy._src.connect.regular_conn import ( - One2One as One2One, - one2one as one2one, - All2All as All2All, - all2all as all2all, - GridFour as GridFour, - grid_four as grid_four, - GridEight as GridEight, - grid_eight as grid_eight, - GridN as GridN, + One2One as One2One, + one2one as one2one, + All2All as All2All, + all2all as all2all, + GridFour as GridFour, + grid_four as grid_four, + GridEight as GridEight, + grid_eight as grid_eight, + GridN as GridN, ) diff --git a/brainpy/dnn/activations.py b/brainpy/dnn/activations.py index 5e96f63b8..2ec61614a 100644 --- a/brainpy/dnn/activations.py +++ b/brainpy/dnn/activations.py @@ -1,38 +1,36 @@ - from brainpy._src.dnn.function import ( - Activation as Activation, - Flatten as Flatten, - FunAsLayer as FunAsLayer, + Activation as Activation, + Flatten as Flatten, + FunAsLayer as FunAsLayer, ) from brainpy._src.dnn.activations import ( - Threshold, - ReLU, - RReLU, - Hardtanh, - ReLU6, - Sigmoid, - Hardsigmoid, - Tanh, - SiLU, - Mish, - Hardswish, - ELU, - CELU, - SELU, - GLU, - GELU, - Hardshrink, - LeakyReLU, - LogSigmoid, - Softplus, - Softshrink, - PReLU, - Softsign, - Tanhshrink, - Softmin, - Softmax, - Softmax2d, - LogSoftmax, + Threshold, + ReLU, + RReLU, + Hardtanh, + ReLU6, + Sigmoid, + Hardsigmoid, + Tanh, + SiLU, + Mish, + Hardswish, + ELU, + CELU, + SELU, + GLU, + GELU, + Hardshrink, + LeakyReLU, + LogSigmoid, + Softplus, + Softshrink, + PReLU, + Softsign, + Tanhshrink, + Softmin, + Softmax, + Softmax2d, + LogSoftmax, ) - diff --git a/brainpy/dnn/conv.py b/brainpy/dnn/conv.py index a86ea90cd..d749bbe9f 100644 --- a/brainpy/dnn/conv.py +++ b/brainpy/dnn/conv.py @@ -1,12 +1,11 @@ - from brainpy._src.dnn.conv import ( - Conv1d as Conv1d, - Conv2d as Conv2d, - Conv3d as Conv3d, - Conv1D as Conv1D, - Conv2D as Conv2D, - Conv3D as Conv3D, - ConvTranspose1d as ConvTranspose1d, - ConvTranspose2d as ConvTranspose2d, - ConvTranspose3d as ConvTranspose3d, + Conv1d as Conv1d, + Conv2d as Conv2d, + Conv3d as Conv3d, + Conv1D as Conv1D, + Conv2D as Conv2D, + Conv3D as Conv3D, + ConvTranspose1d as ConvTranspose1d, + ConvTranspose2d as ConvTranspose2d, + ConvTranspose3d as ConvTranspose3d, ) diff --git a/brainpy/dnn/interoperation.py b/brainpy/dnn/interoperation.py index 8411fa8cd..b6fdae5e6 100644 --- a/brainpy/dnn/interoperation.py +++ b/brainpy/dnn/interoperation.py @@ -1,5 +1,4 @@ - from brainpy._src.dnn.interoperation_flax import ( - FromFlax, - ToFlaxRNNCell, ToFlax, + FromFlax, + ToFlaxRNNCell, ToFlax, ) diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 762c3c282..6514d95ff 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -1,17 +1,16 @@ - from brainpy._src.dnn.linear import ( - Dense as Dense, - Linear as Linear, - Identity as Identity, - AllToAll as AllToAll, - OneToOne as OneToOne, - MaskedLinear as MaskedLinear, - CSRLinear as CSRLinear, - EventCSRLinear as EventCSRLinear, - JitFPHomoLinear as JitFPHomoLinear, - JitFPUniformLinear as JitFPUniformLinear, - JitFPNormalLinear as JitFPNormalLinear, - EventJitFPHomoLinear as EventJitFPHomoLinear, - EventJitFPNormalLinear as EventJitFPNormalLinear, - EventJitFPUniformLinear as EventJitFPUniformLinear, + Dense as Dense, + Linear as Linear, + Identity as Identity, + AllToAll as AllToAll, + OneToOne as OneToOne, + MaskedLinear as MaskedLinear, + CSRLinear as CSRLinear, + EventCSRLinear as EventCSRLinear, + JitFPHomoLinear as JitFPHomoLinear, + JitFPUniformLinear as JitFPUniformLinear, + JitFPNormalLinear as JitFPNormalLinear, + EventJitFPHomoLinear as EventJitFPHomoLinear, + EventJitFPNormalLinear as EventJitFPNormalLinear, + EventJitFPUniformLinear as EventJitFPUniformLinear, ) diff --git a/brainpy/dnn/normalization.py b/brainpy/dnn/normalization.py index cdecc3ae2..4b5273f6c 100644 --- a/brainpy/dnn/normalization.py +++ b/brainpy/dnn/normalization.py @@ -1,13 +1,11 @@ - - from brainpy._src.dnn.normalization import ( - BatchNorm1d as BatchNorm1d, - BatchNorm2d as BatchNorm2d, - BatchNorm3d as BatchNorm3d, - BatchNorm1D as BatchNorm1D, - BatchNorm2D as BatchNorm2D, - BatchNorm3D as BatchNorm3D, - LayerNorm as LayerNorm, - GroupNorm as GroupNorm, - InstanceNorm as InstanceNorm, + BatchNorm1d as BatchNorm1d, + BatchNorm2d as BatchNorm2d, + BatchNorm3d as BatchNorm3d, + BatchNorm1D as BatchNorm1D, + BatchNorm2D as BatchNorm2D, + BatchNorm3D as BatchNorm3D, + LayerNorm as LayerNorm, + GroupNorm as GroupNorm, + InstanceNorm as InstanceNorm, ) diff --git a/brainpy/dnn/others.py b/brainpy/dnn/others.py index 717dff569..9e2e26072 100644 --- a/brainpy/dnn/others.py +++ b/brainpy/dnn/others.py @@ -1,14 +1,12 @@ - - from brainpy._src.dnn.base import ( - Layer as Layer, + Layer as Layer, ) from brainpy._src.dnn.dropout import ( - Dropout as Dropout, + Dropout as Dropout, ) from brainpy._src.dnn.function import ( - Activation, - Flatten, - Unflatten, - FunAsLayer, + Activation, + Flatten, + Unflatten, + FunAsLayer, ) diff --git a/brainpy/dnn/pooling.py b/brainpy/dnn/pooling.py index 75d06c3f6..280a16934 100644 --- a/brainpy/dnn/pooling.py +++ b/brainpy/dnn/pooling.py @@ -1,21 +1,20 @@ - from brainpy._src.dnn.pooling import ( - MaxPool as MaxPool, - MaxPool1d as MaxPool1d, - MaxPool2d as MaxPool2d, - MaxPool3d as MaxPool3d, + MaxPool as MaxPool, + MaxPool1d as MaxPool1d, + MaxPool2d as MaxPool2d, + MaxPool3d as MaxPool3d, - MinPool as MinPool, + MinPool as MinPool, - AvgPool as AvgPool, - AvgPool1d as AvgPool1d, - AvgPool2d as AvgPool2d, - AvgPool3d as AvgPool3d, + AvgPool as AvgPool, + AvgPool1d as AvgPool1d, + AvgPool2d as AvgPool2d, + AvgPool3d as AvgPool3d, - AdaptiveAvgPool1d as AdaptiveAvgPool1d, - AdaptiveAvgPool2d as AdaptiveAvgPool2d, - AdaptiveAvgPool3d as AdaptiveAvgPool3d, - AdaptiveMaxPool1d as AdaptiveMaxPool1d, - AdaptiveMaxPool2d as AdaptiveMaxPool2d, - AdaptiveMaxPool3d as AdaptiveMaxPool3d, + AdaptiveAvgPool1d as AdaptiveAvgPool1d, + AdaptiveAvgPool2d as AdaptiveAvgPool2d, + AdaptiveAvgPool3d as AdaptiveAvgPool3d, + AdaptiveMaxPool1d as AdaptiveMaxPool1d, + AdaptiveMaxPool2d as AdaptiveMaxPool2d, + AdaptiveMaxPool3d as AdaptiveMaxPool3d, ) diff --git a/brainpy/encoding.py b/brainpy/encoding.py index b51f9d744..f7780303d 100644 --- a/brainpy/encoding.py +++ b/brainpy/encoding.py @@ -2,14 +2,13 @@ from brainpy._src.encoding.base import ( - Encoder as Encoder, + Encoder as Encoder, ) from brainpy._src.encoding.stateful_encoding import ( - LatencyEncoder as LatencyEncoder, - WeightedPhaseEncoder as WeightedPhaseEncoder, + LatencyEncoder as LatencyEncoder, + WeightedPhaseEncoder as WeightedPhaseEncoder, ) from brainpy._src.encoding.stateless_encoding import ( - PoissonEncoder as PoissonEncoder, - DiffEncoder as DiffEncoder, + PoissonEncoder as PoissonEncoder, + DiffEncoder as DiffEncoder, ) - diff --git a/brainpy/errors.py b/brainpy/errors.py index 453c9c818..b38414d94 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -2,239 +2,239 @@ class BrainPyError(Exception): - """General BrainPy error.""" - pass + """General BrainPy error.""" + pass class APIChangedError(BrainPyError): - pass + pass class RunningError(BrainPyError): - """The error occurred in the running function.""" - pass + """The error occurred in the running function.""" + pass class IntegratorError(BrainPyError): - pass + pass class DiffEqError(BrainPyError): - """The differential equation definition error.""" - pass + """The differential equation definition error.""" + pass class CodeError(BrainPyError): - """Code definition error. - """ - pass + """Code definition error. + """ + pass class AnalyzerError(BrainPyError): - """Error occurred in differential equation analyzer and dynamics analysis. - """ + """Error occurred in differential equation analyzer and dynamics analysis. + """ class PackageMissingError(BrainPyError): - """The package missing error. - """ + """The package missing error. + """ - @classmethod - def by_purpose(cls, name, purpose): - err = (f'"{name}" must be installed when the user wants to use {purpose}. \n' - f'Please install through "pip install {name}".') - return cls(err) + @classmethod + def by_purpose(cls, name, purpose): + err = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + return cls(err) class BackendNotInstalled(BrainPyError): - def __init__(self, backend): - super(BackendNotInstalled, self).__init__( - '"{bk}" must be installed when the user wants to use {bk} backend. \n' - 'Please install {bk} through "pip install {bk}" ' - 'or "conda install {bk}".'.format(bk=backend)) + def __init__(self, backend): + super(BackendNotInstalled, self).__init__( + '"{bk}" must be installed when the user wants to use {bk} backend. \n' + 'Please install {bk} through "pip install {bk}" ' + 'or "conda install {bk}".'.format(bk=backend)) class UniqueNameError(BrainPyError): - def __init__(self, *args): - super(UniqueNameError, self).__init__(*args) + def __init__(self, *args): + super(UniqueNameError, self).__init__(*args) class UnsupportedError(BrainPyError): - pass + pass class NoImplementationError(BrainPyError): - pass + pass class NoLongerSupportError(BrainPyError): - pass + pass class ConnectorError(BrainPyError): - pass + pass class MonitorError(BrainPyError): - pass + pass class MathError(BrainPyError): - """Errors occurred in ``brainpy.math`` module.""" - pass + """Errors occurred in ``brainpy.math`` module.""" + pass class MPACheckpointingRequiredError(BrainPyError): - """To optimally save and restore a multiprocess array (GDA or jax Array outputted from pjit), use GlobalAsyncCheckpointManager. + """To optimally save and restore a multiprocess array (GDA or jax Array outputted from pjit), use GlobalAsyncCheckpointManager. - You can create an GlobalAsyncCheckpointManager at top-level and pass it as - argument:: + You can create an GlobalAsyncCheckpointManager at top-level and pass it as + argument:: - from jax.experimental.gda_serialization import serialization as gdas - gda_manager = gdas.GlobalAsyncCheckpointManager() - brainpy.checkpoints.save(..., gda_manager=gda_manager) - """ + from jax.experimental.gda_serialization import serialization as gdas + gda_manager = gdas.GlobalAsyncCheckpointManager() + brainpy.checkpoints.save(..., gda_manager=gda_manager) + """ - def __init__(self, path, step): - super().__init__( - f'Checkpoint failed at step: "{step}" and path: "{path}": Target ' - 'contains a multiprocess array should be saved/restored with a ' - 'GlobalAsyncCheckpointManager.') + def __init__(self, path, step): + super().__init__( + f'Checkpoint failed at step: "{step}" and path: "{path}": Target ' + 'contains a multiprocess array should be saved/restored with a ' + 'GlobalAsyncCheckpointManager.') class MPARestoreTargetRequiredError(BrainPyError): - """Provide a valid target when restoring a checkpoint with a multiprocess array. - - Multiprocess arrays need a sharding (global meshes and partition specs) to be - initialized. Therefore, to restore a checkpoint that contains a multiprocess - array, make sure the ``target`` you passed contains valid multiprocess arrays - at the corresponding tree structure location. If you cannot provide a full - valid ``target``, consider ``allow_partial_mpa_restoration=True``. - """ - - def __init__(self, path, step, key=None): - error_msg = ( - f'Restore checkpoint failed at step: "{step}" and path: "{path}": ' - 'Checkpoints containing a multiprocess array need to be restored with ' - 'a target with pre-created arrays. If you cannot provide a full valid ' - 'target, consider ``allow_partial_mpa_restoration=True``. ') - if key: - error_msg += f'This error fired when trying to restore array at {key}.' - super().__init__(error_msg) + """Provide a valid target when restoring a checkpoint with a multiprocess array. + + Multiprocess arrays need a sharding (global meshes and partition specs) to be + initialized. Therefore, to restore a checkpoint that contains a multiprocess + array, make sure the ``target`` you passed contains valid multiprocess arrays + at the corresponding tree structure location. If you cannot provide a full + valid ``target``, consider ``allow_partial_mpa_restoration=True``. + """ + + def __init__(self, path, step, key=None): + error_msg = ( + f'Restore checkpoint failed at step: "{step}" and path: "{path}": ' + 'Checkpoints containing a multiprocess array need to be restored with ' + 'a target with pre-created arrays. If you cannot provide a full valid ' + 'target, consider ``allow_partial_mpa_restoration=True``. ') + if key: + error_msg += f'This error fired when trying to restore array at {key}.' + super().__init__(error_msg) class MPARestoreDataCorruptedError(BrainPyError): - """A multiprocess array stored in Google Cloud Storage doesn't contain a "commit_success.txt" file, which should be written at the end of the save. + """A multiprocess array stored in Google Cloud Storage doesn't contain a "commit_success.txt" file, which should be written at the end of the save. - Failure of finding it could indicate a corruption of your saved GDA data. - """ + Failure of finding it could indicate a corruption of your saved GDA data. + """ - def __init__(self, step, path): - super().__init__( - f'Restore checkpoint failed at step: "{step}" on multiprocess array at ' - f' "{path}": No "commit_success.txt" found on this "_gda" directory. ' - 'Was its save halted before completion?') + def __init__(self, step, path): + super().__init__( + f'Restore checkpoint failed at step: "{step}" on multiprocess array at ' + f' "{path}": No "commit_success.txt" found on this "_gda" directory. ' + 'Was its save halted before completion?') class MPARestoreTypeNotMatchError(BrainPyError): - """Make sure the multiprocess array type you use matches your configuration in jax.config.jax_array. + """Make sure the multiprocess array type you use matches your configuration in jax.config.jax_array. - If you turned `jax.config.jax_array` on, you should use - `jax.experimental.array.Array` everywhere, instead of using - `GlobalDeviceArray`. Otherwise, avoid using jax.experimental.array - to restore your checkpoint. - """ + If you turned `jax.config.jax_array` on, you should use + `jax.experimental.array.Array` everywhere, instead of using + `GlobalDeviceArray`. Otherwise, avoid using jax.experimental.array + to restore your checkpoint. + """ - def __init__(self, step, gda_path): - super().__init__( - f'Restore checkpoint failed at step: "{step}" on multiprocess array at ' - f' "{gda_path}": The array type provided by the target does not match ' - 'the JAX global configuration, namely the jax.config.jax_array.') + def __init__(self, step, gda_path): + super().__init__( + f'Restore checkpoint failed at step: "{step}" on multiprocess array at ' + f' "{gda_path}": The array type provided by the target does not match ' + 'the JAX global configuration, namely the jax.config.jax_array.') class AlreadyExistsError(BrainPyError): - """Attempting to overwrite a file via copy. + """Attempting to overwrite a file via copy. - You can pass ``overwrite=True`` to disable this behavior and overwrite - existing files in. - """ + You can pass ``overwrite=True`` to disable this behavior and overwrite + existing files in. + """ - def __init__(self, path): - super().__init__(f'Trying overwrite an existing file: "{path}".') + def __init__(self, path): + super().__init__(f'Trying overwrite an existing file: "{path}".') class InvalidCheckpointError(BrainPyError): - """A checkpoint cannot be stored in a directory that already has + """A checkpoint cannot be stored in a directory that already has - a checkpoint at the current or a later step. + a checkpoint at the current or a later step. - You can pass ``overwrite=True`` to disable this behavior and - overwrite existing checkpoints in the target directory. - """ + You can pass ``overwrite=True`` to disable this behavior and + overwrite existing checkpoints in the target directory. + """ - def __init__(self, path, step): - super().__init__( - f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".' - ) + def __init__(self, path, step): + super().__init__( + f'Trying to save an outdated checkpoint at step: "{step}" and path: "{path}".' + ) class InvalidCheckpointPath(BrainPyError): - """A checkpoint cannot be stored in a directory that already has + """A checkpoint cannot be stored in a directory that already has - a checkpoint at the current or a later step. + a checkpoint at the current or a later step. - You can pass ``overwrite=True`` to disable this behavior and - overwrite existing checkpoints in the target directory. - """ + You can pass ``overwrite=True`` to disable this behavior and + overwrite existing checkpoints in the target directory. + """ - def __init__(self, path): - super().__init__(f'Invalid checkpoint at "{path}".') + def __init__(self, path): + super().__init__(f'Invalid checkpoint at "{path}".') class JaxTracerError(MathError): - def __init__(self, variables=None): - msg = 'There is an unexpected tracer. \n\n' \ - 'In BrainPy, all the dynamically changed variables must be declared as ' \ - '"brainpy.math.Variable" and they should be provided ' \ - 'into the "dyn_vars" when calling the transformation functions, ' \ - 'like "jit()", "vmap()", "grad()", "make_loop()", etc. \n\n' - - if variables is None: - pass - elif isinstance(variables, dict): - msg += f'We detect all the provided dynamical variables are: ' \ - f'{variables.keys()}\n\n' - elif isinstance(variables, (list, tuple)): - msg += 'We detect all the provided dynamical variables are: \n' - for v in variables: - msg += f'\t{v.dtype}[{v.shape}]\n' - msg += '\n' - else: - raise ValueError - - # msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' - msg = 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' - - super(JaxTracerError, self).__init__(msg) + def __init__(self, variables=None): + msg = 'There is an unexpected tracer. \n\n' \ + 'In BrainPy, all the dynamically changed variables must be declared as ' \ + '"brainpy.math.Variable" and they should be provided ' \ + 'into the "dyn_vars" when calling the transformation functions, ' \ + 'like "jit()", "vmap()", "grad()", "make_loop()", etc. \n\n' + + if variables is None: + pass + elif isinstance(variables, dict): + msg += f'We detect all the provided dynamical variables are: ' \ + f'{variables.keys()}\n\n' + elif isinstance(variables, (list, tuple)): + msg += 'We detect all the provided dynamical variables are: \n' + for v in variables: + msg += f'\t{v.dtype}[{v.shape}]\n' + msg += '\n' + else: + raise ValueError + + # msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' + msg = 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' + + super(JaxTracerError, self).__init__(msg) class ConcretizationTypeError(Exception): - def __init__(self): - super(ConcretizationTypeError, self).__init__( - 'This problem may be caused by several ways:\n' - '1. Your if-else conditional statement relies on instances of brainpy.math.Variable. \n' - '2. Your if-else conditional statement relies on functional arguments which do not ' - 'set in "static_argnames" when applying JIT compilation. More details please see ' - 'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n' - '3. The static variables which set in the "static_argnames" are provided ' - 'as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. ' - 'Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].' - ) + def __init__(self): + super(ConcretizationTypeError, self).__init__( + 'This problem may be caused by several ways:\n' + '1. Your if-else conditional statement relies on instances of brainpy.math.Variable. \n' + '2. Your if-else conditional statement relies on functional arguments which do not ' + 'set in "static_argnames" when applying JIT compilation. More details please see ' + 'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n' + '3. The static variables which set in the "static_argnames" are provided ' + 'as arguments, not keyword arguments, like "jit_f(v1, v2)" [<- wrong]. ' + 'Please write it as "jit_f(static_k1=v1, static_k2=v2)" [<- right].' + ) class GPUOperatorNotFound(Exception): - def __init__(self, name): - super(GPUOperatorNotFound, self).__init__(f''' + def __init__(self, name): + super(GPUOperatorNotFound, self).__init__(f''' GPU operator for "{name}" does not found. Please install brainpylib GPU operators with linux + CUDA environment. @@ -242,4 +242,4 @@ def __init__(self, name): class SharedArgError(BrainPyError): - pass + pass diff --git a/brainpy/experimental.py b/brainpy/experimental.py index c909fa633..eca5a2920 100644 --- a/brainpy/experimental.py +++ b/brainpy/experimental.py @@ -1,18 +1,16 @@ - from brainpy._src.dynold.experimental.syn_plasticity import ( - STD as STD, - STP as STP, + STD as STD, + STP as STP, ) from brainpy._src.dynold.experimental.syn_outs import ( - CUBA as CUBA, - COBA as COBA, + CUBA as CUBA, + COBA as COBA, ) from brainpy._src.dynold.experimental.abstract_synapses import ( - Exponential, - DualExponential, - Alpha, + Exponential, + DualExponential, + Alpha, ) from brainpy._src.dynold.experimental.others import ( - PoissonInput, + PoissonInput, ) - diff --git a/brainpy/initialize.py b/brainpy/initialize.py index 0c737bc0b..6ba02b184 100644 --- a/brainpy/initialize.py +++ b/brainpy/initialize.py @@ -5,55 +5,50 @@ You can access them through ``brainpy.init.XXX``. """ - from brainpy._src.initialize.base import ( - Initializer as Initializer, + Initializer as Initializer, ) from brainpy._src.initialize.decay_inits import ( - GaussianDecay as GaussianDecay, - DOGDecay as DOGDecay, + GaussianDecay as GaussianDecay, + DOGDecay as DOGDecay, ) - from brainpy._src.initialize.random_inits import ( - calculate_gain, + calculate_gain, ) from brainpy._src.initialize.random_inits import ( - Normal as Normal, - Uniform as Uniform, - TruncatedNormal as TruncatedNormal, - VarianceScaling as VarianceScaling, - KaimingUniform as KaimingUniform, - KaimingNormal as KaimingNormal, - XavierUniform as XavierUniform, - XavierNormal as XavierNormal, - LecunUniform as LecunUniform, - LecunNormal as LecunNormal, - Orthogonal as Orthogonal, - DeltaOrthogonal as DeltaOrthogonal, - Gamma, - Exponential, + Normal as Normal, + Uniform as Uniform, + TruncatedNormal as TruncatedNormal, + VarianceScaling as VarianceScaling, + KaimingUniform as KaimingUniform, + KaimingNormal as KaimingNormal, + XavierUniform as XavierUniform, + XavierNormal as XavierNormal, + LecunUniform as LecunUniform, + LecunNormal as LecunNormal, + Orthogonal as Orthogonal, + DeltaOrthogonal as DeltaOrthogonal, + Gamma, + Exponential, ) - from brainpy._src.initialize.regular_inits import ( - ZeroInit as ZeroInit, - Constant as Constant, - OneInit as OneInit, - Identity as Identity, + ZeroInit as ZeroInit, + Constant as Constant, + OneInit as OneInit, + Identity as Identity, ) - from brainpy._src.initialize.generic import ( - parameter as parameter, - variable as variable, - variable_ as variable_, - noise as noise, - delay as delay, + parameter as parameter, + variable as variable, + variable_ as variable_, + noise as noise, + delay as delay, ) from brainpy._src.initialize.others import ( - Clip as Clip, + Clip as Clip, ) - diff --git a/brainpy/losses.py b/brainpy/losses.py index f2506742c..ddf8903e8 100644 --- a/brainpy/losses.py +++ b/brainpy/losses.py @@ -1,39 +1,38 @@ # -*- coding: utf-8 -*- from brainpy._src.losses.comparison import ( - cross_entropy_loss as cross_entropy_loss, - cross_entropy_sparse as cross_entropy_sparse, - cross_entropy_sigmoid as cross_entropy_sigmoid, - nll_loss, - l1_loss as l1_loss, - l2_loss as l2_loss, - huber_loss as huber_loss, - mean_absolute_error as mean_absolute_error, - mean_squared_error as mean_squared_error, - mean_squared_log_error as mean_squared_log_error, - binary_logistic_loss as binary_logistic_loss, - multiclass_logistic_loss as multiclass_logistic_loss, - sigmoid_binary_cross_entropy as sigmoid_binary_cross_entropy, - softmax_cross_entropy as softmax_cross_entropy, - log_cosh_loss as log_cosh_loss, - ctc_loss_with_forward_probs as ctc_loss_with_forward_probs, - ctc_loss as ctc_loss, - multi_margin_loss as multi_margin_loss, + cross_entropy_loss as cross_entropy_loss, + cross_entropy_sparse as cross_entropy_sparse, + cross_entropy_sigmoid as cross_entropy_sigmoid, + nll_loss, + l1_loss as l1_loss, + l2_loss as l2_loss, + huber_loss as huber_loss, + mean_absolute_error as mean_absolute_error, + mean_squared_error as mean_squared_error, + mean_squared_log_error as mean_squared_log_error, + binary_logistic_loss as binary_logistic_loss, + multiclass_logistic_loss as multiclass_logistic_loss, + sigmoid_binary_cross_entropy as sigmoid_binary_cross_entropy, + softmax_cross_entropy as softmax_cross_entropy, + log_cosh_loss as log_cosh_loss, + ctc_loss_with_forward_probs as ctc_loss_with_forward_probs, + ctc_loss as ctc_loss, + multi_margin_loss as multi_margin_loss, ) from brainpy._src.losses.comparison import ( - CrossEntropyLoss, - NLLLoss, - L1Loss, - MAELoss, - MSELoss, + CrossEntropyLoss, + NLLLoss, + L1Loss, + MAELoss, + MSELoss, ) from brainpy._src.losses.regularization import ( - l2_norm as l2_norm, - mean_absolute as mean_absolute, - mean_square as mean_square, - log_cosh as log_cosh, - smooth_labels as smooth_labels, + l2_norm as l2_norm, + mean_absolute as mean_absolute, + mean_square as mean_square, + log_cosh as log_cosh, + smooth_labels as smooth_labels, ) - diff --git a/brainpy/measure.py b/brainpy/measure.py index 97eda269d..4e8b91c6a 100644 --- a/brainpy/measure.py +++ b/brainpy/measure.py @@ -1,19 +1,18 @@ # -*- coding: utf-8 -*- from brainpy._src.measure.correlation import ( - cross_correlation as cross_correlation, - voltage_fluctuation as voltage_fluctuation, - matrix_correlation as matrix_correlation, - weighted_correlation as weighted_correlation, - functional_connectivity as functional_connectivity, + cross_correlation as cross_correlation, + voltage_fluctuation as voltage_fluctuation, + matrix_correlation as matrix_correlation, + weighted_correlation as weighted_correlation, + functional_connectivity as functional_connectivity, ) from brainpy._src.measure.firings import ( - raster_plot as raster_plot, - firing_rate as firing_rate, + raster_plot as raster_plot, + firing_rate as firing_rate, ) from brainpy._src.measure.lfp import ( - unitary_LFP as unitary_LFP, + unitary_LFP as unitary_LFP, ) - diff --git a/brainpy/mixin.py b/brainpy/mixin.py index 3787e3cf5..137b17670 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -1,14 +1,13 @@ - from brainpy._src.mixin import ( - MixIn as MixIn, - AlignPost as AlignPost, - ParamDesc as ParamDesc, - ParamDescriber as ParamDescriber, - BindCondData as BindCondData, - Container as Container, - TreeNode as TreeNode, - JointType as JointType, - SupportAutoDelay as SupportAutoDelay, - SupportInputProj as SupportInputProj, - SupportSTDP as SupportSTDP, + MixIn as MixIn, + AlignPost as AlignPost, + ParamDesc as ParamDesc, + ParamDescriber as ParamDescriber, + BindCondData as BindCondData, + Container as Container, + TreeNode as TreeNode, + JointType as JointType, + SupportAutoDelay as SupportAutoDelay, + SupportInputProj as SupportInputProj, + SupportSTDP as SupportSTDP, ) diff --git a/brainpy/neurons.py b/brainpy/neurons.py index afb2dbadc..35722d2f2 100644 --- a/brainpy/neurons.py +++ b/brainpy/neurons.py @@ -5,38 +5,38 @@ """ from brainpy._src.dynold.neurons.biological_models import ( - HH as HH, - MorrisLecar as MorrisLecar, - PinskyRinzelModel as PinskyRinzelModel, - WangBuzsakiModel as WangBuzsakiModel, + HH as HH, + MorrisLecar as MorrisLecar, + PinskyRinzelModel as PinskyRinzelModel, + WangBuzsakiModel as WangBuzsakiModel, ) from brainpy._src.dynold.neurons.fractional_models import ( - FractionalNeuron as FractionalNeuron, - FractionalFHR as FractionalFHR, - FractionalIzhikevich as FractionalIzhikevich, + FractionalNeuron as FractionalNeuron, + FractionalFHR as FractionalFHR, + FractionalIzhikevich as FractionalIzhikevich, ) from brainpy._src.dynold.neurons.reduced_models import ( - LeakyIntegrator as LeakyIntegrator, - LIF as LIF, - ExpIF as ExpIF, - AdExIF as AdExIF, - QuaIF as QuaIF, - AdQuaIF as AdQuaIF, - GIF as GIF, - ALIFBellec2020 as ALIFBellec2020, - Izhikevich as Izhikevich, - HindmarshRose as HindmarshRose, - FHN as FHN, - LIF_SFA_Bellec2020, + LeakyIntegrator as LeakyIntegrator, + LIF as LIF, + ExpIF as ExpIF, + AdExIF as AdExIF, + QuaIF as QuaIF, + AdQuaIF as AdQuaIF, + GIF as GIF, + ALIFBellec2020 as ALIFBellec2020, + Izhikevich as Izhikevich, + HindmarshRose as HindmarshRose, + FHN as FHN, + LIF_SFA_Bellec2020, ) from .dyn.others import ( - InputGroup as InputGroup, - OutputGroup as OutputGroup, - SpikeTimeGroup as SpikeTimeGroup, - PoissonGroup as PoissonGroup, - Leaky as Leaky, - Integrator as Integrator, - OUProcess as OUProcess, + InputGroup as InputGroup, + OutputGroup as OutputGroup, + SpikeTimeGroup as SpikeTimeGroup, + PoissonGroup as PoissonGroup, + Leaky as Leaky, + Integrator as Integrator, + OUProcess as OUProcess, ) diff --git a/brainpy/optim.py b/brainpy/optim.py index de66e3700..d0c893086 100644 --- a/brainpy/optim.py +++ b/brainpy/optim.py @@ -6,46 +6,44 @@ from brainpy._src.optimizers.optimizer import ( - Optimizer as Optimizer, + Optimizer as Optimizer, ) from brainpy._src.optimizers.optimizer import ( - SGD as SGD, - Momentum as Momentum, - MomentumNesterov as MomentumNesterov, - Adagrad as Adagrad, - Adadelta as Adadelta, - RMSProp as RMSProp, - Adam as Adam, - LARS as LARS, - Adan as Adan, - AdamW as AdamW, + SGD as SGD, + Momentum as Momentum, + MomentumNesterov as MomentumNesterov, + Adagrad as Adagrad, + Adadelta as Adadelta, + RMSProp as RMSProp, + Adam as Adam, + LARS as LARS, + Adan as Adan, + AdamW as AdamW, ) - # schedulers # # ---------- # from brainpy._src.optimizers.scheduler import ( - make_schedule as make_schedule, - Scheduler as Scheduler, + make_schedule as make_schedule, + Scheduler as Scheduler, ) from brainpy._src.optimizers.scheduler import ( - Constant as Constant, - ExponentialDecay as ExponentialDecay, - InverseTimeDecay as InverseTimeDecay, - PolynomialDecay as PolynomialDecay, - PiecewiseConstant as PiecewiseConstant, - PiecewiseConstantLR as PiecewiseConstantLR, - PolynomialDecayLR as PolynomialDecayLR, - InverseTimeDecayLR as InverseTimeDecayLR, - ExponentialDecayLR as ExponentialDecayLR + Constant as Constant, + ExponentialDecay as ExponentialDecay, + InverseTimeDecay as InverseTimeDecay, + PolynomialDecay as PolynomialDecay, + PiecewiseConstant as PiecewiseConstant, + PiecewiseConstantLR as PiecewiseConstantLR, + PolynomialDecayLR as PolynomialDecayLR, + InverseTimeDecayLR as InverseTimeDecayLR, + ExponentialDecayLR as ExponentialDecayLR ) from brainpy._src.optimizers.scheduler import ( - StepLR as StepLR, - MultiStepLR as MultiStepLR, - ExponentialLR as ExponentialLR, - CosineAnnealingLR as CosineAnnealingLR, - CosineAnnealingWarmRestarts as CosineAnnealingWarmRestarts, + StepLR as StepLR, + MultiStepLR as MultiStepLR, + ExponentialLR as ExponentialLR, + CosineAnnealingLR as CosineAnnealingLR, + CosineAnnealingWarmRestarts as CosineAnnealingWarmRestarts, ) - diff --git a/brainpy/running.py b/brainpy/running.py index f11e0340c..406bdf7ee 100644 --- a/brainpy/running.py +++ b/brainpy/running.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- from brainpy._src.running.jax_multiprocessing import ( - jax_vectorize_map as jax_vectorize_map, - jax_parallelize_map as jax_parallelize_map, + jax_vectorize_map as jax_vectorize_map, + jax_parallelize_map as jax_parallelize_map, ) from brainpy._src.running.native_multiprocessing import ( - process_pool as process_pool, - process_pool_lock as process_pool_lock, + process_pool as process_pool, + process_pool_lock as process_pool_lock, ) from brainpy._src.running.pathos_multiprocessing import ( - cpu_ordered_parallel as cpu_ordered_parallel, - cpu_unordered_parallel as cpu_unordered_parallel, + cpu_ordered_parallel as cpu_ordered_parallel, + cpu_unordered_parallel as cpu_unordered_parallel, ) diff --git a/brainpy/synapses.py b/brainpy/synapses.py index 572ccfa3b..1d0e911d8 100644 --- a/brainpy/synapses.py +++ b/brainpy/synapses.py @@ -5,41 +5,39 @@ This module has been deprecated since brainpy>=2.4.0. Use ``brainpy.dyn`` module instead. """ - from brainpy._src.dynold.synapses.base import ( - _SynSTP as SynSTP, - _SynOut as SynOut, - TwoEndConn as TwoEndConn, + _SynSTP as SynSTP, + _SynOut as SynOut, + TwoEndConn as TwoEndConn, ) from brainpy._src.dynold.synapses.biological_models import ( - AMPA as AMPA, - GABAa as GABAa, - BioNMDA as BioNMDA, + AMPA as AMPA, + GABAa as GABAa, + BioNMDA as BioNMDA, ) from brainpy._src.dynold.synapses.abstract_models import ( - Delta as Delta, - Exponential as Exponential, - DualExponential as DualExponential, - Alpha as Alpha, - NMDA as NMDA, + Delta as Delta, + Exponential as Exponential, + DualExponential as DualExponential, + Alpha as Alpha, + NMDA as NMDA, ) from brainpy._src.dynold.synapses.compat import ( - DeltaSynapse as DeltaSynapse, - ExpCUBA as ExpCUBA, - ExpCOBA as ExpCOBA, - DualExpCUBA as DualExpCUBA, - DualExpCOBA as DualExpCOBA, - AlphaCUBA as AlphaCUBA, - AlphaCOBA as AlphaCOBA, + DeltaSynapse as DeltaSynapse, + ExpCUBA as ExpCUBA, + ExpCOBA as ExpCOBA, + DualExpCUBA as DualExpCUBA, + DualExpCOBA as DualExpCOBA, + AlphaCUBA as AlphaCUBA, + AlphaCOBA as AlphaCOBA, ) from brainpy._src.dynold.synapses.learning_rules import ( - STP as STP, + STP as STP, ) from brainpy._src.dyn.synapses.delay_couplings import ( - DiffusiveCoupling, - AdditiveCoupling, + DiffusiveCoupling, + AdditiveCoupling, ) from brainpy._src.dynold.synapses.gap_junction import ( - GapJunction + GapJunction ) - diff --git a/brainpy/synouts.py b/brainpy/synouts.py index b00bc62a3..c60ee89c9 100644 --- a/brainpy/synouts.py +++ b/brainpy/synouts.py @@ -5,12 +5,10 @@ This module has been deprecated since brainpy>=2.4.0. Use ``brainpy.dyn`` module instead. """ - from brainpy._src.dynold.synouts.conductances import ( - COBA as COBA, - CUBA as CUBA, + COBA as COBA, + CUBA as CUBA, ) from brainpy._src.dynold.synouts.ions import ( - MgBlock as MgBlock, + MgBlock as MgBlock, ) - diff --git a/brainpy/synplast.py b/brainpy/synplast.py index 53fd65ad3..1637062ae 100644 --- a/brainpy/synplast.py +++ b/brainpy/synplast.py @@ -5,6 +5,6 @@ """ from brainpy._src.dynold.synplast.short_term_plasticity import ( - STD as STD, - STP as STP, + STD as STD, + STP as STP, ) diff --git a/brainpy/tools.py b/brainpy/tools.py index 35e98f6d6..b4a24d4db 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -2,51 +2,48 @@ from brainpy._src.tools.codes import ( - repr_object as repr_object, - repr_dict as repr_dict, - repr_context as repr_context, - copy_doc as copy_doc, - code_lines_to_func as code_lines_to_func, - get_identifiers as get_identifiers, - indent as indent, - deindent as deindent, - word_replace as word_replace, - is_lambda_function as is_lambda_function, - get_main_code as get_main_code, - get_func_source as get_func_source, - change_func_name as change_func_name, + repr_object as repr_object, + repr_dict as repr_dict, + repr_context as repr_context, + copy_doc as copy_doc, + code_lines_to_func as code_lines_to_func, + get_identifiers as get_identifiers, + indent as indent, + deindent as deindent, + word_replace as word_replace, + is_lambda_function as is_lambda_function, + get_main_code as get_main_code, + get_func_source as get_func_source, + change_func_name as change_func_name, ) from brainpy._src.tools.dicts import ( - DotDict as DotDict, + DotDict as DotDict, ) from brainpy._src.tools.math_util import ( - format_seed as format_seed, + format_seed as format_seed, ) from brainpy._src.tools.package import ( - numba_jit as numba_jit, - numba_seed as numba_seed, - numba_range as numba_range, + numba_jit as numba_jit, + numba_seed as numba_seed, + numba_range as numba_range, ) from brainpy._src.tools.others import ( - replicate as replicate, - not_customized as not_customized, - to_size as to_size, - size2num as size2num, - timeout as timeout, - init_progress_bar as init_progress_bar, + replicate as replicate, + not_customized as not_customized, + to_size as to_size, + size2num as size2num, + timeout as timeout, + init_progress_bar as init_progress_bar, ) from brainpy._src.tools.install import ( - jaxlib_install_info, + jaxlib_install_info, ) from brainpy._src.tools.functions import ( - compose as compose, - pipe as pipe, + compose as compose, + pipe as pipe, ) - - - diff --git a/brainpy/types.py b/brainpy/types.py index e98b49cec..69b36aabe 100644 --- a/brainpy/types.py +++ b/brainpy/types.py @@ -1,14 +1,13 @@ # -*- coding: utf-8 -*- from brainpy._src.types import ( - ArrayType as ArrayType, - Array as Array, - Parameter as Parameter, - PyTree as PyTree, - Shape as Shape, - Initializer as Initializer, - Output as Output, - Monitor as Monitor, - Sharding as Sharding, + ArrayType as ArrayType, + Array as Array, + Parameter as Parameter, + PyTree as PyTree, + Shape as Shape, + Initializer as Initializer, + Output as Output, + Monitor as Monitor, + Sharding as Sharding, ) - diff --git a/docs/index.rst b/docs/index.rst index 3d7adcf14..180bc4220 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,13 +32,6 @@ Installation pip install -U brainpy[tpu] - .. tab-item:: Ecosystem - - .. code-block:: bash - - pip install -U BrainX - - ---- Learn more @@ -117,7 +110,7 @@ Learn more .. note:: - BrainPy now is rewritten based on [brainstate](https://github.com/chaobrain/brainstate). + BrainPy is rewritten based on `brainstate `_ since August 2025. diff --git a/pyproject.toml b/pyproject.toml index bcd40ff30..38cac9241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,8 @@ dependencies = [ "tqdm", "brainstate", "brainunit", - "brainevent" + "brainevent", + "braintools", ] dynamic = ['version'] diff --git a/requirements.txt b/requirements.txt index f72ffeece..e7b897528 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ numpy brainstate brainunit brainevent +braintools jax tqdm diff --git a/setup.py b/setup.py index b0d664eaf..2865a11e2 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.10', - install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainstate>=0.1.6', 'brainunit', 'brainevent'], + install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainstate>=0.1.6', 'brainunit', 'brainevent', 'braintools'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", @@ -38,9 +38,11 @@ "Source Code": "https://github.com/brainpy/BrainPy", }, extras_require={ - 'cpu': ['jax[cpu]', 'brainstate[cpu]', 'brainunit[cpu]', 'brainevent[cpu]'], - 'cuda12': ['jax[cuda12]', 'brainstate[cuda12]', 'brainunit[cuda12]', 'brainevent[cuda12]'], - 'tpu': ['jax[tpu]', 'brainstate[tpu]', 'brainunit[tpu]', 'brainevent[tpu]'], + 'cpu': ['jax[cpu]', 'brainstate[cpu]', 'brainunit[cpu]', 'brainevent[cpu]', 'braintools[cpu]'], + 'cuda12': [ + 'jax[cuda12]', 'brainstate[cuda12]', 'brainunit[cuda12]', 'brainevent[cuda12]', 'braintools[cuda12]' + ], + 'tpu': ['jax[tpu]', 'brainstate[tpu]', 'brainunit[tpu]', 'brainevent[tpu]', 'braintools[tpu]'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, '