-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinit.py
More file actions
131 lines (117 loc) · 4.53 KB
/
init.py
File metadata and controls
131 lines (117 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import brainunit as u
import jax.numpy as jnp
import numpy as np
from brainstate import environ, random
from brainstate.typing import ArrayLike, SeedOrKey, DTypeLike
def _compute_fans(shape, in_axis=-2, out_axis=-1):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
class Orthogonal:
def __init__(
self,
scale: ArrayLike = 1.,
axis: int = -1,
seed: SeedOrKey = None,
unit: u.Unit = u.UNITLESS,
):
super().__init__()
self.scale = scale
self.axis = axis
self.rng = random.default_rng(seed)
self.unit = unit
def __call__(self, shape, dtype: DTypeLike = None, ):
dtype = dtype or environ.dftype()
n_rows = shape[self.axis]
n_cols = np.prod(shape) // n_rows
matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
norm_dst = self.rng.normal(size=matrix_shape, dtype=dtype)
q_mat, r_mat = jnp.linalg.qr(norm_dst)
# Enforce Q is uniformly distributed
q_mat *= jnp.sign(jnp.diag(r_mat))
if n_rows < n_cols:
q_mat = q_mat.T
q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
q_mat = jnp.moveaxis(q_mat, 0, self.axis)
r = jnp.asarray(self.scale, dtype=dtype) * q_mat
return u.maybe_decimal(u.Quantity(r, unit=self.unit))
class VarianceScaling:
def __init__(
self,
scale: ArrayLike,
mode: str,
distribution: str,
in_axis: int = -2,
out_axis: int = -1,
seed: SeedOrKey = None,
unit: u.Unit = u.UNITLESS,
):
assert mode in ['fan_in', 'fan_out', 'fan_avg']
assert distribution in ['truncated_normal', 'normal', 'uniform']
self.scale = scale
self.mode = mode
self.in_axis = in_axis
self.out_axis = out_axis
self.distribution = distribution
self.rng = random.default_rng(seed)
self.unit = unit
def __call__(self, shape, dtype: DTypeLike = None, ):
dtype = dtype or environ.dftype()
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
if self.mode == "fan_in":
denominator = fan_in
elif self.mode == "fan_out":
denominator = fan_out
elif self.mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
variance = (self.scale / denominator).astype(dtype)
if self.distribution == "truncated_normal":
stddev = (jnp.sqrt(variance) / .87962566103423978).astype(dtype)
res = self.rng.truncated_normal(-2, 2, shape, dtype=dtype) * stddev
elif self.distribution == "normal":
res = self.rng.randn(*shape, dtype=dtype) * jnp.sqrt(variance).astype(dtype)
elif self.distribution == "uniform":
res = (
self.rng.uniform(low=-1, high=1, size=shape, dtype=dtype) *
jnp.sqrt(3 * variance).astype(dtype)
)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return u.maybe_decimal(u.Quantity(res, unit=self.unit))
class KaimingUniform(VarianceScaling):
def __init__(
self,
scale: float = 2.0,
mode: str = "fan_in",
distribution: str = "uniform",
in_axis: int = -2,
out_axis: int = -1,
seed: SeedOrKey = None,
unit: u.Unit = u.UNITLESS,
):
super().__init__(
scale,
mode,
distribution,
in_axis=in_axis,
out_axis=out_axis,
seed=seed,
unit=unit
)