diff --git a/backprop_animated.svg b/backprop_animated.svg
new file mode 100644
index 00000000..2947138a
--- /dev/null
+++ b/backprop_animated.svg
@@ -0,0 +1,18320 @@
+
+
\ No newline at end of file
diff --git a/micrograd/engine.py b/micrograd/engine.py
index afd82cc5..b9970f66 100644
--- a/micrograd/engine.py
+++ b/micrograd/engine.py
@@ -1,57 +1,73 @@
+class Operator:
+ def __init__(self, name):
+ self.name = name
+
class Value:
""" stores a single scalar value and its gradient """
- def __init__(self, data, _children=(), _op=''):
+ def __init__(self, data, _children=(), _op=None):
self.data = data
self.grad = 0
# internal variables used for autograd graph construction
- self._backward = lambda: None
+ self._backward = lambda root=None,callback=None: None
self._prev = set(_children)
self._op = _op # the op that produced this node, for graphviz / debugging / etc
def __add__(self, other):
+ operator = Operator('+')
other = other if isinstance(other, Value) else Value(other)
- out = Value(self.data + other.data, (self, other), '+')
+ out = Value(self.data + other.data, (self, other), operator)
- def _backward():
+ def _backward(root=None, callback=None):
self.grad += out.grad
other.grad += out.grad
+ if callback:
+ callback([self,other], out._op, root)
out._backward = _backward
return out
def __mul__(self, other):
+ operator = Operator('*')
other = other if isinstance(other, Value) else Value(other)
- out = Value(self.data * other.data, (self, other), '*')
+ out = Value(self.data * other.data, (self, other), operator)
- def _backward():
+ def _backward(root=None, callback=None):
self.grad += other.data * out.grad
other.grad += self.data * out.grad
+ if callback:
+ callback([self,other], out._op, root)
out._backward = _backward
return out
def __pow__(self, other):
+ operator = Operator(f'**{other}')
assert isinstance(other, (int, float)), "only supporting int/float powers for now"
- out = Value(self.data**other, (self,), f'**{other}')
+ out = Value(self.data**other, (self,), operator)
- def _backward():
+ def _backward(root=None, callback=None):
self.grad += (other * self.data**(other-1)) * out.grad
+ if callback:
+ callback([self], out._op, root)
out._backward = _backward
return out
def relu(self):
- out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU')
+ operator = Operator('ReLU')
+ out = Value(0 if self.data < 0 else self.data, (self,), operator)
- def _backward():
+ def _backward(root=None,callback=None):
self.grad += (out.data > 0) * out.grad
+ if callback:
+ callback([self], out._op, root)
out._backward = _backward
return out
- def backward(self):
+ def backward(self, callback=None):
# topological order all of the children in the graph
topo = []
@@ -66,8 +82,9 @@ def build_topo(v):
# go one variable at a time and apply the chain rule to get its gradient
self.grad = 1
+ root=self
for v in reversed(topo):
- v._backward()
+ v._backward(root, callback)
def __neg__(self): # -self
return self * -1
diff --git a/trace_graph.ipynb b/trace_graph.ipynb
index 055cf341..69d75a1e 100644
--- a/trace_graph.ipynb
+++ b/trace_graph.ipynb
@@ -2,27 +2,33 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/bin/bash: /home/kostas/anaconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
+ "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
+ "Requirement already satisfied: graphviz in /home/kostas/anaconda3/lib/python3.11/site-packages (0.21)\n",
+ "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
+ "Requirement already satisfied: graphviz in /home/kostas/anaconda3/lib/python3.11/site-packages (0.21)\n",
+ "Note: you may need to restart the kernel to use updated packages.\n",
+ "Note: you may need to restart the kernel to use updated packages.\n"
+ ]
+ }
+ ],
"source": [
"# brew install graphviz\n",
- "# pip install graphviz\n",
+ "%pip install graphviz\n",
+ "%reset -f\n",
"from graphviz import Digraph"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from micrograd.engine import Value"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -49,28 +55,1133 @@
" for n in nodes:\n",
" dot.node(name=str(id(n)), label = \"{ data %.4f | grad %.4f }\" % (n.data, n.grad), shape='record')\n",
" if n._op:\n",
- " dot.node(name=str(id(n)) + n._op, label=n._op)\n",
- " dot.edge(str(id(n)) + n._op, str(id(n)))\n",
+ " dot.node(name=str(id(n)) + n._op.name, label=n._op.name)\n",
+ " dot.edge(str(id(n)) + n._op.name, str(id(n)))\n",
" \n",
" for n1, n2 in edges:\n",
- " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n",
+ " dot.edge(str(id(n1)), str(id(n2)) + n2._op.name)\n",
" \n",
" return dot"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
+ "from micrograd.engine import Value\n",
+ "\n",
"# a very simple example\n",
"x = Value(1.0)\n",
+ "z=x+3\n",
"y = (x * 2 + 1).relu()\n",
+ "draw_dot(y)\n",
"y.backward()\n",
"draw_dot(y)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.display import Image\n",
+ "\n",
+ "def draw_dot_highlight(root, highlight=None, parent_op=None, format='svg', rankdir='LR'):\n",
+ " \"\"\"\n",
+ " Draws the computation graph, highlighting the node if provided.\n",
+ " \"\"\"\n",
+ " nodes, edges = trace(root)\n",
+ " dot = Digraph(format=format, graph_attr={'rankdir': rankdir, 'margin': '0.5'})\n",
+ " for n in nodes:\n",
+ " color = 'white'\n",
+ " if highlight is not None and n in highlight:\n",
+ " color = 'yellow'\n",
+ " dot.node(name=str(id(n)), label = \"{ data %.4f | grad %.4f }\" % (n.data, n.grad), shape='record', \n",
+ " style='filled', fillcolor=color)\n",
+ " if n._op:\n",
+ " if n._op is not parent_op:\n",
+ " dot.node(name=str(id(n)) + n._op.name, label=n._op.name)\n",
+ " dot.edge(str(id(n)) + n._op.name, str(id(n)))\n",
+ " else:\n",
+ " dot.node(name=str(id(n)) + n._op.name, label=n._op.name, style='filled', fillcolor='red')\n",
+ " dot.edge(str(id(n)) + n._op.name, str(id(n)), color='red', penwidth='2')\n",
+ " for n1, n2 in edges:\n",
+ " if n1 in highlight and n2._op == parent_op:\n",
+ " dot.edge(str(id(n1)), str(id(n2)) + n2._op.name, color='red', penwidth='2')\n",
+ " else:\n",
+ " dot.edge(str(id(n1)), str(id(n2)) + n2._op.name)\n",
+ " return dot\n",
+ "\n",
+ "# Animate backward pass and save as SVG animation\n",
+ "frames = []\n",
+ "def collect_frame(nodes, parent_op, root):\n",
+ " dot = draw_dot_highlight(root, highlight=nodes, parent_op=parent_op)\n",
+ " svg_bytes = dot.pipe(format='svg')\n",
+ " if svg_bytes:\n",
+ " frames.append(svg_bytes.decode('utf-8')) # decode to string for SVG concatenation\n",
+ "\n",
+ "def save_animation(output_filename, frames):\n",
+ " if not frames:\n",
+ " print('No frames generated. Check if backward pass is working and callback is called.')\n",
+ " else:\n",
+ " # Save animation as a multi-frame SVG (using groups and for simple effect)\n",
+ " svg_header = '\\n'\n",
+ " svg_frames = [f[ f.find('