diff --git a/micrograd/nn.py b/micrograd/nn.py index 30d5d777..d612cd51 100644 --- a/micrograd/nn.py +++ b/micrograd/nn.py @@ -16,8 +16,10 @@ def __init__(self, nin, nonlin=True): self.w = [Value(random.uniform(-1,1)) for _ in range(nin)] self.b = Value(0) self.nonlin = nonlin + self.nin = nin def __call__(self, x): + assert len(x) == self.nin, "Shape mismatch between input and given nin value" act = sum((wi*xi for wi,xi in zip(self.w, x)), self.b) return act.relu() if self.nonlin else act