├── .gitignore ├── README.md ├── docs ├── graph.html ├── main.html ├── ops.html ├── public │ ├── fonts │ │ ├── aller-bold.eot │ │ ├── aller-bold.ttf │ │ ├── aller-bold.woff │ │ ├── aller-light.eot │ │ ├── aller-light.ttf │ │ ├── aller-light.woff │ │ ├── roboto-black.eot │ │ ├── roboto-black.ttf │ │ └── roboto-black.woff │ └── stylesheets │ │ └── normalize.css ├── pycco.css ├── session.html ├── tensor.html └── tf_test.html ├── graph.py ├── main.py ├── ops.py ├── session.py ├── tensor.py ├── tests ├── __init__.py ├── test_gradients.py └── test_ops.py └── tf_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .env/ 3 | *.pyc 4 | .ipynb_checkpoints/ 5 | data/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementing (parts of) TensorFlow (almost) from Scratch 2 | ## A Walkthrough of Symbolic Differentiation 3 | 4 | This [literate programming](https://en.wikipedia.org/wiki/Literate_programming) 5 | exercise will construct a simple 2-layer feed-forward neural network to compute 6 | the [exclusive or](https://en.wikipedia.org/wiki/Exclusive_or), using [symbolic 7 | differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) to 8 | compute the gradients automatically. In total, about 500 lines of code, 9 | including comments. The only functional dependency is numpy. I highly recommend 10 | reading Chris Olah's [Calculus on Computational Graphs: 11 | Backpropagation](http://colah.github.io/posts/2015-08-Backprop/) for more 12 | background on what this code is doing. 13 | -------------------------------------------------------------------------------- /docs/graph.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |main.py | 30 | graph.py | 31 | tensor.py | 32 | ops.py | 33 | session.py
34 | 35 |from __future__ import absolute_import
38 | from __future__ import print_function
39 | from __future__ import division
40 |
41 | from tensor import Tensor
42 | from ops import AddOp, SubOp, MulOp, DivOp, \
43 | DotOp, TransposeOp, SquareOp, NegOp, \
44 | MeanOp, SigmoidOp, AssignOp, GroupOp
Graph
represents a computation to be evaluated by a Session
. With the
54 | exception of Graph#tensor
, Graph#convert
, and Graph#gradients
, most
55 | methods simply create an operation and return the output tensor of the
56 | operation.
class Graph(object):
The tensor
method defines a new tensor with the given initial value
81 | and operation.
def tensor(self, initial_value=None, op=None):
return Tensor(initial_value=initial_value, graph=self, op=op)
The convert
method returns the given value if it is a Tensor
,
106 | otherwise convert it to one.
def convert(self, value):
if isinstance(value, Tensor):
122 | return value
123 | return self.tensor(initial_value=value)
The gradients
method performs backpropagation using reverse accumulation and the chain rule.
It traverses the graph from y
to each x
in xs
, accumulating
134 | gradients, and returning the partial gradients for each xs
. We use a
135 | queue to keep track of the next tensor for which to compute the
136 | gradient and keep a dictionary of the gradients computed thus far.
137 | Iteration starts from the target output y
with an output gradient
138 | of 1.
def gradients(self, y, xs):
queue = []
154 | queue.append((y, 1))
155 |
156 | grads = {}
157 | while len(queue) > 0:
158 | y, grad_y = queue.pop(0)
159 | grad_y = self.convert(grad_y)
160 |
161 | gradients = y.op.gradient(grad_y)
162 | assert len(gradients) == len(y.op.inputs)
163 |
164 | for tensor, gradient in zip(y.op.inputs, gradients):
165 | if tensor in grads:
166 | grads[tensor] += gradient
167 | else:
168 | grads[tensor] = gradient
169 |
170 | if tensor.op:
171 | queue.append((tensor, gradient))
172 |
173 | return [grads[x] for x in xs]
Each operation method defines a new operation with the provided input 184 | tensors and returns the operations' output.
185 | def add(self, a, b):
200 | op = AddOp([a, b], graph=self)
201 | return op.output
def sub(self, a, b):
214 | op = SubOp([a, b], graph=self)
215 | return op.output
def mul(self, a, b):
228 | op = MulOp([a, b], graph=self)
229 | return op.output
def div(self, a, b):
242 | op = DivOp([a, b], graph=self)
243 | return op.output
def neg(self, x):
256 | op = NegOp([x], graph=self)
257 | return op.output
def square(self, x):
270 | op = SquareOp([x], graph=self)
271 | return op.output
def sigmoid(self, x):
284 | op = SigmoidOp([x], graph=self)
285 | return op.output
def dot(self, a, b):
298 | op = DotOp([a, b], graph=self)
299 | return op.output
def transpose(self, x):
312 | op = TransposeOp([x], graph=self)
313 | return op.output
def mean(self, x):
326 | op = MeanOp([x], graph=self)
327 | return op.output
def assign(self, a, b):
340 | op = AssignOp([a, b], graph=self)
341 | return op.output
def group(self, inputs):
354 | op = GroupOp(inputs, graph=self)
355 | return op.output
356 |
357 |
main.py | 33 | graph.py | 34 | tensor.py | 35 | ops.py | 36 | session.py
37 | 38 |This literate programming 39 | exercise will construct a simple 2-layer feed-forward neural network to compute 40 | the exclusive or, using symbolic 41 | differentiation to 42 | compute the gradients automatically. In total, about 500 lines of code, 43 | including comments. The only functional dependency is numpy. I highly recommend 44 | reading Chris Olah's Calculus on Computational Graphs: 45 | Backpropagation for more 46 | background on what this code is doing.
47 |The XOR task is convenient for a number of reasons: it's very fast to compute; 48 | it is not linearly separable thus requiring at least two layers and making the 49 | gradient calculation more interesting; it doesn't require more complicated 50 | matrix-matrix features such as broadcasting.
51 |52 |56 |(I'm also working on a more involved example for MNIST but as soon as I added 53 | support for matrix shapes and broadcasting the code ballooned by 5x and it was 54 | no longer a simple example.)
55 |
Let's start by going over the architecture. We're going to use four main 57 | components:
58 |Graph
, composed of Tensor
nodes and Op
nodes that
60 | together represent the computation we want to differentiate.Tensor
represents a value in the graph. Tensors keep a
62 | reference to the operation that produced it, if any.BaseOp
represents a computation to perform and its
64 | differentiable components. Operations hold references to their input
65 | tensors and an output tensor.Session
is used to evaluate tensors in the graph.Note the return from a graph operation is actually a tensor, representing 69 | the output of the operation.
70 |from __future__ import absolute_import
73 | from __future__ import print_function
74 | from __future__ import division
75 |
76 | import numpy as np
77 | np.random.seed(67)
78 |
79 | from tqdm import trange
80 |
81 | from graph import Graph
82 | from session import Session
The main method performs some setup then trains the model, displaying the 92 | current loss along the way.
93 |def main():
Define a new graph
117 | graph = Graph()
Initialize the training data (XOR truth table)
129 | X = graph.tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]))
132 | y = graph.tensor(np.array([[0, 1, 1, 0]]))
Initialize the model's parameters (weights for each layer)
142 | weights0 = graph.tensor(np.random.normal(size=(2, 4)))
145 | weights1 = graph.tensor(np.random.normal(size=(4, 1)))
Define the model's activations
155 | activations0 = graph.sigmoid(graph.dot(X, weights0))
158 | activations1 = graph.sigmoid(graph.dot(activations0, weights1))
Define operation for computing the loss 168 | (mean squared error)
169 | loss_op = graph.mean(graph.square(graph.transpose(y) - activations1))
Define operations for the gradients w.r.t. the loss and an update 181 | operation to apply the gradients to the model's parameters.
182 | parameters = [weights0, weights1]
185 | gradients = graph.gradients(loss_op, parameters)
186 |
187 | update_op = graph.group([
188 | graph.assign(param, param - grad) \
189 | for param, grad in zip(parameters, gradients)
190 | ])
Begin training... We iterate for a number of epochs, calling the session 200 | run method each time to compute the update operation and the current 201 | loss. The progress bar's description is updated to display the loss.
202 | sess = Session(graph)
205 | with trange(10000) as pbar_epoch:
206 | for _ in pbar_epoch:
207 | _, loss = sess.run([update_op, loss_op])
208 | pbar_epoch.set_description('loss: {:.8f}'.format(loss))
209 |
210 | if __name__ == '__main__':
211 | main()
212 |
213 |
main.py | 30 | graph.py | 31 | tensor.py | 32 | ops.py | 33 | session.py
34 | 35 |from __future__ import absolute_import
38 | from __future__ import print_function
39 | from __future__ import division
40 |
41 | import numpy as np
BaseOp
represents an operation that performs computation on tensors.
51 | Every operation consists of the following:
inputs
, each converted to ensure they're all tensors.None
.)class BaseOp(object):
def __init__(self, inputs, graph):
86 | self.inputs = [graph.convert(input_) for input_ in inputs]
87 | self.output = graph.tensor(op=self)
88 | self.graph = graph
The compute
method receives as input the evaluated input tensors
98 | and returns the result of performing its operation on the inputs.
def compute(self, sess, *args):
raise NotImplementedError()
The gradient
method computes the partial derivative w.r.t. each input
123 | to the operation. (Most of the derivatives come from
124 | Wikipedia.)
def gradient(self, grad):
raise NotImplementedError()
AddOp
adds a tensor to another tensor. Uses the
149 | sum rule to
150 | compute the partial derivatives.
class AddOp(BaseOp):
def compute(self, sess, a, b):
178 | return a + b
def gradient(self, grad):
191 | return [grad, grad]
SubOp
subtracts a tensor from another tensor. Also uses the
201 | sum rule to
202 | compute the partial derivatives.
class SubOp(BaseOp):
def compute(self, sess, a, b):
230 | return a - b
def gradient(self, grad):
243 | return [grad, -grad]
MulOp
multiplies a tensor by another tensor. Uses the
253 | product rule to compute the
254 | partial derivatives.
class MulOp(BaseOp):
def compute(self, sess, a, b):
282 | return a * b
def gradient(self, grad):
295 | a, b = self.inputs
296 | return [grad * b, grad * a]
DivOp
divides a tensor by another tensor. Uses the
306 | quotient rule to compute the
307 | partial derivatives.
class DivOp(BaseOp):
def compute(self, sess, a, b):
335 | return a / b
def gradient(self, grad):
348 | a, b = self.inputs
349 | return [grad / b, grad * (-a / self.graph.square(b))]
NegOp
negates a tensor.
class NegOp(BaseOp):
def compute(self, sess, x):
386 | return -x
def gradient(self, grad):
399 | return [-grad]
DotOp
computes the dot product between two tensors. Uses the
409 | product rule to compute the
410 | partial derivatives. Note that here we need to transpose the terms and
411 | perform a dot product, assuming matrices rather than scalars.
class DotOp(BaseOp):
def compute(self, sess, a, b):
439 | return np.dot(a, b)
def gradient(self, grad):
452 | a, b = self.inputs
453 | return [
454 | self.graph.dot(grad, self.graph.transpose(b)),
455 | self.graph.dot(self.graph.transpose(a), grad),
456 | ]
SquareOp
squares a tensor.
class SquareOp(BaseOp):
def compute(self, sess, x):
493 | return np.square(x)
def gradient(self, grad):
506 | x = self.inputs[0]
507 | return [grad * (2 * x)]
TransposeOp
tranposes a tensor.
class TransposeOp(BaseOp):
def compute(self, sess, x):
544 | return np.transpose(x)
def gradient(self, grad):
557 | return [self.graph.transpose(grad)]
SigmoidOp
implements the
567 | sigmoid function and its
568 | derivative. Notice that the derivative uses the output of the operation
569 | which saves recomputation.
class SigmoidOp(BaseOp):
def compute(self, sess, x):
597 | return 1 / (1 + np.exp(-x))
def gradient(self, grad):
610 | y = self.output
611 | return [grad * (y * (1 - y))]
MeanOp
computes the mean of a tensor. Note the gradient here is
621 | intentionally incorrect because computing it requires knowing the shape of
622 | the input and output tensors. Fortunately, gradients are fairly malleable
623 | in optimization.
class MeanOp(BaseOp):
def compute(self, sess, x):
651 | return np.mean(x)
def gradient(self, grad):
664 | return [grad]
GroupOp
exploits the fact that each input to the operation is
674 | automatically evaluated before computing the operation's output, allowing
675 | us to group together the evaluation of multiple operations. It's input
676 | gradients come from simply broadcasting the output gradient.
class GroupOp(BaseOp):
def compute(self, sess, *args):
704 | return None
def gradient(self, grad):
717 | return [grad] * len(self.inputs)
AssignOp
updates the session's current state for a tensor. It is not
727 | differentiable in this implementation.
class AssignOp(BaseOp):
def compute(self, sess, a, b):
755 | assert a.shape == b.shape, \
756 | 'shapes must match to assign: {} != {}' \
757 | .format(a.shape, b.shape)
758 | sess.state[self.inputs[0]] = b
759 | return b
760 |
761 |
main.py | 30 | graph.py | 31 | tensor.py | 32 | ops.py | 33 | session.py
34 | 35 |from __future__ import absolute_import
38 | from __future__ import print_function
39 | from __future__ import division
40 |
41 | import numpy as np
Session
performs computation on a graph.
class Session(object):
Initializing a session with a graph and a state dictionary to hold 75 | tensor values.
76 | def __init__(self, graph):
self.graph = graph
91 | self.state = {}
run_op
takes as input an operation to run and a context to fetch
101 | pre-evaluted tensors.
def run_op(self, op, context):
args = [self.eval_tensor(tensor, context) for tensor in op.inputs]
117 | return op.compute(self, *args)
eval_tensor
takes as input a tensor to evaluate and a context to
127 | fetch pre-evaluted tensors. If the tensor is not already in the context
128 | there are three possibilities for evaluating the tensor:
def eval_tensor(self, tensor, context):
if tensor not in context:
152 | if tensor.op is not None:
153 | context[tensor] = self.run_op(tensor.op, context)
154 | elif tensor in self.state and self.state[tensor] is not None:
155 | context[tensor] = self.state[tensor]
156 | elif tensor not in self.state and tensor.initial_value is not None:
157 | context[tensor] = self.state[tensor] = tensor.initial_value
158 |
159 | return context[tensor]
run
takes a list of tensors to evaluate and a feed dictionary that
169 | can be used to override tensors.
def run(self, tensors, feed_dict=None):
context = {}
185 |
186 | if feed_dict:
187 | context.update(feed_dict)
188 |
189 | return [self.eval_tensor(tensor, context) for tensor in tensors]
190 |
191 |
main.py | 30 | graph.py | 31 | tensor.py | 32 | ops.py | 33 | session.py
34 | 35 |from __future__ import absolute_import
38 | from __future__ import print_function
39 | from __future__ import division
40 |
41 | import numpy as np
Tensor
represents a value in the graph. It's just a data container with
51 | methods for operator overloading (each of which delegate to the graph). It
52 | includes:
class Tensor(object):
def __init__(self, initial_value, op, graph):
85 | self.initial_value = initial_value
86 | self.graph = graph
87 | self.op = op
def __add__(self, other):
100 | return self.graph.add(self, other)
def __sub__(self, other):
113 | return self.graph.sub(self, other)
def __mul__(self, other):
126 | return self.graph.mul(self, other)
def __truediv__(self, other):
139 | return self.graph.div(self, other)
def __neg__(self):
152 | return self.graph.neg(self)
def __radd__(self, other):
165 | return self.graph.add(other, self)
def __rsub__(self, other):
178 | return self.graph.sub(other, self)
def __rmul__(self, other):
191 | return self.graph.mul(other, self)
def __rtruediv__(self, other):
204 | return self.graph.div(other, self)
205 |
206 |
from __future__ import absolute_import
33 | from __future__ import print_function
34 | from __future__ import division
35 |
36 | import numpy as np
37 | np.random.seed(67)
38 |
39 | import tensorflow as tf
40 |
41 | from tqdm import trange
def main():
54 | X = tf.constant([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=tf.float32)
55 | y = tf.constant([[0, 1, 1, 0]], dtype=tf.float32)
56 |
57 | weights0 = tf.Variable(np.random.normal(size=(2, 4)), dtype=tf.float32)
58 | weights1 = tf.Variable(np.random.normal(size=(4, 1)), dtype=tf.float32)
59 |
60 | activations0 = tf.sigmoid(tf.matmul(X, weights0))
61 | activations1 = tf.sigmoid(tf.matmul(activations0, weights1))
62 |
63 | loss_op = tf.reduce_mean(tf.square(tf.transpose(y) - activations1))
64 |
65 | parameters = [weights0, weights1]
66 | gradients = tf.gradients(loss_op, parameters)
67 |
68 | update_op = tf.group(*[
69 | tf.assign(param, param - grad) \
70 | for param, grad in zip(parameters, gradients)
71 | ])
72 |
73 | tf.set_random_seed(67)
74 |
75 | with tf.Session() as sess:
76 | sess.run(tf.global_variables_initializer())
77 | with trange(10000) as pbar_epoch:
78 | for _ in pbar_epoch:
79 | _, loss = sess.run([update_op, loss_op])
80 | pbar_epoch.set_description('loss: {:.8f}'.format(loss))
81 |
82 | if __name__ == '__main__':
83 | main()
84 |
85 |