From 86f2ed7f76140a721843d8d186c67b0cb4a41921 Mon Sep 17 00:00:00 2001 From: aladinsane77 Date: Sun, 21 Sep 2025 18:07:44 +0300 Subject: [PATCH] added support for animated backprop graph --- backprop_animated.svg | 18320 ++++++++++++++++++++++++++++++++++++++++ micrograd/engine.py | 41 +- trace_graph.ipynb | 1151 ++- 3 files changed, 19480 insertions(+), 32 deletions(-) create mode 100644 backprop_animated.svg diff --git a/backprop_animated.svg b/backprop_animated.svg new file mode 100644 index 00000000..2947138a --- /dev/null +++ b/backprop_animated.svg @@ -0,0 +1,18320 @@ + + + +%3 + + + +132783679983696 + +data -1.0000 + +grad 1.0000 + + + +132783679994960+ + ++ + + + +132783679983696->132783679994960+ + + + + + +132783679983696+ + ++ + + + +132783679983696+->132783679983696 + + + + + +132783679983760 + +data -1.0000 + +grad 0.0000 + + + +132783679984336* + +* + + + +132783679983760->132783679984336* + + + + + +132783679979728 + +data -2.0000 + +grad 0.0000 + + + +132783679992272+ + ++ + + + +132783679979728->132783679992272+ + + + + + +132783679993168+ + ++ + + + +132783679979728->132783679993168+ + + + + + +132783679979728+ + ++ + + + +132783679979728+->132783679979728 + + + + + +132783679985936 + +data 6.0000 + +grad 0.0000 + + + +132783679980560+ + ++ + + + +132783679985936->132783679980560+ + + + + + +132783679985936ReLU + +ReLU + + + +132783679985936ReLU->132783679985936 + + + + + +132783679613200 + +data -4.0000 + +grad 0.0000 + + + +132783679613200->132783679979728+ + + + + + +132783679613200->132783679984336* + + + + + +132783679992592* + +* + + + +132783679613200->132783679992592* + + + + + +132783679980432* + +* + + + +132783679613200->132783679980432* + + + + + +132783679982992+ + ++ + + + +132783679613200->132783679982992+ + + + + + +132783679986064 + +data 6.0000 + +grad 0.0000 + + + +132783679986064->132783679985936ReLU + + + + + +132783679986064+ + ++ + + + +132783679986064+->132783679986064 + + + + + +132783679992208 + +data 0.0000 + +grad 0.0000 + + + +132783679990416+ + ++ + + + +132783679992208->132783679990416+ + + + + + +132783679982224* + +* + + + +132783679992208->132783679982224* + + + + + +132783679992208+ + ++ + + + +132783679992208+->132783679992208 + + + + + +132783679992272 + +data -1.0000 + +grad 0.0000 + + + +132783679992272->132783679993168+ + + + + + +132783679992272+->132783679992272 + + + + + +132783679990416 + +data 0.0000 + +grad 0.0000 + + + +132783679995664+ + ++ + + + +132783679990416->132783679995664+ + + + + + +132783679983440* + +* + + + +132783679990416->132783679983440* + + + + + +132783679990416+->132783679990416 + + + + + +132783679982224 + +data 0.0000 + +grad 0.0000 + + + +132783679983504+ + ++ + + + +132783679982224->132783679983504+ + + + + + +132783679982224*->132783679982224 + + + + + +132783679984336 + +data 4.0000 + +grad 0.0000 + + + +132783679984336->132783679986064+ + + + + + +132783679984336*->132783679984336 + + + + + +132783679992592 + +data -8.0000 + +grad 0.0000 + + + +132783679992592->132783679992208+ + + + + + +132783679992592*->132783679992592 + + + + + +132783679980368 + +data -1.0000 + +grad 0.0000 + + + +132783679980368->132783679980432* + + + + + +132783679980432 + +data 4.0000 + +grad 0.0000 + + + +132783679984656+ + ++ + + + +132783679980432->132783679984656+ + + + + + +132783679980432*->132783679980432 + + + + + +132783679992720 + +data 1.0000 + +grad 0.0000 + + + +132783679993040+ + ++ + + + +132783679992720->132783679993040+ + + + + + +132783679980560 + +data 6.0000 + +grad 0.0000 + + + +132783679980560->132783679995664+ + + + + + +132783679980560+->132783679980560 + + + + + +132783679613968 + +data 2.0000 + +grad 0.0000 + + + +132783679613968->132783679979728+ + + + + + +132783679613968->132783679986064+ + + + + + +132783679613968->132783679992592* + + + + + +132783679613968->132783679982992+ + + + + + +132783679991760**3 + +**3 + + + +132783679613968->132783679991760**3 + + + + + +132783679982608 + +data 3.0000 + +grad 0.0000 + + + +132783679982608->132783679983440* + + + + + +132783679984656 + +data 2.0000 + +grad 0.0000 + + + +132783679984656->132783679983696+ + + + + + +132783679984656+->132783679984656 + + + + + +132783679994960 + +data -7.0000 + +grad 1.0000 + + + +132783679994960+->132783679994960 + + + + + +132783679993040 + +data -2.0000 + +grad 0.0000 + + + +132783679993040->132783679984656+ + + + + + +132783679993040+->132783679993040 + + + + + +132783679991696 + +data -1.0000 + +grad 0.0000 + + + +132783679995792* + +* + + + +132783679991696->132783679995792* + + + + + +132783679993168 + +data -3.0000 + +grad 0.0000 + + + +132783679993168->132783679983696+ + + + + + +132783679993168->132783679993040+ + + + + + +132783679993168+->132783679993168 + + + + + +132783679982992 + +data -2.0000 + +grad 0.0000 + + + +132783679983120ReLU + +ReLU + + + +132783679982992->132783679983120ReLU + + + + + +132783679982992+->132783679982992 + + + + + +132783679993232 + +data 1.0000 + +grad 0.0000 + + + +132783679993232->132783679992272+ + + + + + +132783679995792 + +data -6.0000 + +grad 1.0000 + + + +132783679995792->132783679994960+ + + + + + +132783679995792*->132783679995792 + + + + + +132783679983120 + +data 0.0000 + +grad 0.0000 + + + +132783679983120->132783679983504+ + + + + + +132783679983120ReLU->132783679983120 + + + + + +132783679995664 + +data 6.0000 + +grad 0.0000 + + + +132783679995664->132783679995792* + + + + + +132783679995664+->132783679995664 + + + + + +132783679983376 + +data 2.0000 + +grad 0.0000 + + + +132783679983376->132783679982224* + + + + + +132783679983440 + +data 0.0000 + +grad 0.0000 + + + +132783679983440->132783679980560+ + + + + + +132783679983440*->132783679983440 + + + + + +132783679983504 + +data 0.0000 + +grad 0.0000 + + + +132783679983504->132783679990416+ + + + + + +132783679983504+->132783679983504 + + + + + +132783679991760 + +data 8.0000 + +grad 0.0000 + + + +132783679991760->132783679992208+ + + + + + +132783679991760**3->132783679991760 + + + + + + + + + + + + \ No newline at end of file diff --git a/micrograd/engine.py b/micrograd/engine.py index afd82cc5..b9970f66 100644 --- a/micrograd/engine.py +++ b/micrograd/engine.py @@ -1,57 +1,73 @@ +class Operator: + def __init__(self, name): + self.name = name + class Value: """ stores a single scalar value and its gradient """ - def __init__(self, data, _children=(), _op=''): + def __init__(self, data, _children=(), _op=None): self.data = data self.grad = 0 # internal variables used for autograd graph construction - self._backward = lambda: None + self._backward = lambda root=None,callback=None: None self._prev = set(_children) self._op = _op # the op that produced this node, for graphviz / debugging / etc def __add__(self, other): + operator = Operator('+') other = other if isinstance(other, Value) else Value(other) - out = Value(self.data + other.data, (self, other), '+') + out = Value(self.data + other.data, (self, other), operator) - def _backward(): + def _backward(root=None, callback=None): self.grad += out.grad other.grad += out.grad + if callback: + callback([self,other], out._op, root) out._backward = _backward return out def __mul__(self, other): + operator = Operator('*') other = other if isinstance(other, Value) else Value(other) - out = Value(self.data * other.data, (self, other), '*') + out = Value(self.data * other.data, (self, other), operator) - def _backward(): + def _backward(root=None, callback=None): self.grad += other.data * out.grad other.grad += self.data * out.grad + if callback: + callback([self,other], out._op, root) out._backward = _backward return out def __pow__(self, other): + operator = Operator(f'**{other}') assert isinstance(other, (int, float)), "only supporting int/float powers for now" - out = Value(self.data**other, (self,), f'**{other}') + out = Value(self.data**other, (self,), operator) - def _backward(): + def _backward(root=None, callback=None): self.grad += (other * self.data**(other-1)) * out.grad + if callback: + callback([self], out._op, root) out._backward = _backward return out def relu(self): - out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU') + operator = Operator('ReLU') + out = Value(0 if self.data < 0 else self.data, (self,), operator) - def _backward(): + def _backward(root=None,callback=None): self.grad += (out.data > 0) * out.grad + if callback: + callback([self], out._op, root) out._backward = _backward return out - def backward(self): + def backward(self, callback=None): # topological order all of the children in the graph topo = [] @@ -66,8 +82,9 @@ def build_topo(v): # go one variable at a time and apply the chain rule to get its gradient self.grad = 1 + root=self for v in reversed(topo): - v._backward() + v._backward(root, callback) def __neg__(self): # -self return self * -1 diff --git a/trace_graph.ipynb b/trace_graph.ipynb index 055cf341..69d75a1e 100644 --- a/trace_graph.ipynb +++ b/trace_graph.ipynb @@ -2,27 +2,33 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/bin/bash: /home/kostas/anaconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n", + "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", + "Requirement already satisfied: graphviz in /home/kostas/anaconda3/lib/python3.11/site-packages (0.21)\n", + "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", + "Requirement already satisfied: graphviz in /home/kostas/anaconda3/lib/python3.11/site-packages (0.21)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "# brew install graphviz\n", - "# pip install graphviz\n", + "%pip install graphviz\n", + "%reset -f\n", "from graphviz import Digraph" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from micrograd.engine import Value" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -49,28 +55,1133 @@ " for n in nodes:\n", " dot.node(name=str(id(n)), label = \"{ data %.4f | grad %.4f }\" % (n.data, n.grad), shape='record')\n", " if n._op:\n", - " dot.node(name=str(id(n)) + n._op, label=n._op)\n", - " dot.edge(str(id(n)) + n._op, str(id(n)))\n", + " dot.node(name=str(id(n)) + n._op.name, label=n._op.name)\n", + " dot.edge(str(id(n)) + n._op.name, str(id(n)))\n", " \n", " for n1, n2 in edges:\n", - " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n", + " dot.edge(str(id(n1)), str(id(n2)) + n2._op.name)\n", " \n", " return dot" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "132783679917072\n", + "\n", + "data 2.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679924688*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679917072->132783679924688*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679924688\n", + "\n", + "data 2.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679902160+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679924688->132783679902160+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679924688*->132783679924688\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679902032\n", + "\n", + "data 3.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679902032ReLU\n", + "\n", + "ReLU\n", + "\n", + "\n", + "\n", + "132783679902032ReLU->132783679902032\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679926096\n", + "\n", + "data 1.0000\n", + "\n", + "grad 2.0000\n", + "\n", + "\n", + "\n", + "132783679926096->132783679924688*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679902160\n", + "\n", + "data 3.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679902160->132783679902032ReLU\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679902160+->132783679902160\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679963088\n", + "\n", + "data 1.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679963088->132783679902160+\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "from micrograd.engine import Value\n", + "\n", "# a very simple example\n", "x = Value(1.0)\n", + "z=x+3\n", "y = (x * 2 + 1).relu()\n", + "draw_dot(y)\n", "y.backward()\n", "draw_dot(y)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "\n", + "def draw_dot_highlight(root, highlight=None, parent_op=None, format='svg', rankdir='LR'):\n", + " \"\"\"\n", + " Draws the computation graph, highlighting the node if provided.\n", + " \"\"\"\n", + " nodes, edges = trace(root)\n", + " dot = Digraph(format=format, graph_attr={'rankdir': rankdir, 'margin': '0.5'})\n", + " for n in nodes:\n", + " color = 'white'\n", + " if highlight is not None and n in highlight:\n", + " color = 'yellow'\n", + " dot.node(name=str(id(n)), label = \"{ data %.4f | grad %.4f }\" % (n.data, n.grad), shape='record', \n", + " style='filled', fillcolor=color)\n", + " if n._op:\n", + " if n._op is not parent_op:\n", + " dot.node(name=str(id(n)) + n._op.name, label=n._op.name)\n", + " dot.edge(str(id(n)) + n._op.name, str(id(n)))\n", + " else:\n", + " dot.node(name=str(id(n)) + n._op.name, label=n._op.name, style='filled', fillcolor='red')\n", + " dot.edge(str(id(n)) + n._op.name, str(id(n)), color='red', penwidth='2')\n", + " for n1, n2 in edges:\n", + " if n1 in highlight and n2._op == parent_op:\n", + " dot.edge(str(id(n1)), str(id(n2)) + n2._op.name, color='red', penwidth='2')\n", + " else:\n", + " dot.edge(str(id(n1)), str(id(n2)) + n2._op.name)\n", + " return dot\n", + "\n", + "# Animate backward pass and save as SVG animation\n", + "frames = []\n", + "def collect_frame(nodes, parent_op, root):\n", + " dot = draw_dot_highlight(root, highlight=nodes, parent_op=parent_op)\n", + " svg_bytes = dot.pipe(format='svg')\n", + " if svg_bytes:\n", + " frames.append(svg_bytes.decode('utf-8')) # decode to string for SVG concatenation\n", + "\n", + "def save_animation(output_filename, frames):\n", + " if not frames:\n", + " print('No frames generated. Check if backward pass is working and callback is called.')\n", + " else:\n", + " # Save animation as a multi-frame SVG (using groups and for simple effect)\n", + " svg_header = '\\n'\n", + " svg_frames = [f[ f.find('')+6 ] for f in frames if '' in f]\n", + "\n", + " if not svg_frames:\n", + " print('No valid SVG frames found.')\n", + " else:\n", + " # Extract width/height from first frame\n", + " import re\n", + " m = re.search(r']*width=\"([^\"]+)\"[^>]*height=\"([^\"]+)\"', svg_frames[0])\n", + " width = m.group(1) if m else '800pt'\n", + " height = m.group(2) if m else '600pt'\n", + " # Add extra height for the slider\n", + " height_num = int(float(height.replace('pt','')))\n", + " slider_height = 60\n", + " new_height = f\"{height_num + slider_height}pt\"\n", + "\n", + " # Compose SVG with all frames as elements, only one visible at a time\n", + " svg_out = [svg_header, f'']\n", + " for i, frame in enumerate(svg_frames):\n", + " # Remove outer tags\n", + " content = re.sub(r']*>|', '', frame, flags=re.DOTALL)\n", + " display_val = 'inline' if i == 0 else 'none'\n", + " svg_out.append(f'{content}')\n", + "\n", + " # Add slider and JS for speed control at the bottom\n", + " svg_out.append(f'''\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " ''')\n", + " svg_out.append('')\n", + " with open(output_filename, 'w') as f:\n", + " f.write(''.join(svg_out))\n", + "\n", + " print(f'SVG animation saved as {output_filename}')\n", + " # display image using url and specific width set at 500\n", + " # Image(url=output_filename, width=500)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SVG animation saved as backprop_animated.svg\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "132783679983696\n", + "\n", + "data -1.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679994960+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679983696->132783679994960+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983696+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679983696+->132783679983696\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983760\n", + "\n", + "data -1.0000\n", + "\n", + "grad 4.0000\n", + "\n", + "\n", + "\n", + "132783679984336*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679983760->132783679984336*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679979728\n", + "\n", + "data -2.0000\n", + "\n", + "grad 4.0000\n", + "\n", + "\n", + "\n", + "132783679992272+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679979728->132783679992272+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993168+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679979728->132783679993168+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679979728+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679979728+->132783679979728\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679985936\n", + "\n", + "data 6.0000\n", + "\n", + "grad -1.0000\n", + "\n", + "\n", + "\n", + "132783679980560+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679985936->132783679980560+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679985936ReLU\n", + "\n", + "ReLU\n", + "\n", + "\n", + "\n", + "132783679985936ReLU->132783679985936\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679613200\n", + "\n", + "data -4.0000\n", + "\n", + "grad -20.0000\n", + "\n", + "\n", + "\n", + "132783679613200->132783679979728+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679613200->132783679984336*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992592*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679613200->132783679992592*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679980432*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679613200->132783679980432*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982992+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679613200->132783679982992+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679986064\n", + "\n", + "data 6.0000\n", + "\n", + "grad -1.0000\n", + "\n", + "\n", + "\n", + "132783679986064->132783679985936ReLU\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679986064+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679986064+->132783679986064\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992208\n", + "\n", + "data 0.0000\n", + "\n", + "grad -12.0000\n", + "\n", + "\n", + "\n", + "132783679990416+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679992208->132783679990416+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982224*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679992208->132783679982224*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992208+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679992208+->132783679992208\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992272\n", + "\n", + "data -1.0000\n", + "\n", + "grad 2.0000\n", + "\n", + "\n", + "\n", + "132783679992272->132783679993168+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992272+->132783679992272\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679990416\n", + "\n", + "data 0.0000\n", + "\n", + "grad -4.0000\n", + "\n", + "\n", + "\n", + "132783679995664+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679990416->132783679995664+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983440*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679990416->132783679983440*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679990416+->132783679990416\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982224\n", + "\n", + "data 0.0000\n", + "\n", + "grad -4.0000\n", + "\n", + "\n", + "\n", + "132783679983504+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679982224->132783679983504+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982224*->132783679982224\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679984336\n", + "\n", + "data 4.0000\n", + "\n", + "grad -1.0000\n", + "\n", + "\n", + "\n", + "132783679984336->132783679986064+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679984336*->132783679984336\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992592\n", + "\n", + "data -8.0000\n", + "\n", + "grad -12.0000\n", + "\n", + "\n", + "\n", + "132783679992592->132783679992208+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992592*->132783679992592\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679980368\n", + "\n", + "data -1.0000\n", + "\n", + "grad -4.0000\n", + "\n", + "\n", + "\n", + "132783679980368->132783679980432*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679980432\n", + "\n", + "data 4.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679984656+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679980432->132783679984656+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679980432*->132783679980432\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679992720\n", + "\n", + "data 1.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679993040+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "132783679992720->132783679993040+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679980560\n", + "\n", + "data 6.0000\n", + "\n", + "grad -1.0000\n", + "\n", + "\n", + "\n", + "132783679980560->132783679995664+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679980560+->132783679980560\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679613968\n", + "\n", + "data 2.0000\n", + "\n", + "grad -93.0000\n", + "\n", + "\n", + "\n", + "132783679613968->132783679979728+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679613968->132783679986064+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679613968->132783679992592*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679613968->132783679982992+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679991760**3\n", + "\n", + "**3\n", + "\n", + "\n", + "\n", + "132783679613968->132783679991760**3\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982608\n", + "\n", + "data 3.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "132783679982608->132783679983440*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679984656\n", + "\n", + "data 2.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679984656->132783679983696+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679984656+->132783679984656\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679994960\n", + "\n", + "data -7.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679994960+->132783679994960\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993040\n", + "\n", + "data -2.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679993040->132783679984656+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993040+->132783679993040\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679991696\n", + "\n", + "data -1.0000\n", + "\n", + "grad 6.0000\n", + "\n", + "\n", + "\n", + "132783679995792*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "132783679991696->132783679995792*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993168\n", + "\n", + "data -3.0000\n", + "\n", + "grad 2.0000\n", + "\n", + "\n", + "\n", + "132783679993168->132783679983696+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993168->132783679993040+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993168+->132783679993168\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982992\n", + "\n", + "data -2.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "132783679983120ReLU\n", + "\n", + "ReLU\n", + "\n", + "\n", + "\n", + "132783679982992->132783679983120ReLU\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679982992+->132783679982992\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679993232\n", + "\n", + "data 1.0000\n", + "\n", + "grad 2.0000\n", + "\n", + "\n", + "\n", + "132783679993232->132783679992272+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679995792\n", + "\n", + "data -6.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "132783679995792->132783679994960+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679995792*->132783679995792\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983120\n", + "\n", + "data 0.0000\n", + "\n", + "grad -4.0000\n", + "\n", + "\n", + "\n", + "132783679983120->132783679983504+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983120ReLU->132783679983120\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679995664\n", + "\n", + "data 6.0000\n", + "\n", + "grad -1.0000\n", + "\n", + "\n", + "\n", + "132783679995664->132783679995792*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679995664+->132783679995664\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983376\n", + "\n", + "data 2.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "132783679983376->132783679982224*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983440\n", + "\n", + "data 0.0000\n", + "\n", + "grad -1.0000\n", + "\n", + "\n", + "\n", + "132783679983440->132783679980560+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983440*->132783679983440\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983504\n", + "\n", + "data 0.0000\n", + "\n", + "grad -4.0000\n", + "\n", + "\n", + "\n", + "132783679983504->132783679990416+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679983504+->132783679983504\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679991760\n", + "\n", + "data 8.0000\n", + "\n", + "grad -12.0000\n", + "\n", + "\n", + "\n", + "132783679991760->132783679992208+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "132783679991760**3->132783679991760\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = Value(-4.0)\n", + "b = Value(2.0)\n", + "c = a + b\n", + "d = a * b + b**3\n", + "c += c + 1\n", + "c += 1 + c + (-a)\n", + "d += d * 2 + (b + a).relu()\n", + "d += 3 * d + (b - a).relu()\n", + "e = c - d\n", + "\n", + "frames.clear()\n", + "e.backward(callback=collect_frame)\n", + "\n", + "save_animation('backprop_animated.svg', frames)\n", + "draw_dot(e)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -103,7 +1214,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "base", "language": "python", "name": "python3" }, @@ -117,7 +1228,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.11.4" } }, "nbformat": 4,