├── .gitignore
├── LICENSE
├── README.md
├── README_raw.md
├── autograd
├── tf_two_layer_net.py
├── two_layer_net_autograd.py
└── two_layer_net_custom_function.py
├── build_readme.py
├── nn
├── dynamic_net.py
├── two_layer_net_module.py
├── two_layer_net_nn.py
└── two_layer_net_optim.py
└── tensor
├── two_layer_net_numpy.py
└── two_layer_net_tensor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Justin Johnson
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repository introduces the fundamental concepts of
2 | [PyTorch](https://github.com/pytorch/pytorch)
3 | through self-contained examples.
4 |
5 | At its core, PyTorch provides two main features:
6 | - An n-dimensional Tensor, similar to numpy but can run on GPUs
7 | - Automatic differentiation for building and training neural networks
8 |
9 | We will use a fully-connected ReLU network as our running example. The network
10 | will have a single hidden layer, and will be trained with gradient descent to
11 | fit random data by minimizing the Euclidean distance between the network output
12 | and the true output.
13 |
14 | **NOTE:** These examples have been update for PyTorch 0.4, which made several
15 | major changes to the core PyTorch API. Most notably, prior to 0.4 Tensors had
16 | to be wrapped in Variable objects to use autograd; this functionality has now
17 | been added directly to Tensors, and Variables are now deprecated.
18 |
19 | ### Table of Contents
20 | - Warm-up: numpy
21 | - PyTorch: Tensors
22 | - PyTorch: Autograd
23 | - PyTorch: Defining new autograd functions
24 | - TensorFlow: Static Graphs
25 | - PyTorch: nn
26 | - PyTorch: optim
27 | - PyTorch: Custom nn Modules
28 | - PyTorch: Control Flow and Weight Sharing
29 |
30 | ## Warm-up: numpy
31 |
32 | Before introducing PyTorch, we will first implement the network using numpy.
33 |
34 | Numpy provides an n-dimensional array object, and many functions for manipulating
35 | these arrays. Numpy is a generic framework for scientific computing; it does not
36 | know anything about computation graphs, or deep learning, or gradients. However
37 | we can easily use numpy to fit a two-layer network to random data by manually
38 | implementing the forward and backward passes through the network using numpy
39 | operations:
40 |
41 | ```python
42 | # Code in file tensor/two_layer_net_numpy.py
43 | import numpy as np
44 |
45 | # N is batch size; D_in is input dimension;
46 | # H is hidden dimension; D_out is output dimension.
47 | N, D_in, H, D_out = 64, 1000, 100, 10
48 |
49 | # Create random input and output data
50 | x = np.random.randn(N, D_in)
51 | y = np.random.randn(N, D_out)
52 |
53 | # Randomly initialize weights
54 | w1 = np.random.randn(D_in, H)
55 | w2 = np.random.randn(H, D_out)
56 |
57 | learning_rate = 1e-6
58 | for t in range(500):
59 | # Forward pass: compute predicted y
60 | h = x.dot(w1)
61 | h_relu = np.maximum(h, 0)
62 | y_pred = h_relu.dot(w2)
63 |
64 | # Compute and print loss
65 | loss = np.square(y_pred - y).sum()
66 | print(t, loss)
67 |
68 | # Backprop to compute gradients of w1 and w2 with respect to loss
69 | grad_y_pred = 2.0 * (y_pred - y)
70 | grad_w2 = h_relu.T.dot(grad_y_pred)
71 | grad_h_relu = grad_y_pred.dot(w2.T)
72 | grad_h = grad_h_relu.copy()
73 | grad_h[h < 0] = 0
74 | grad_w1 = x.T.dot(grad_h)
75 |
76 | # Update weights
77 | w1 -= learning_rate * grad_w1
78 | w2 -= learning_rate * grad_w2
79 | ```
80 |
81 | ## PyTorch: Tensors
82 |
83 | Numpy is a great framework, but it cannot utilize GPUs to accelerate its
84 | numerical computations. For modern deep neural networks, GPUs often provide
85 | speedups of [50x or greater](https://github.com/jcjohnson/cnn-benchmarks), so
86 | unfortunately numpy won't be enough for modern deep learning.
87 |
88 | Here we introduce the most fundamental PyTorch concept: the **Tensor**. A PyTorch
89 | Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional
90 | array, and PyTorch provides many functions for operating on these Tensors.
91 | Any computation you might want to perform with numpy can also be accomplished
92 | with PyTorch Tensors; you should think of them as a generic tool for scientific
93 | computing.
94 |
95 | However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their
96 | numeric computations. To run a PyTorch Tensor on GPU, you use the `device`
97 | argument when constructing a Tensor to place the Tensor on a GPU.
98 |
99 | Here we use PyTorch Tensors to fit a two-layer network to random data. Like the
100 | numpy example above we manually implement the forward and backward
101 | passes through the network, using operations on PyTorch Tensors:
102 |
103 | ```python
104 | # Code in file tensor/two_layer_net_tensor.py
105 | import torch
106 |
107 | device = torch.device('cpu')
108 | # device = torch.device('cuda') # Uncomment this to run on GPU
109 |
110 | # N is batch size; D_in is input dimension;
111 | # H is hidden dimension; D_out is output dimension.
112 | N, D_in, H, D_out = 64, 1000, 100, 10
113 |
114 | # Create random input and output data
115 | x = torch.randn(N, D_in, device=device)
116 | y = torch.randn(N, D_out, device=device)
117 |
118 | # Randomly initialize weights
119 | w1 = torch.randn(D_in, H, device=device)
120 | w2 = torch.randn(H, D_out, device=device)
121 |
122 | learning_rate = 1e-6
123 | for t in range(500):
124 | # Forward pass: compute predicted y
125 | h = x.mm(w1)
126 | h_relu = h.clamp(min=0)
127 | y_pred = h_relu.mm(w2)
128 |
129 | # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
130 | # of shape (); we can get its value as a Python number with loss.item().
131 | loss = (y_pred - y).pow(2).sum()
132 | print(t, loss.item())
133 |
134 | # Backprop to compute gradients of w1 and w2 with respect to loss
135 | grad_y_pred = 2.0 * (y_pred - y)
136 | grad_w2 = h_relu.t().mm(grad_y_pred)
137 | grad_h_relu = grad_y_pred.mm(w2.t())
138 | grad_h = grad_h_relu.clone()
139 | grad_h[h < 0] = 0
140 | grad_w1 = x.t().mm(grad_h)
141 |
142 | # Update weights using gradient descent
143 | w1 -= learning_rate * grad_w1
144 | w2 -= learning_rate * grad_w2
145 | ```
146 |
147 | ## PyTorch: Autograd
148 |
149 | In the above examples, we had to manually implement both the forward and
150 | backward passes of our neural network. Manually implementing the backward pass
151 | is not a big deal for a small two-layer network, but can quickly get very hairy
152 | for large complex networks.
153 |
154 | Thankfully, we can use
155 | [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
156 | to automate the computation of backward passes in neural networks.
157 | The **autograd** package in PyTorch provides exactly this functionality.
158 | When using autograd, the forward pass of your network will define a
159 | **computational graph**; nodes in the graph will be Tensors, and edges will be
160 | functions that produce output Tensors from input Tensors. Backpropagating through
161 | this graph then allows you to easily compute gradients.
162 |
163 | This sounds complicated, it's pretty simple to use in practice. If we want to
164 | compute gradients with respect to some Tensor, then we set `requires_grad=True`
165 | when constructing that Tensor. Any PyTorch operations on that Tensor will cause
166 | a computational graph to be constructed, allowing us to later perform backpropagation
167 | through the graph. If `x` is a Tensor with `requires_grad=True`, then after
168 | backpropagation `x.grad` will be another Tensor holding the gradient of `x` with
169 | respect to some scalar value.
170 |
171 | Sometimes you may wish to prevent PyTorch from building computational graphs when
172 | performing certain operations on Tensors with `requires_grad=True`; for example
173 | we usually don't want to backpropagate through the weight update steps when
174 | training a neural network. In such scenarios we can use the `torch.no_grad()`
175 | context manager to prevent the construction of a computational graph.
176 |
177 | Here we use PyTorch Tensors and autograd to implement our two-layer network;
178 | now we no longer need to manually implement the backward pass through the
179 | network:
180 |
181 | ```python
182 | # Code in file autograd/two_layer_net_autograd.py
183 | import torch
184 |
185 | device = torch.device('cpu')
186 | # device = torch.device('cuda') # Uncomment this to run on GPU
187 |
188 | # N is batch size; D_in is input dimension;
189 | # H is hidden dimension; D_out is output dimension.
190 | N, D_in, H, D_out = 64, 1000, 100, 10
191 |
192 | # Create random Tensors to hold input and outputs
193 | x = torch.randn(N, D_in, device=device)
194 | y = torch.randn(N, D_out, device=device)
195 |
196 | # Create random Tensors for weights; setting requires_grad=True means that we
197 | # want to compute gradients for these Tensors during the backward pass.
198 | w1 = torch.randn(D_in, H, device=device, requires_grad=True)
199 | w2 = torch.randn(H, D_out, device=device, requires_grad=True)
200 |
201 | learning_rate = 1e-6
202 | for t in range(500):
203 | # Forward pass: compute predicted y using operations on Tensors. Since w1 and
204 | # w2 have requires_grad=True, operations involving these Tensors will cause
205 | # PyTorch to build a computational graph, allowing automatic computation of
206 | # gradients. Since we are no longer implementing the backward pass by hand we
207 | # don't need to keep references to intermediate values.
208 | y_pred = x.mm(w1).clamp(min=0).mm(w2)
209 |
210 | # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
211 | # is a Python number giving its value.
212 | loss = (y_pred - y).pow(2).sum()
213 | print(t, loss.item())
214 |
215 | # Use autograd to compute the backward pass. This call will compute the
216 | # gradient of loss with respect to all Tensors with requires_grad=True.
217 | # After this call w1.grad and w2.grad will be Tensors holding the gradient
218 | # of the loss with respect to w1 and w2 respectively.
219 | loss.backward()
220 |
221 | # Update weights using gradient descent. For this step we just want to mutate
222 | # the values of w1 and w2 in-place; we don't want to build up a computational
223 | # graph for the update steps, so we use the torch.no_grad() context manager
224 | # to prevent PyTorch from building a computational graph for the updates
225 | with torch.no_grad():
226 | w1 -= learning_rate * w1.grad
227 | w2 -= learning_rate * w2.grad
228 |
229 | # Manually zero the gradients after running the backward pass
230 | w1.grad.zero_()
231 | w2.grad.zero_()
232 | ```
233 |
234 | ## PyTorch: Defining new autograd functions
235 | Under the hood, each primitive autograd operator is really two functions that
236 | operate on Tensors. The **forward** function computes output Tensors from input
237 | Tensors. The **backward** function receives the gradient of the output Tensors
238 | with respect to some scalar value, and computes the gradient of the input Tensors
239 | with respect to that same scalar value.
240 |
241 | In PyTorch we can easily define our own autograd operator by defining a subclass
242 | of `torch.autograd.Function` and implementing the `forward` and `backward` functions.
243 | We can then use our new autograd operator by constructing an instance and calling it
244 | like a function, passing Tensors containing input data.
245 |
246 | In this example we define our own custom autograd function for performing the ReLU
247 | nonlinearity, and use it to implement our two-layer network:
248 |
249 | ```python
250 | # Code in file autograd/two_layer_net_custom_function.py
251 | import torch
252 |
253 | class MyReLU(torch.autograd.Function):
254 | """
255 | We can implement our own custom autograd Functions by subclassing
256 | torch.autograd.Function and implementing the forward and backward passes
257 | which operate on Tensors.
258 | """
259 | @staticmethod
260 | def forward(ctx, x):
261 | """
262 | In the forward pass we receive a context object and a Tensor containing the
263 | input; we must return a Tensor containing the output, and we can use the
264 | context object to cache objects for use in the backward pass.
265 | """
266 | ctx.save_for_backward(x)
267 | return x.clamp(min=0)
268 |
269 | @staticmethod
270 | def backward(ctx, grad_output):
271 | """
272 | In the backward pass we receive the context object and a Tensor containing
273 | the gradient of the loss with respect to the output produced during the
274 | forward pass. We can retrieve cached data from the context object, and must
275 | compute and return the gradient of the loss with respect to the input to the
276 | forward function.
277 | """
278 | x, = ctx.saved_tensors
279 | grad_x = grad_output.clone()
280 | grad_x[x < 0] = 0
281 | return grad_x
282 |
283 |
284 | device = torch.device('cpu')
285 | # device = torch.device('cuda') # Uncomment this to run on GPU
286 |
287 | # N is batch size; D_in is input dimension;
288 | # H is hidden dimension; D_out is output dimension.
289 | N, D_in, H, D_out = 64, 1000, 100, 10
290 |
291 | # Create random Tensors to hold input and output
292 | x = torch.randn(N, D_in, device=device)
293 | y = torch.randn(N, D_out, device=device)
294 |
295 | # Create random Tensors for weights.
296 | w1 = torch.randn(D_in, H, device=device, requires_grad=True)
297 | w2 = torch.randn(H, D_out, device=device, requires_grad=True)
298 |
299 | learning_rate = 1e-6
300 | for t in range(500):
301 | # Forward pass: compute predicted y using operations on Tensors; we call our
302 | # custom ReLU implementation using the MyReLU.apply function
303 | y_pred = MyReLU.apply(x.mm(w1)).mm(w2)
304 |
305 | # Compute and print loss
306 | loss = (y_pred - y).pow(2).sum()
307 | print(t, loss.item())
308 |
309 | # Use autograd to compute the backward pass.
310 | loss.backward()
311 |
312 | with torch.no_grad():
313 | # Update weights using gradient descent
314 | w1 -= learning_rate * w1.grad
315 | w2 -= learning_rate * w2.grad
316 |
317 | # Manually zero the gradients after running the backward pass
318 | w1.grad.zero_()
319 | w2.grad.zero_()
320 |
321 | ```
322 |
323 | ## TensorFlow: Static Graphs
324 | PyTorch autograd looks a lot like TensorFlow: in both frameworks we define
325 | a computational graph, and use automatic differentiation to compute gradients.
326 | The biggest difference between the two is that TensorFlow's computational graphs
327 | are **static** and PyTorch uses **dynamic** computational graphs.
328 |
329 | In TensorFlow, we define the computational graph once and then execute the same
330 | graph over and over again, possibly feeding different input data to the graph.
331 | In PyTorch, each forward pass defines a new computational graph.
332 |
333 | Static graphs are nice because you can optimize the graph up front; for example
334 | a framework might decide to fuse some graph operations for efficiency, or to
335 | come up with a strategy for distributing the graph across many GPUs or many
336 | machines. If you are reusing the same graph over and over, then this potentially
337 | costly up-front optimization can be amortized as the same graph is rerun over
338 | and over.
339 |
340 | One aspect where static and dynamic graphs differ is control flow. For some models
341 | we may wish to perform different computation for each data point; for example a
342 | recurrent network might be unrolled for different numbers of time steps for each
343 | data point; this unrolling can be implemented as a loop. With a static graph the
344 | loop construct needs to be a part of the graph; for this reason TensorFlow
345 | provides operators such as `tf.scan` for embedding loops into the graph. With
346 | dynamic graphs the situation is simpler: since we build graphs on-the-fly for
347 | each example, we can use normal imperative flow control to perform computation
348 | that differs for each input.
349 |
350 | To contrast with the PyTorch autograd example above, here we use TensorFlow to
351 | fit a simple two-layer net:
352 |
353 | ```python
354 | # Code in file autograd/tf_two_layer_net.py
355 | import tensorflow as tf
356 | import numpy as np
357 |
358 | # First we set up the computational graph:
359 |
360 | # N is batch size; D_in is input dimension;
361 | # H is hidden dimension; D_out is output dimension.
362 | N, D_in, H, D_out = 64, 1000, 100, 10
363 |
364 | # Create placeholders for the input and target data; these will be filled
365 | # with real data when we execute the graph.
366 | x = tf.placeholder(tf.float32, shape=(None, D_in))
367 | y = tf.placeholder(tf.float32, shape=(None, D_out))
368 |
369 | # Create Variables for the weights and initialize them with random data.
370 | # A TensorFlow Variable persists its value across executions of the graph.
371 | w1 = tf.Variable(tf.random_normal((D_in, H)))
372 | w2 = tf.Variable(tf.random_normal((H, D_out)))
373 |
374 | # Forward pass: Compute the predicted y using operations on TensorFlow Tensors.
375 | # Note that this code does not actually perform any numeric operations; it
376 | # merely sets up the computational graph that we will later execute.
377 | h = tf.matmul(x, w1)
378 | h_relu = tf.maximum(h, tf.zeros(1))
379 | y_pred = tf.matmul(h_relu, w2)
380 |
381 | # Compute loss using operations on TensorFlow Tensors
382 | loss = tf.reduce_sum((y - y_pred) ** 2.0)
383 |
384 | # Compute gradient of the loss with respect to w1 and w2.
385 | grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])
386 |
387 | # Update the weights using gradient descent. To actually update the weights
388 | # we need to evaluate new_w1 and new_w2 when executing the graph. Note that
389 | # in TensorFlow the the act of updating the value of the weights is part of
390 | # the computational graph; in PyTorch this happens outside the computational
391 | # graph.
392 | learning_rate = 1e-6
393 | new_w1 = w1.assign(w1 - learning_rate * grad_w1)
394 | new_w2 = w2.assign(w2 - learning_rate * grad_w2)
395 |
396 | # Now we have built our computational graph, so we enter a TensorFlow session to
397 | # actually execute the graph.
398 | with tf.Session() as sess:
399 | # Run the graph once to initialize the Variables w1 and w2.
400 | sess.run(tf.global_variables_initializer())
401 |
402 | # Create numpy arrays holding the actual data for the inputs x and targets y
403 | x_value = np.random.randn(N, D_in)
404 | y_value = np.random.randn(N, D_out)
405 | for _ in range(500):
406 | # Execute the graph many times. Each time it executes we want to bind
407 | # x_value to x and y_value to y, specified with the feed_dict argument.
408 | # Each time we execute the graph we want to compute the values for loss,
409 | # new_w1, and new_w2; the values of these Tensors are returned as numpy
410 | # arrays.
411 | loss_value, _, _ = sess.run([loss, new_w1, new_w2],
412 | feed_dict={x: x_value, y: y_value})
413 | print(loss_value)
414 | ```
415 |
416 |
417 | ## PyTorch: nn
418 | Computational graphs and autograd are a very powerful paradigm for defining
419 | complex operators and automatically taking derivatives; however for large
420 | neural networks raw autograd can be a bit too low-level.
421 |
422 | When building neural networks we frequently think of arranging the computation
423 | into **layers**, some of which have **learnable parameters** which will be
424 | optimized during learning.
425 |
426 | In TensorFlow, packages like [Keras](https://github.com/fchollet/keras),
427 | [TensorFlow-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim),
428 | and [TFLearn](http://tflearn.org/) provide higher-level abstractions over
429 | raw computational graphs that are useful for building neural networks.
430 |
431 | In PyTorch, the `nn` package serves this same purpose. The `nn` package defines a set of
432 | **Modules**, which are roughly equivalent to neural network layers. A Module receives
433 | input Tensors and computes output Tensors, but may also hold internal state such as
434 | Tensors containing learnable parameters. The `nn` package also defines a set of useful
435 | loss functions that are commonly used when training neural networks.
436 |
437 | In this example we use the `nn` package to implement our two-layer network:
438 |
439 | ```python
440 | # Code in file nn/two_layer_net_nn.py
441 | import torch
442 |
443 | device = torch.device('cpu')
444 | # device = torch.device('cuda') # Uncomment this to run on GPU
445 |
446 | # N is batch size; D_in is input dimension;
447 | # H is hidden dimension; D_out is output dimension.
448 | N, D_in, H, D_out = 64, 1000, 100, 10
449 |
450 | # Create random Tensors to hold inputs and outputs
451 | x = torch.randn(N, D_in, device=device)
452 | y = torch.randn(N, D_out, device=device)
453 |
454 | # Use the nn package to define our model as a sequence of layers. nn.Sequential
455 | # is a Module which contains other Modules, and applies them in sequence to
456 | # produce its output. Each Linear Module computes output from input using a
457 | # linear function, and holds internal Tensors for its weight and bias.
458 | # After constructing the model we use the .to() method to move it to the
459 | # desired device.
460 | model = torch.nn.Sequential(
461 | torch.nn.Linear(D_in, H),
462 | torch.nn.ReLU(),
463 | torch.nn.Linear(H, D_out),
464 | ).to(device)
465 |
466 | # The nn package also contains definitions of popular loss functions; in this
467 | # case we will use Mean Squared Error (MSE) as our loss function. Setting
468 | # reduction='sum' means that we are computing the *sum* of squared errors rather
469 | # than the mean; this is for consistency with the examples above where we
470 | # manually compute the loss, but in practice it is more common to use mean
471 | # squared error as a loss by setting reduction='elementwise_mean'.
472 | loss_fn = torch.nn.MSELoss(reduction='sum')
473 |
474 | learning_rate = 1e-4
475 | for t in range(500):
476 | # Forward pass: compute predicted y by passing x to the model. Module objects
477 | # override the __call__ operator so you can call them like functions. When
478 | # doing so you pass a Tensor of input data to the Module and it produces
479 | # a Tensor of output data.
480 | y_pred = model(x)
481 |
482 | # Compute and print loss. We pass Tensors containing the predicted and true
483 | # values of y, and the loss function returns a Tensor containing the loss.
484 | loss = loss_fn(y_pred, y)
485 | print(t, loss.item())
486 |
487 | # Zero the gradients before running the backward pass.
488 | model.zero_grad()
489 |
490 | # Backward pass: compute gradient of the loss with respect to all the learnable
491 | # parameters of the model. Internally, the parameters of each Module are stored
492 | # in Tensors with requires_grad=True, so this call will compute gradients for
493 | # all learnable parameters in the model.
494 | loss.backward()
495 |
496 | # Update the weights using gradient descent. Each parameter is a Tensor, so
497 | # we can access its data and gradients like we did before.
498 | with torch.no_grad():
499 | for param in model.parameters():
500 | param.data -= learning_rate * param.grad
501 | ```
502 |
503 |
504 | ## PyTorch: optim
505 | Up to this point we have updated the weights of our models by manually mutating
506 | Tensors holding learnable parameters. This is not a huge burden
507 | for simple optimization algorithms like stochastic gradient descent, but in practice
508 | we often train neural networks using more sophisiticated optimizers like AdaGrad,
509 | RMSProp, Adam, etc.
510 |
511 | The `optim` package in PyTorch abstracts the idea of an optimization algorithm and
512 | provides implementations of commonly used optimization algorithms.
513 |
514 | In this example we will use the `nn` package to define our model as before, but we
515 | will optimize the model using the Adam algorithm provided by the `optim` package:
516 |
517 | ```python
518 | # Code in file nn/two_layer_net_optim.py
519 | import torch
520 |
521 | # N is batch size; D_in is input dimension;
522 | # H is hidden dimension; D_out is output dimension.
523 | N, D_in, H, D_out = 64, 1000, 100, 10
524 |
525 | # Create random Tensors to hold inputs and outputs.
526 | x = torch.randn(N, D_in)
527 | y = torch.randn(N, D_out)
528 |
529 | # Use the nn package to define our model and loss function.
530 | model = torch.nn.Sequential(
531 | torch.nn.Linear(D_in, H),
532 | torch.nn.ReLU(),
533 | torch.nn.Linear(H, D_out),
534 | )
535 | loss_fn = torch.nn.MSELoss(reduction='sum')
536 |
537 | # Use the optim package to define an Optimizer that will update the weights of
538 | # the model for us. Here we will use Adam; the optim package contains many other
539 | # optimization algorithms. The first argument to the Adam constructor tells the
540 | # optimizer which Tensors it should update.
541 | learning_rate = 1e-4
542 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
543 | for t in range(500):
544 | # Forward pass: compute predicted y by passing x to the model.
545 | y_pred = model(x)
546 |
547 | # Compute and print loss.
548 | loss = loss_fn(y_pred, y)
549 | print(t, loss.item())
550 |
551 | # Before the backward pass, use the optimizer object to zero all of the
552 | # gradients for the Tensors it will update (which are the learnable weights
553 | # of the model)
554 | optimizer.zero_grad()
555 |
556 | # Backward pass: compute gradient of the loss with respect to model parameters
557 | loss.backward()
558 |
559 | # Calling the step function on an Optimizer makes an update to its parameters
560 | optimizer.step()
561 | ```
562 |
563 |
564 | ## PyTorch: Custom nn Modules
565 | Sometimes you will want to specify models that are more complex than a sequence of
566 | existing Modules; for these cases you can define your own Modules by subclassing
567 | `nn.Module` and defining a `forward` which receives input Tensors and produces
568 | output Tensors using other modules or other autograd operations on Tensors.
569 |
570 | In this example we implement our two-layer network as a custom Module subclass:
571 |
572 | ```python
573 | # Code in file nn/two_layer_net_module.py
574 | import torch
575 |
576 | class TwoLayerNet(torch.nn.Module):
577 | def __init__(self, D_in, H, D_out):
578 | """
579 | In the constructor we instantiate two nn.Linear modules and assign them as
580 | member variables.
581 | """
582 | super(TwoLayerNet, self).__init__()
583 | self.linear1 = torch.nn.Linear(D_in, H)
584 | self.linear2 = torch.nn.Linear(H, D_out)
585 |
586 | def forward(self, x):
587 | """
588 | In the forward function we accept a Tensor of input data and we must return
589 | a Tensor of output data. We can use Modules defined in the constructor as
590 | well as arbitrary (differentiable) operations on Tensors.
591 | """
592 | h_relu = self.linear1(x).clamp(min=0)
593 | y_pred = self.linear2(h_relu)
594 | return y_pred
595 |
596 | # N is batch size; D_in is input dimension;
597 | # H is hidden dimension; D_out is output dimension.
598 | N, D_in, H, D_out = 64, 1000, 100, 10
599 |
600 | # Create random Tensors to hold inputs and outputs
601 | x = torch.randn(N, D_in)
602 | y = torch.randn(N, D_out)
603 |
604 | # Construct our model by instantiating the class defined above.
605 | model = TwoLayerNet(D_in, H, D_out)
606 |
607 | # Construct our loss function and an Optimizer. The call to model.parameters()
608 | # in the SGD constructor will contain the learnable parameters of the two
609 | # nn.Linear modules which are members of the model.
610 | loss_fn = torch.nn.MSELoss(reduction='sum')
611 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
612 | for t in range(500):
613 | # Forward pass: Compute predicted y by passing x to the model
614 | y_pred = model(x)
615 |
616 | # Compute and print loss
617 | loss = loss_fn(y_pred, y)
618 | print(t, loss.item())
619 |
620 | # Zero gradients, perform a backward pass, and update the weights.
621 | optimizer.zero_grad()
622 | loss.backward()
623 | optimizer.step()
624 |
625 | ```
626 |
627 |
628 | ## PyTorch: Control Flow + Weight Sharing
629 | As an example of dynamic graphs and weight sharing, we implement a very strange
630 | model: a fully-connected ReLU network that on each forward pass chooses a random
631 | number between 1 and 4 and uses that many hidden layers, reusing the same weights
632 | multiple times to compute the innermost hidden layers.
633 |
634 | For this model can use normal Python flow control to implement the loop, and we
635 | can implement weight sharing among the innermost layers by simply reusing the
636 | same Module multiple times when defining the forward pass.
637 |
638 | We can easily implement this model as a Module subclass:
639 |
640 | ```python
641 | # Code in file nn/dynamic_net.py
642 | import random
643 | import torch
644 |
645 | class DynamicNet(torch.nn.Module):
646 | def __init__(self, D_in, H, D_out):
647 | """
648 | In the constructor we construct three nn.Linear instances that we will use
649 | in the forward pass.
650 | """
651 | super(DynamicNet, self).__init__()
652 | self.input_linear = torch.nn.Linear(D_in, H)
653 | self.middle_linear = torch.nn.Linear(H, H)
654 | self.output_linear = torch.nn.Linear(H, D_out)
655 |
656 | def forward(self, x):
657 | """
658 | For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
659 | and reuse the middle_linear Module that many times to compute hidden layer
660 | representations.
661 |
662 | Since each forward pass builds a dynamic computation graph, we can use normal
663 | Python control-flow operators like loops or conditional statements when
664 | defining the forward pass of the model.
665 |
666 | Here we also see that it is perfectly safe to reuse the same Module many
667 | times when defining a computational graph. This is a big improvement from Lua
668 | Torch, where each Module could be used only once.
669 | """
670 | h_relu = self.input_linear(x).clamp(min=0)
671 | for _ in range(random.randint(0, 3)):
672 | h_relu = self.middle_linear(h_relu).clamp(min=0)
673 | y_pred = self.output_linear(h_relu)
674 | return y_pred
675 |
676 |
677 | # N is batch size; D_in is input dimension;
678 | # H is hidden dimension; D_out is output dimension.
679 | N, D_in, H, D_out = 64, 1000, 100, 10
680 |
681 | # Create random Tensors to hold inputs and outputs.
682 | x = torch.randn(N, D_in)
683 | y = torch.randn(N, D_out)
684 |
685 | # Construct our model by instantiating the class defined above
686 | model = DynamicNet(D_in, H, D_out)
687 |
688 | # Construct our loss function and an Optimizer. Training this strange model with
689 | # vanilla stochastic gradient descent is tough, so we use momentum
690 | criterion = torch.nn.MSELoss(reduction='sum')
691 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
692 | for t in range(500):
693 | # Forward pass: Compute predicted y by passing x to the model
694 | y_pred = model(x)
695 |
696 | # Compute and print loss
697 | loss = criterion(y_pred, y)
698 | print(t, loss.item())
699 |
700 | # Zero gradients, perform a backward pass, and update the weights.
701 | optimizer.zero_grad()
702 | loss.backward()
703 | optimizer.step()
704 | ```
705 |
--------------------------------------------------------------------------------
/README_raw.md:
--------------------------------------------------------------------------------
1 | This repository introduces the fundamental concepts of
2 | [PyTorch](https://github.com/pytorch/pytorch)
3 | through self-contained examples.
4 |
5 | At its core, PyTorch provides two main features:
6 | - An n-dimensional Tensor, similar to numpy but can run on GPUs
7 | - Automatic differentiation for building and training neural networks
8 |
9 | We will use a fully-connected ReLU network as our running example. The network
10 | will have a single hidden layer, and will be trained with gradient descent to
11 | fit random data by minimizing the Euclidean distance between the network output
12 | and the true output.
13 |
14 | **NOTE:** These examples have been update for PyTorch 0.4, which made several
15 | major changes to the core PyTorch API. Most notably, prior to 0.4 Tensors had
16 | to be wrapped in Variable objects to use autograd; this functionality has now
17 | been added directly to Tensors, and Variables are now deprecated.
18 |
19 | ### Table of Contents
20 | - Warm-up: numpy
21 | - PyTorch: Tensors
22 | - PyTorch: Autograd
23 | - PyTorch: Defining new autograd functions
24 | - TensorFlow: Static Graphs
25 | - PyTorch: nn
26 | - PyTorch: optim
27 | - PyTorch: Custom nn Modules
28 | - PyTorch: Control Flow and Weight Sharing
29 |
30 | ## Warm-up: numpy
31 |
32 | Before introducing PyTorch, we will first implement the network using numpy.
33 |
34 | Numpy provides an n-dimensional array object, and many functions for manipulating
35 | these arrays. Numpy is a generic framework for scientific computing; it does not
36 | know anything about computation graphs, or deep learning, or gradients. However
37 | we can easily use numpy to fit a two-layer network to random data by manually
38 | implementing the forward and backward passes through the network using numpy
39 | operations:
40 |
41 | ```python
42 | :INCLUDE tensor/two_layer_net_numpy.py
43 | ```
44 |
45 | ## PyTorch: Tensors
46 |
47 | Numpy is a great framework, but it cannot utilize GPUs to accelerate its
48 | numerical computations. For modern deep neural networks, GPUs often provide
49 | speedups of [50x or greater](https://github.com/jcjohnson/cnn-benchmarks), so
50 | unfortunately numpy won't be enough for modern deep learning.
51 |
52 | Here we introduce the most fundamental PyTorch concept: the **Tensor**. A PyTorch
53 | Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional
54 | array, and PyTorch provides many functions for operating on these Tensors.
55 | Any computation you might want to perform with numpy can also be accomplished
56 | with PyTorch Tensors; you should think of them as a generic tool for scientific
57 | computing.
58 |
59 | However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their
60 | numeric computations. To run a PyTorch Tensor on GPU, you use the `device`
61 | argument when constructing a Tensor to place the Tensor on a GPU.
62 |
63 | Here we use PyTorch Tensors to fit a two-layer network to random data. Like the
64 | numpy example above we manually implement the forward and backward
65 | passes through the network, using operations on PyTorch Tensors:
66 |
67 | ```python
68 | :INCLUDE tensor/two_layer_net_tensor.py
69 | ```
70 |
71 | ## PyTorch: Autograd
72 |
73 | In the above examples, we had to manually implement both the forward and
74 | backward passes of our neural network. Manually implementing the backward pass
75 | is not a big deal for a small two-layer network, but can quickly get very hairy
76 | for large complex networks.
77 |
78 | Thankfully, we can use
79 | [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
80 | to automate the computation of backward passes in neural networks.
81 | The **autograd** package in PyTorch provides exactly this functionality.
82 | When using autograd, the forward pass of your network will define a
83 | **computational graph**; nodes in the graph will be Tensors, and edges will be
84 | functions that produce output Tensors from input Tensors. Backpropagating through
85 | this graph then allows you to easily compute gradients.
86 |
87 | This sounds complicated, it's pretty simple to use in practice. If we want to
88 | compute gradients with respect to some Tensor, then we set `requires_grad=True`
89 | when constructing that Tensor. Any PyTorch operations on that Tensor will cause
90 | a computational graph to be constructed, allowing us to later perform backpropagation
91 | through the graph. If `x` is a Tensor with `requires_grad=True`, then after
92 | backpropagation `x.grad` will be another Tensor holding the gradient of `x` with
93 | respect to some scalar value.
94 |
95 | Sometimes you may wish to prevent PyTorch from building computational graphs when
96 | performing certain operations on Tensors with `requires_grad=True`; for example
97 | we usually don't want to backpropagate through the weight update steps when
98 | training a neural network. In such scenarios we can use the `torch.no_grad()`
99 | context manager to prevent the construction of a computational graph.
100 |
101 | Here we use PyTorch Tensors and autograd to implement our two-layer network;
102 | now we no longer need to manually implement the backward pass through the
103 | network:
104 |
105 | ```python
106 | :INCLUDE autograd/two_layer_net_autograd.py
107 | ```
108 |
109 | ## PyTorch: Defining new autograd functions
110 | Under the hood, each primitive autograd operator is really two functions that
111 | operate on Tensors. The **forward** function computes output Tensors from input
112 | Tensors. The **backward** function receives the gradient of the output Tensors
113 | with respect to some scalar value, and computes the gradient of the input Tensors
114 | with respect to that same scalar value.
115 |
116 | In PyTorch we can easily define our own autograd operator by defining a subclass
117 | of `torch.autograd.Function` and implementing the `forward` and `backward` functions.
118 | We can then use our new autograd operator by constructing an instance and calling it
119 | like a function, passing Tensors containing input data.
120 |
121 | In this example we define our own custom autograd function for performing the ReLU
122 | nonlinearity, and use it to implement our two-layer network:
123 |
124 | ```python
125 | :INCLUDE autograd/two_layer_net_custom_function.py
126 | ```
127 |
128 | ## TensorFlow: Static Graphs
129 | PyTorch autograd looks a lot like TensorFlow: in both frameworks we define
130 | a computational graph, and use automatic differentiation to compute gradients.
131 | The biggest difference between the two is that TensorFlow's computational graphs
132 | are **static** and PyTorch uses **dynamic** computational graphs.
133 |
134 | In TensorFlow, we define the computational graph once and then execute the same
135 | graph over and over again, possibly feeding different input data to the graph.
136 | In PyTorch, each forward pass defines a new computational graph.
137 |
138 | Static graphs are nice because you can optimize the graph up front; for example
139 | a framework might decide to fuse some graph operations for efficiency, or to
140 | come up with a strategy for distributing the graph across many GPUs or many
141 | machines. If you are reusing the same graph over and over, then this potentially
142 | costly up-front optimization can be amortized as the same graph is rerun over
143 | and over.
144 |
145 | One aspect where static and dynamic graphs differ is control flow. For some models
146 | we may wish to perform different computation for each data point; for example a
147 | recurrent network might be unrolled for different numbers of time steps for each
148 | data point; this unrolling can be implemented as a loop. With a static graph the
149 | loop construct needs to be a part of the graph; for this reason TensorFlow
150 | provides operators such as `tf.scan` for embedding loops into the graph. With
151 | dynamic graphs the situation is simpler: since we build graphs on-the-fly for
152 | each example, we can use normal imperative flow control to perform computation
153 | that differs for each input.
154 |
155 | To contrast with the PyTorch autograd example above, here we use TensorFlow to
156 | fit a simple two-layer net:
157 |
158 | ```python
159 | :INCLUDE autograd/tf_two_layer_net.py
160 | ```
161 |
162 |
163 | ## PyTorch: nn
164 | Computational graphs and autograd are a very powerful paradigm for defining
165 | complex operators and automatically taking derivatives; however for large
166 | neural networks raw autograd can be a bit too low-level.
167 |
168 | When building neural networks we frequently think of arranging the computation
169 | into **layers**, some of which have **learnable parameters** which will be
170 | optimized during learning.
171 |
172 | In TensorFlow, packages like [Keras](https://github.com/fchollet/keras),
173 | [TensorFlow-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim),
174 | and [TFLearn](http://tflearn.org/) provide higher-level abstractions over
175 | raw computational graphs that are useful for building neural networks.
176 |
177 | In PyTorch, the `nn` package serves this same purpose. The `nn` package defines a set of
178 | **Modules**, which are roughly equivalent to neural network layers. A Module receives
179 | input Tensors and computes output Tensors, but may also hold internal state such as
180 | Tensors containing learnable parameters. The `nn` package also defines a set of useful
181 | loss functions that are commonly used when training neural networks.
182 |
183 | In this example we use the `nn` package to implement our two-layer network:
184 |
185 | ```python
186 | :INCLUDE nn/two_layer_net_nn.py
187 | ```
188 |
189 |
190 | ## PyTorch: optim
191 | Up to this point we have updated the weights of our models by manually mutating
192 | Tensors holding learnable parameters. This is not a huge burden
193 | for simple optimization algorithms like stochastic gradient descent, but in practice
194 | we often train neural networks using more sophisiticated optimizers like AdaGrad,
195 | RMSProp, Adam, etc.
196 |
197 | The `optim` package in PyTorch abstracts the idea of an optimization algorithm and
198 | provides implementations of commonly used optimization algorithms.
199 |
200 | In this example we will use the `nn` package to define our model as before, but we
201 | will optimize the model using the Adam algorithm provided by the `optim` package:
202 |
203 | ```python
204 | :INCLUDE nn/two_layer_net_optim.py
205 | ```
206 |
207 |
208 | ## PyTorch: Custom nn Modules
209 | Sometimes you will want to specify models that are more complex than a sequence of
210 | existing Modules; for these cases you can define your own Modules by subclassing
211 | `nn.Module` and defining a `forward` which receives input Tensors and produces
212 | output Tensors using other modules or other autograd operations on Tensors.
213 |
214 | In this example we implement our two-layer network as a custom Module subclass:
215 |
216 | ```python
217 | :INCLUDE nn/two_layer_net_module.py
218 | ```
219 |
220 |
221 | ## PyTorch: Control Flow + Weight Sharing
222 | As an example of dynamic graphs and weight sharing, we implement a very strange
223 | model: a fully-connected ReLU network that on each forward pass chooses a random
224 | number between 1 and 4 and uses that many hidden layers, reusing the same weights
225 | multiple times to compute the innermost hidden layers.
226 |
227 | For this model can use normal Python flow control to implement the loop, and we
228 | can implement weight sharing among the innermost layers by simply reusing the
229 | same Module multiple times when defining the forward pass.
230 |
231 | We can easily implement this model as a Module subclass:
232 |
233 | ```python
234 | :INCLUDE nn/dynamic_net.py
235 | ```
236 |
--------------------------------------------------------------------------------
/autograd/tf_two_layer_net.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | """
5 | A fully-connected ReLU network with one hidden layer and no biases, trained to
6 | predict y from x by minimizing squared Euclidean distance.
7 |
8 | This implementation uses basic TensorFlow operations to set up a computational
9 | graph, then executes the graph many times to actually train the network.
10 |
11 | One of the main differences between TensorFlow and PyTorch is that TensorFlow
12 | uses static computational graphs while PyTorch uses dynamic computational
13 | graphs.
14 |
15 | In TensorFlow we first set up the computational graph, then execute the same
16 | graph many times.
17 | """
18 |
19 | # First we set up the computational graph:
20 |
21 | # N is batch size; D_in is input dimension;
22 | # H is hidden dimension; D_out is output dimension.
23 | N, D_in, H, D_out = 64, 1000, 100, 10
24 |
25 | # Create placeholders for the input and target data; these will be filled
26 | # with real data when we execute the graph.
27 | x = tf.placeholder(tf.float32, shape=(None, D_in))
28 | y = tf.placeholder(tf.float32, shape=(None, D_out))
29 |
30 | # Create Variables for the weights and initialize them with random data.
31 | # A TensorFlow Variable persists its value across executions of the graph.
32 | w1 = tf.Variable(tf.random_normal((D_in, H)))
33 | w2 = tf.Variable(tf.random_normal((H, D_out)))
34 |
35 | # Forward pass: Compute the predicted y using operations on TensorFlow Tensors.
36 | # Note that this code does not actually perform any numeric operations; it
37 | # merely sets up the computational graph that we will later execute.
38 | h = tf.matmul(x, w1)
39 | h_relu = tf.maximum(h, tf.zeros(1))
40 | y_pred = tf.matmul(h_relu, w2)
41 |
42 | # Compute loss using operations on TensorFlow Tensors
43 | loss = tf.reduce_sum((y - y_pred) ** 2.0)
44 |
45 | # Compute gradient of the loss with respect to w1 and w2.
46 | grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])
47 |
48 | # Update the weights using gradient descent. To actually update the weights
49 | # we need to evaluate new_w1 and new_w2 when executing the graph. Note that
50 | # in TensorFlow the the act of updating the value of the weights is part of
51 | # the computational graph; in PyTorch this happens outside the computational
52 | # graph.
53 | learning_rate = 1e-6
54 | new_w1 = w1.assign(w1 - learning_rate * grad_w1)
55 | new_w2 = w2.assign(w2 - learning_rate * grad_w2)
56 |
57 | # Now we have built our computational graph, so we enter a TensorFlow session to
58 | # actually execute the graph.
59 | with tf.Session() as sess:
60 | # Run the graph once to initialize the Variables w1 and w2.
61 | sess.run(tf.global_variables_initializer())
62 |
63 | # Create numpy arrays holding the actual data for the inputs x and targets y
64 | x_value = np.random.randn(N, D_in)
65 | y_value = np.random.randn(N, D_out)
66 | for _ in range(500):
67 | # Execute the graph many times. Each time it executes we want to bind
68 | # x_value to x and y_value to y, specified with the feed_dict argument.
69 | # Each time we execute the graph we want to compute the values for loss,
70 | # new_w1, and new_w2; the values of these Tensors are returned as numpy
71 | # arrays.
72 | loss_value, _, _ = sess.run([loss, new_w1, new_w2],
73 | feed_dict={x: x_value, y: y_value})
74 | print(loss_value)
75 |
--------------------------------------------------------------------------------
/autograd/two_layer_net_autograd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer and no biases, trained to
5 | predict y from x by minimizing squared Euclidean distance.
6 |
7 | This implementation computes the forward pass using operations on PyTorch
8 | Tensors, and uses PyTorch autograd to compute gradients.
9 |
10 | When we create a PyTorch Tensor with requires_grad=True, then operations
11 | involving that Tensor will not just compute values; they will also build up
12 | a computational graph in the background, allowing us to easily backpropagate
13 | through the graph to compute gradients of some downstream (scalar) loss with
14 | respect to a Tensor. Concretely if x is a Tensor with x.requires_grad == True
15 | then after backpropagation x.grad will be another Tensor holding the gradient
16 | of x with respect to some scalar value.
17 | """
18 |
19 | device = torch.device('cpu')
20 | # device = torch.device('cuda') # Uncomment this to run on GPU
21 |
22 | # N is batch size; D_in is input dimension;
23 | # H is hidden dimension; D_out is output dimension.
24 | N, D_in, H, D_out = 64, 1000, 100, 10
25 |
26 | # Create random Tensors to hold input and outputs
27 | x = torch.randn(N, D_in, device=device)
28 | y = torch.randn(N, D_out, device=device)
29 |
30 | # Create random Tensors for weights; setting requires_grad=True means that we
31 | # want to compute gradients for these Tensors during the backward pass.
32 | w1 = torch.randn(D_in, H, device=device, requires_grad=True)
33 | w2 = torch.randn(H, D_out, device=device, requires_grad=True)
34 |
35 | learning_rate = 1e-6
36 | for t in range(500):
37 | # Forward pass: compute predicted y using operations on Tensors. Since w1 and
38 | # w2 have requires_grad=True, operations involving these Tensors will cause
39 | # PyTorch to build a computational graph, allowing automatic computation of
40 | # gradients. Since we are no longer implementing the backward pass by hand we
41 | # don't need to keep references to intermediate values.
42 | y_pred = x.mm(w1).clamp(min=0).mm(w2)
43 |
44 | # Compute and print loss. Loss is a Tensor of shape (), and loss.item()
45 | # is a Python number giving its value.
46 | loss = (y_pred - y).pow(2).sum()
47 | print(t, loss.item())
48 |
49 | # Use autograd to compute the backward pass. This call will compute the
50 | # gradient of loss with respect to all Tensors with requires_grad=True.
51 | # After this call w1.grad and w2.grad will be Tensors holding the gradient
52 | # of the loss with respect to w1 and w2 respectively.
53 | loss.backward()
54 |
55 | # Update weights using gradient descent. For this step we just want to mutate
56 | # the values of w1 and w2 in-place; we don't want to build up a computational
57 | # graph for the update steps, so we use the torch.no_grad() context manager
58 | # to prevent PyTorch from building a computational graph for the updates
59 | with torch.no_grad():
60 | w1 -= learning_rate * w1.grad
61 | w2 -= learning_rate * w2.grad
62 |
63 | # Manually zero the gradients after running the backward pass
64 | w1.grad.zero_()
65 | w2.grad.zero_()
66 |
--------------------------------------------------------------------------------
/autograd/two_layer_net_custom_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer and no biases, trained to
5 | predict y from x by minimizing squared Euclidean distance.
6 |
7 | This implementation computes the forward pass using operations on PyTorch
8 | Tensors, and uses PyTorch autograd to compute gradients.
9 |
10 | In this implementation we implement our own custom autograd function to perform
11 | the ReLU function.
12 | """
13 |
14 | class MyReLU(torch.autograd.Function):
15 | """
16 | We can implement our own custom autograd Functions by subclassing
17 | torch.autograd.Function and implementing the forward and backward passes
18 | which operate on Tensors.
19 | """
20 | @staticmethod
21 | def forward(ctx, x):
22 | """
23 | In the forward pass we receive a context object and a Tensor containing the
24 | input; we must return a Tensor containing the output, and we can use the
25 | context object to cache objects for use in the backward pass.
26 | """
27 | ctx.save_for_backward(x)
28 | return x.clamp(min=0)
29 |
30 | @staticmethod
31 | def backward(ctx, grad_output):
32 | """
33 | In the backward pass we receive the context object and a Tensor containing
34 | the gradient of the loss with respect to the output produced during the
35 | forward pass. We can retrieve cached data from the context object, and must
36 | compute and return the gradient of the loss with respect to the input to the
37 | forward function.
38 | """
39 | x, = ctx.saved_tensors
40 | grad_x = grad_output.clone()
41 | grad_x[x < 0] = 0
42 | return grad_x
43 |
44 |
45 | device = torch.device('cpu')
46 | # device = torch.device('cuda') # Uncomment this to run on GPU
47 |
48 | # N is batch size; D_in is input dimension;
49 | # H is hidden dimension; D_out is output dimension.
50 | N, D_in, H, D_out = 64, 1000, 100, 10
51 |
52 | # Create random Tensors to hold input and output
53 | x = torch.randn(N, D_in, device=device)
54 | y = torch.randn(N, D_out, device=device)
55 |
56 | # Create random Tensors for weights.
57 | w1 = torch.randn(D_in, H, device=device, requires_grad=True)
58 | w2 = torch.randn(H, D_out, device=device, requires_grad=True)
59 |
60 | learning_rate = 1e-6
61 | for t in range(500):
62 | # Forward pass: compute predicted y using operations on Tensors; we call our
63 | # custom ReLU implementation using the MyReLU.apply function
64 | y_pred = MyReLU.apply(x.mm(w1)).mm(w2)
65 |
66 | # Compute and print loss
67 | loss = (y_pred - y).pow(2).sum()
68 | print(t, loss.item())
69 |
70 | # Use autograd to compute the backward pass.
71 | loss.backward()
72 |
73 | with torch.no_grad():
74 | # Update weights using gradient descent
75 | w1 -= learning_rate * w1.grad
76 | w2 -= learning_rate * w2.grad
77 |
78 | # Manually zero the gradients after running the backward pass
79 | w1.grad.zero_()
80 | w2.grad.zero_()
81 |
82 |
--------------------------------------------------------------------------------
/build_readme.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | """
4 | GitHub doesn't provide an include mechanism for README files so we have to
5 | implement our own.
6 | """
7 |
8 | def main():
9 | build_readme('README_raw.md', 'README.md')
10 | for d in os.listdir('.'):
11 | if not os.path.isdir(d) or d.startswith('.'):
12 | continue
13 | in_path = os.path.join(d, 'README_raw.md')
14 | out_path = os.path.join(d, 'README.md')
15 | build_readme(in_path, out_path)
16 |
17 |
18 | def build_readme(in_path, out_path):
19 | if not os.path.isfile(in_path):
20 | return
21 | with open(in_path, 'r') as fin, open(out_path, 'w') as fout:
22 | for line in fin:
23 | if not line.startswith(':INCLUDE'):
24 | fout.write('%s' % line)
25 | else:
26 | include_path = line.split(' ')[1].strip()
27 | include_path = os.path.join(os.path.split(in_path)[0], include_path)
28 | fout.write('# Code in file %s\n' % include_path)
29 | skip_toggle = False
30 | skip_next = False
31 | with open(include_path, 'r') as finc:
32 | for ll in finc:
33 | if ll.startswith('"""'):
34 | skip_next = skip_toggle
35 | skip_toggle = not skip_toggle
36 | elif not skip_toggle and not skip_next:
37 | fout.write(ll)
38 | elif skip_next:
39 | skip_next = False
40 |
41 | if __name__ == '__main__':
42 | main()
43 |
44 |
--------------------------------------------------------------------------------
/nn/dynamic_net.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | """
5 | To showcase the power of PyTorch dynamic graphs, we will implement a very strange
6 | model: a fully-connected ReLU network that on each forward pass randomly chooses
7 | a number between 1 and 4 and has that many hidden layers, reusing the same
8 | weights multiple times to compute the innermost hidden layers.
9 | """
10 |
11 | class DynamicNet(torch.nn.Module):
12 | def __init__(self, D_in, H, D_out):
13 | """
14 | In the constructor we construct three nn.Linear instances that we will use
15 | in the forward pass.
16 | """
17 | super(DynamicNet, self).__init__()
18 | self.input_linear = torch.nn.Linear(D_in, H)
19 | self.middle_linear = torch.nn.Linear(H, H)
20 | self.output_linear = torch.nn.Linear(H, D_out)
21 |
22 | def forward(self, x):
23 | """
24 | For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
25 | and reuse the middle_linear Module that many times to compute hidden layer
26 | representations.
27 |
28 | Since each forward pass builds a dynamic computation graph, we can use normal
29 | Python control-flow operators like loops or conditional statements when
30 | defining the forward pass of the model.
31 |
32 | Here we also see that it is perfectly safe to reuse the same Module many
33 | times when defining a computational graph. This is a big improvement from Lua
34 | Torch, where each Module could be used only once.
35 | """
36 | h_relu = self.input_linear(x).clamp(min=0)
37 | for _ in range(random.randint(0, 3)):
38 | h_relu = self.middle_linear(h_relu).clamp(min=0)
39 | y_pred = self.output_linear(h_relu)
40 | return y_pred
41 |
42 |
43 | # N is batch size; D_in is input dimension;
44 | # H is hidden dimension; D_out is output dimension.
45 | N, D_in, H, D_out = 64, 1000, 100, 10
46 |
47 | # Create random Tensors to hold inputs and outputs.
48 | x = torch.randn(N, D_in)
49 | y = torch.randn(N, D_out)
50 |
51 | # Construct our model by instantiating the class defined above
52 | model = DynamicNet(D_in, H, D_out)
53 |
54 | # Construct our loss function and an Optimizer. Training this strange model with
55 | # vanilla stochastic gradient descent is tough, so we use momentum
56 | criterion = torch.nn.MSELoss(reduction='sum')
57 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
58 | for t in range(500):
59 | # Forward pass: Compute predicted y by passing x to the model
60 | y_pred = model(x)
61 |
62 | # Compute and print loss
63 | loss = criterion(y_pred, y)
64 | print(t, loss.item())
65 |
66 | # Zero gradients, perform a backward pass, and update the weights.
67 | optimizer.zero_grad()
68 | loss.backward()
69 | optimizer.step()
70 |
--------------------------------------------------------------------------------
/nn/two_layer_net_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer, trained to predict y from x
5 | by minimizing squared Euclidean distance.
6 |
7 | This implementation defines the model as a custom Module subclass. Whenever you
8 | want a model more complex than a simple sequence of existing Modules you will
9 | need to define your model this way.
10 | """
11 |
12 | class TwoLayerNet(torch.nn.Module):
13 | def __init__(self, D_in, H, D_out):
14 | """
15 | In the constructor we instantiate two nn.Linear modules and assign them as
16 | member variables.
17 | """
18 | super(TwoLayerNet, self).__init__()
19 | self.linear1 = torch.nn.Linear(D_in, H)
20 | self.linear2 = torch.nn.Linear(H, D_out)
21 |
22 | def forward(self, x):
23 | """
24 | In the forward function we accept a Tensor of input data and we must return
25 | a Tensor of output data. We can use Modules defined in the constructor as
26 | well as arbitrary (differentiable) operations on Tensors.
27 | """
28 | h_relu = self.linear1(x).clamp(min=0)
29 | y_pred = self.linear2(h_relu)
30 | return y_pred
31 |
32 | # N is batch size; D_in is input dimension;
33 | # H is hidden dimension; D_out is output dimension.
34 | N, D_in, H, D_out = 64, 1000, 100, 10
35 |
36 | # Create random Tensors to hold inputs and outputs
37 | x = torch.randn(N, D_in)
38 | y = torch.randn(N, D_out)
39 |
40 | # Construct our model by instantiating the class defined above.
41 | model = TwoLayerNet(D_in, H, D_out)
42 |
43 | # Construct our loss function and an Optimizer. The call to model.parameters()
44 | # in the SGD constructor will contain the learnable parameters of the two
45 | # nn.Linear modules which are members of the model.
46 | loss_fn = torch.nn.MSELoss(reduction='sum')
47 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
48 | for t in range(500):
49 | # Forward pass: Compute predicted y by passing x to the model
50 | y_pred = model(x)
51 |
52 | # Compute and print loss
53 | loss = loss_fn(y_pred, y)
54 | print(t, loss.item())
55 |
56 | # Zero gradients, perform a backward pass, and update the weights.
57 | optimizer.zero_grad()
58 | loss.backward()
59 | optimizer.step()
60 |
61 |
--------------------------------------------------------------------------------
/nn/two_layer_net_nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer, trained to predict y from x
5 | by minimizing squared Euclidean distance.
6 |
7 | This implementation uses the nn package from PyTorch to build the network.
8 | PyTorch autograd makes it easy to define computational graphs and take gradients,
9 | but raw autograd can be a bit too low-level for defining complex neural networks;
10 | this is where the nn package can help. The nn package defines a set of Modules,
11 | which you can think of as a neural network layer that has produces output from
12 | input and may have some trainable weights or other state.
13 | """
14 |
15 | device = torch.device('cpu')
16 | # device = torch.device('cuda') # Uncomment this to run on GPU
17 |
18 | # N is batch size; D_in is input dimension;
19 | # H is hidden dimension; D_out is output dimension.
20 | N, D_in, H, D_out = 64, 1000, 100, 10
21 |
22 | # Create random Tensors to hold inputs and outputs
23 | x = torch.randn(N, D_in, device=device)
24 | y = torch.randn(N, D_out, device=device)
25 |
26 | # Use the nn package to define our model as a sequence of layers. nn.Sequential
27 | # is a Module which contains other Modules, and applies them in sequence to
28 | # produce its output. Each Linear Module computes output from input using a
29 | # linear function, and holds internal Tensors for its weight and bias.
30 | # After constructing the model we use the .to() method to move it to the
31 | # desired device.
32 | model = torch.nn.Sequential(
33 | torch.nn.Linear(D_in, H),
34 | torch.nn.ReLU(),
35 | torch.nn.Linear(H, D_out),
36 | ).to(device)
37 |
38 | # The nn package also contains definitions of popular loss functions; in this
39 | # case we will use Mean Squared Error (MSE) as our loss function. Setting
40 | # reduction='sum' means that we are computing the *sum* of squared errors rather
41 | # than the mean; this is for consistency with the examples above where we
42 | # manually compute the loss, but in practice it is more common to use mean
43 | # squared error as a loss by setting reduction='elementwise_mean'.
44 | loss_fn = torch.nn.MSELoss(reduction='sum')
45 |
46 | learning_rate = 1e-4
47 | for t in range(500):
48 | # Forward pass: compute predicted y by passing x to the model. Module objects
49 | # override the __call__ operator so you can call them like functions. When
50 | # doing so you pass a Tensor of input data to the Module and it produces
51 | # a Tensor of output data.
52 | y_pred = model(x)
53 |
54 | # Compute and print loss. We pass Tensors containing the predicted and true
55 | # values of y, and the loss function returns a Tensor containing the loss.
56 | loss = loss_fn(y_pred, y)
57 | print(t, loss.item())
58 |
59 | # Zero the gradients before running the backward pass.
60 | model.zero_grad()
61 |
62 | # Backward pass: compute gradient of the loss with respect to all the learnable
63 | # parameters of the model. Internally, the parameters of each Module are stored
64 | # in Tensors with requires_grad=True, so this call will compute gradients for
65 | # all learnable parameters in the model.
66 | loss.backward()
67 |
68 | # Update the weights using gradient descent. Each parameter is a Tensor, so
69 | # we can access its data and gradients like we did before.
70 | with torch.no_grad():
71 | for param in model.parameters():
72 | param.data -= learning_rate * param.grad
73 |
--------------------------------------------------------------------------------
/nn/two_layer_net_optim.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer, trained to predict y from x
5 | by minimizing squared Euclidean distance.
6 |
7 | This implementation uses the nn package from PyTorch to build the network.
8 |
9 | Rather than manually updating the weights of the model as we have been doing,
10 | we use the optim package to define an Optimizer that will update the weights
11 | for us. The optim package defines many optimization algorithms that are commonly
12 | used for deep learning, including SGD+momentum, RMSProp, Adam, etc.
13 | """
14 |
15 | # N is batch size; D_in is input dimension;
16 | # H is hidden dimension; D_out is output dimension.
17 | N, D_in, H, D_out = 64, 1000, 100, 10
18 |
19 | # Create random Tensors to hold inputs and outputs.
20 | x = torch.randn(N, D_in)
21 | y = torch.randn(N, D_out)
22 |
23 | # Use the nn package to define our model and loss function.
24 | model = torch.nn.Sequential(
25 | torch.nn.Linear(D_in, H),
26 | torch.nn.ReLU(),
27 | torch.nn.Linear(H, D_out),
28 | )
29 | loss_fn = torch.nn.MSELoss(reduction='sum')
30 |
31 | # Use the optim package to define an Optimizer that will update the weights of
32 | # the model for us. Here we will use Adam; the optim package contains many other
33 | # optimization algoriths. The first argument to the Adam constructor tells the
34 | # optimizer which Tensors it should update.
35 | learning_rate = 1e-4
36 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
37 | for t in range(500):
38 | # Forward pass: compute predicted y by passing x to the model.
39 | y_pred = model(x)
40 |
41 | # Compute and print loss.
42 | loss = loss_fn(y_pred, y)
43 | print(t, loss.item())
44 |
45 | # Before the backward pass, use the optimizer object to zero all of the
46 | # gradients for the Tensors it will update (which are the learnable weights
47 | # of the model)
48 | optimizer.zero_grad()
49 |
50 | # Backward pass: compute gradient of the loss with respect to model parameters
51 | loss.backward()
52 |
53 | # Calling the step function on an Optimizer makes an update to its parameters
54 | optimizer.step()
55 |
--------------------------------------------------------------------------------
/tensor/two_layer_net_numpy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer and no biases, trained to
5 | predict y from x using Euclidean error.
6 |
7 | This implementation uses numpy to manually compute the forward pass, loss, and
8 | backward pass.
9 |
10 | A numpy array is a generic n-dimensional array; it does not know anything about
11 | deep learning or gradients or computational graphs, and is just a way to perform
12 | generic numeric computations.
13 | """
14 |
15 | # N is batch size; D_in is input dimension;
16 | # H is hidden dimension; D_out is output dimension.
17 | N, D_in, H, D_out = 64, 1000, 100, 10
18 |
19 | # Create random input and output data
20 | x = np.random.randn(N, D_in)
21 | y = np.random.randn(N, D_out)
22 |
23 | # Randomly initialize weights
24 | w1 = np.random.randn(D_in, H)
25 | w2 = np.random.randn(H, D_out)
26 |
27 | learning_rate = 1e-6
28 | for t in range(500):
29 | # Forward pass: compute predicted y
30 | h = x.dot(w1)
31 | h_relu = np.maximum(h, 0)
32 | y_pred = h_relu.dot(w2)
33 |
34 | # Compute and print loss
35 | loss = np.square(y_pred - y).sum()
36 | print(t, loss)
37 |
38 | # Backprop to compute gradients of w1 and w2 with respect to loss
39 | grad_y_pred = 2.0 * (y_pred - y)
40 | grad_w2 = h_relu.T.dot(grad_y_pred)
41 | grad_h_relu = grad_y_pred.dot(w2.T)
42 | grad_h = grad_h_relu.copy()
43 | grad_h[h < 0] = 0
44 | grad_w1 = x.T.dot(grad_h)
45 |
46 | # Update weights
47 | w1 -= learning_rate * grad_w1
48 | w2 -= learning_rate * grad_w2
49 |
--------------------------------------------------------------------------------
/tensor/two_layer_net_tensor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """
4 | A fully-connected ReLU network with one hidden layer and no biases, trained to
5 | predict y from x by minimizing squared Euclidean distance.
6 |
7 | This implementation uses PyTorch tensors to manually compute the forward pass,
8 | loss, and backward pass.
9 |
10 | A PyTorch Tensor is basically the same as a numpy array: it does not know
11 | anything about deep learning or computational graphs or gradients, and is just
12 | a generic n-dimensional array to be used for arbitrary numeric computation.
13 |
14 | The biggest difference between a numpy array and a PyTorch Tensor is that
15 | a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
16 | just pass a different value to the `device` argument when constructing the
17 | Tensor.
18 | """
19 |
20 | device = torch.device('cpu')
21 | # device = torch.device('cuda') # Uncomment this to run on GPU
22 |
23 | # N is batch size; D_in is input dimension;
24 | # H is hidden dimension; D_out is output dimension.
25 | N, D_in, H, D_out = 64, 1000, 100, 10
26 |
27 | # Create random input and output data
28 | x = torch.randn(N, D_in, device=device)
29 | y = torch.randn(N, D_out, device=device)
30 |
31 | # Randomly initialize weights
32 | w1 = torch.randn(D_in, H, device=device)
33 | w2 = torch.randn(H, D_out, device=device)
34 |
35 | learning_rate = 1e-6
36 | for t in range(500):
37 | # Forward pass: compute predicted y
38 | h = x.mm(w1)
39 | h_relu = h.clamp(min=0)
40 | y_pred = h_relu.mm(w2)
41 |
42 | # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor
43 | # of shape (); we can get its value as a Python number with loss.item().
44 | loss = (y_pred - y).pow(2).sum()
45 | print(t, loss.item())
46 |
47 | # Backprop to compute gradients of w1 and w2 with respect to loss
48 | grad_y_pred = 2.0 * (y_pred - y)
49 | grad_w2 = h_relu.t().mm(grad_y_pred)
50 | grad_h_relu = grad_y_pred.mm(w2.t())
51 | grad_h = grad_h_relu.clone()
52 | grad_h[h < 0] = 0
53 | grad_w1 = x.t().mm(grad_h)
54 |
55 | # Update weights using gradient descent
56 | w1 -= learning_rate * grad_w1
57 | w2 -= learning_rate * grad_w2
58 |
--------------------------------------------------------------------------------