├── .gitignore ├── LICENSE ├── README.md ├── README_raw.md ├── autograd ├── tf_two_layer_net.py ├── two_layer_net_autograd.py └── two_layer_net_custom_function.py ├── build_readme.py ├── nn ├── dynamic_net.py ├── two_layer_net_module.py ├── two_layer_net_nn.py └── two_layer_net_optim.py └── tensor ├── two_layer_net_numpy.py └── two_layer_net_tensor.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Justin Johnson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository introduces the fundamental concepts of 2 | [PyTorch](https://github.com/pytorch/pytorch) 3 | through self-contained examples. 4 | 5 | At its core, PyTorch provides two main features: 6 | - An n-dimensional Tensor, similar to numpy but can run on GPUs 7 | - Automatic differentiation for building and training neural networks 8 | 9 | We will use a fully-connected ReLU network as our running example. The network 10 | will have a single hidden layer, and will be trained with gradient descent to 11 | fit random data by minimizing the Euclidean distance between the network output 12 | and the true output. 13 | 14 | **NOTE:** These examples have been update for PyTorch 0.4, which made several 15 | major changes to the core PyTorch API. Most notably, prior to 0.4 Tensors had 16 | to be wrapped in Variable objects to use autograd; this functionality has now 17 | been added directly to Tensors, and Variables are now deprecated. 18 | 19 | ### Table of Contents 20 | - Warm-up: numpy 21 | - PyTorch: Tensors 22 | - PyTorch: Autograd 23 | - PyTorch: Defining new autograd functions 24 | - TensorFlow: Static Graphs 25 | - PyTorch: nn 26 | - PyTorch: optim 27 | - PyTorch: Custom nn Modules 28 | - PyTorch: Control Flow and Weight Sharing 29 | 30 | ## Warm-up: numpy 31 | 32 | Before introducing PyTorch, we will first implement the network using numpy. 33 | 34 | Numpy provides an n-dimensional array object, and many functions for manipulating 35 | these arrays. Numpy is a generic framework for scientific computing; it does not 36 | know anything about computation graphs, or deep learning, or gradients. However 37 | we can easily use numpy to fit a two-layer network to random data by manually 38 | implementing the forward and backward passes through the network using numpy 39 | operations: 40 | 41 | ```python 42 | # Code in file tensor/two_layer_net_numpy.py 43 | import numpy as np 44 | 45 | # N is batch size; D_in is input dimension; 46 | # H is hidden dimension; D_out is output dimension. 47 | N, D_in, H, D_out = 64, 1000, 100, 10 48 | 49 | # Create random input and output data 50 | x = np.random.randn(N, D_in) 51 | y = np.random.randn(N, D_out) 52 | 53 | # Randomly initialize weights 54 | w1 = np.random.randn(D_in, H) 55 | w2 = np.random.randn(H, D_out) 56 | 57 | learning_rate = 1e-6 58 | for t in range(500): 59 | # Forward pass: compute predicted y 60 | h = x.dot(w1) 61 | h_relu = np.maximum(h, 0) 62 | y_pred = h_relu.dot(w2) 63 | 64 | # Compute and print loss 65 | loss = np.square(y_pred - y).sum() 66 | print(t, loss) 67 | 68 | # Backprop to compute gradients of w1 and w2 with respect to loss 69 | grad_y_pred = 2.0 * (y_pred - y) 70 | grad_w2 = h_relu.T.dot(grad_y_pred) 71 | grad_h_relu = grad_y_pred.dot(w2.T) 72 | grad_h = grad_h_relu.copy() 73 | grad_h[h < 0] = 0 74 | grad_w1 = x.T.dot(grad_h) 75 | 76 | # Update weights 77 | w1 -= learning_rate * grad_w1 78 | w2 -= learning_rate * grad_w2 79 | ``` 80 | 81 | ## PyTorch: Tensors 82 | 83 | Numpy is a great framework, but it cannot utilize GPUs to accelerate its 84 | numerical computations. For modern deep neural networks, GPUs often provide 85 | speedups of [50x or greater](https://github.com/jcjohnson/cnn-benchmarks), so 86 | unfortunately numpy won't be enough for modern deep learning. 87 | 88 | Here we introduce the most fundamental PyTorch concept: the **Tensor**. A PyTorch 89 | Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional 90 | array, and PyTorch provides many functions for operating on these Tensors. 91 | Any computation you might want to perform with numpy can also be accomplished 92 | with PyTorch Tensors; you should think of them as a generic tool for scientific 93 | computing. 94 | 95 | However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their 96 | numeric computations. To run a PyTorch Tensor on GPU, you use the `device` 97 | argument when constructing a Tensor to place the Tensor on a GPU. 98 | 99 | Here we use PyTorch Tensors to fit a two-layer network to random data. Like the 100 | numpy example above we manually implement the forward and backward 101 | passes through the network, using operations on PyTorch Tensors: 102 | 103 | ```python 104 | # Code in file tensor/two_layer_net_tensor.py 105 | import torch 106 | 107 | device = torch.device('cpu') 108 | # device = torch.device('cuda') # Uncomment this to run on GPU 109 | 110 | # N is batch size; D_in is input dimension; 111 | # H is hidden dimension; D_out is output dimension. 112 | N, D_in, H, D_out = 64, 1000, 100, 10 113 | 114 | # Create random input and output data 115 | x = torch.randn(N, D_in, device=device) 116 | y = torch.randn(N, D_out, device=device) 117 | 118 | # Randomly initialize weights 119 | w1 = torch.randn(D_in, H, device=device) 120 | w2 = torch.randn(H, D_out, device=device) 121 | 122 | learning_rate = 1e-6 123 | for t in range(500): 124 | # Forward pass: compute predicted y 125 | h = x.mm(w1) 126 | h_relu = h.clamp(min=0) 127 | y_pred = h_relu.mm(w2) 128 | 129 | # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor 130 | # of shape (); we can get its value as a Python number with loss.item(). 131 | loss = (y_pred - y).pow(2).sum() 132 | print(t, loss.item()) 133 | 134 | # Backprop to compute gradients of w1 and w2 with respect to loss 135 | grad_y_pred = 2.0 * (y_pred - y) 136 | grad_w2 = h_relu.t().mm(grad_y_pred) 137 | grad_h_relu = grad_y_pred.mm(w2.t()) 138 | grad_h = grad_h_relu.clone() 139 | grad_h[h < 0] = 0 140 | grad_w1 = x.t().mm(grad_h) 141 | 142 | # Update weights using gradient descent 143 | w1 -= learning_rate * grad_w1 144 | w2 -= learning_rate * grad_w2 145 | ``` 146 | 147 | ## PyTorch: Autograd 148 | 149 | In the above examples, we had to manually implement both the forward and 150 | backward passes of our neural network. Manually implementing the backward pass 151 | is not a big deal for a small two-layer network, but can quickly get very hairy 152 | for large complex networks. 153 | 154 | Thankfully, we can use 155 | [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) 156 | to automate the computation of backward passes in neural networks. 157 | The **autograd** package in PyTorch provides exactly this functionality. 158 | When using autograd, the forward pass of your network will define a 159 | **computational graph**; nodes in the graph will be Tensors, and edges will be 160 | functions that produce output Tensors from input Tensors. Backpropagating through 161 | this graph then allows you to easily compute gradients. 162 | 163 | This sounds complicated, it's pretty simple to use in practice. If we want to 164 | compute gradients with respect to some Tensor, then we set `requires_grad=True` 165 | when constructing that Tensor. Any PyTorch operations on that Tensor will cause 166 | a computational graph to be constructed, allowing us to later perform backpropagation 167 | through the graph. If `x` is a Tensor with `requires_grad=True`, then after 168 | backpropagation `x.grad` will be another Tensor holding the gradient of `x` with 169 | respect to some scalar value. 170 | 171 | Sometimes you may wish to prevent PyTorch from building computational graphs when 172 | performing certain operations on Tensors with `requires_grad=True`; for example 173 | we usually don't want to backpropagate through the weight update steps when 174 | training a neural network. In such scenarios we can use the `torch.no_grad()` 175 | context manager to prevent the construction of a computational graph. 176 | 177 | Here we use PyTorch Tensors and autograd to implement our two-layer network; 178 | now we no longer need to manually implement the backward pass through the 179 | network: 180 | 181 | ```python 182 | # Code in file autograd/two_layer_net_autograd.py 183 | import torch 184 | 185 | device = torch.device('cpu') 186 | # device = torch.device('cuda') # Uncomment this to run on GPU 187 | 188 | # N is batch size; D_in is input dimension; 189 | # H is hidden dimension; D_out is output dimension. 190 | N, D_in, H, D_out = 64, 1000, 100, 10 191 | 192 | # Create random Tensors to hold input and outputs 193 | x = torch.randn(N, D_in, device=device) 194 | y = torch.randn(N, D_out, device=device) 195 | 196 | # Create random Tensors for weights; setting requires_grad=True means that we 197 | # want to compute gradients for these Tensors during the backward pass. 198 | w1 = torch.randn(D_in, H, device=device, requires_grad=True) 199 | w2 = torch.randn(H, D_out, device=device, requires_grad=True) 200 | 201 | learning_rate = 1e-6 202 | for t in range(500): 203 | # Forward pass: compute predicted y using operations on Tensors. Since w1 and 204 | # w2 have requires_grad=True, operations involving these Tensors will cause 205 | # PyTorch to build a computational graph, allowing automatic computation of 206 | # gradients. Since we are no longer implementing the backward pass by hand we 207 | # don't need to keep references to intermediate values. 208 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 209 | 210 | # Compute and print loss. Loss is a Tensor of shape (), and loss.item() 211 | # is a Python number giving its value. 212 | loss = (y_pred - y).pow(2).sum() 213 | print(t, loss.item()) 214 | 215 | # Use autograd to compute the backward pass. This call will compute the 216 | # gradient of loss with respect to all Tensors with requires_grad=True. 217 | # After this call w1.grad and w2.grad will be Tensors holding the gradient 218 | # of the loss with respect to w1 and w2 respectively. 219 | loss.backward() 220 | 221 | # Update weights using gradient descent. For this step we just want to mutate 222 | # the values of w1 and w2 in-place; we don't want to build up a computational 223 | # graph for the update steps, so we use the torch.no_grad() context manager 224 | # to prevent PyTorch from building a computational graph for the updates 225 | with torch.no_grad(): 226 | w1 -= learning_rate * w1.grad 227 | w2 -= learning_rate * w2.grad 228 | 229 | # Manually zero the gradients after running the backward pass 230 | w1.grad.zero_() 231 | w2.grad.zero_() 232 | ``` 233 | 234 | ## PyTorch: Defining new autograd functions 235 | Under the hood, each primitive autograd operator is really two functions that 236 | operate on Tensors. The **forward** function computes output Tensors from input 237 | Tensors. The **backward** function receives the gradient of the output Tensors 238 | with respect to some scalar value, and computes the gradient of the input Tensors 239 | with respect to that same scalar value. 240 | 241 | In PyTorch we can easily define our own autograd operator by defining a subclass 242 | of `torch.autograd.Function` and implementing the `forward` and `backward` functions. 243 | We can then use our new autograd operator by constructing an instance and calling it 244 | like a function, passing Tensors containing input data. 245 | 246 | In this example we define our own custom autograd function for performing the ReLU 247 | nonlinearity, and use it to implement our two-layer network: 248 | 249 | ```python 250 | # Code in file autograd/two_layer_net_custom_function.py 251 | import torch 252 | 253 | class MyReLU(torch.autograd.Function): 254 | """ 255 | We can implement our own custom autograd Functions by subclassing 256 | torch.autograd.Function and implementing the forward and backward passes 257 | which operate on Tensors. 258 | """ 259 | @staticmethod 260 | def forward(ctx, x): 261 | """ 262 | In the forward pass we receive a context object and a Tensor containing the 263 | input; we must return a Tensor containing the output, and we can use the 264 | context object to cache objects for use in the backward pass. 265 | """ 266 | ctx.save_for_backward(x) 267 | return x.clamp(min=0) 268 | 269 | @staticmethod 270 | def backward(ctx, grad_output): 271 | """ 272 | In the backward pass we receive the context object and a Tensor containing 273 | the gradient of the loss with respect to the output produced during the 274 | forward pass. We can retrieve cached data from the context object, and must 275 | compute and return the gradient of the loss with respect to the input to the 276 | forward function. 277 | """ 278 | x, = ctx.saved_tensors 279 | grad_x = grad_output.clone() 280 | grad_x[x < 0] = 0 281 | return grad_x 282 | 283 | 284 | device = torch.device('cpu') 285 | # device = torch.device('cuda') # Uncomment this to run on GPU 286 | 287 | # N is batch size; D_in is input dimension; 288 | # H is hidden dimension; D_out is output dimension. 289 | N, D_in, H, D_out = 64, 1000, 100, 10 290 | 291 | # Create random Tensors to hold input and output 292 | x = torch.randn(N, D_in, device=device) 293 | y = torch.randn(N, D_out, device=device) 294 | 295 | # Create random Tensors for weights. 296 | w1 = torch.randn(D_in, H, device=device, requires_grad=True) 297 | w2 = torch.randn(H, D_out, device=device, requires_grad=True) 298 | 299 | learning_rate = 1e-6 300 | for t in range(500): 301 | # Forward pass: compute predicted y using operations on Tensors; we call our 302 | # custom ReLU implementation using the MyReLU.apply function 303 | y_pred = MyReLU.apply(x.mm(w1)).mm(w2) 304 | 305 | # Compute and print loss 306 | loss = (y_pred - y).pow(2).sum() 307 | print(t, loss.item()) 308 | 309 | # Use autograd to compute the backward pass. 310 | loss.backward() 311 | 312 | with torch.no_grad(): 313 | # Update weights using gradient descent 314 | w1 -= learning_rate * w1.grad 315 | w2 -= learning_rate * w2.grad 316 | 317 | # Manually zero the gradients after running the backward pass 318 | w1.grad.zero_() 319 | w2.grad.zero_() 320 | 321 | ``` 322 | 323 | ## TensorFlow: Static Graphs 324 | PyTorch autograd looks a lot like TensorFlow: in both frameworks we define 325 | a computational graph, and use automatic differentiation to compute gradients. 326 | The biggest difference between the two is that TensorFlow's computational graphs 327 | are **static** and PyTorch uses **dynamic** computational graphs. 328 | 329 | In TensorFlow, we define the computational graph once and then execute the same 330 | graph over and over again, possibly feeding different input data to the graph. 331 | In PyTorch, each forward pass defines a new computational graph. 332 | 333 | Static graphs are nice because you can optimize the graph up front; for example 334 | a framework might decide to fuse some graph operations for efficiency, or to 335 | come up with a strategy for distributing the graph across many GPUs or many 336 | machines. If you are reusing the same graph over and over, then this potentially 337 | costly up-front optimization can be amortized as the same graph is rerun over 338 | and over. 339 | 340 | One aspect where static and dynamic graphs differ is control flow. For some models 341 | we may wish to perform different computation for each data point; for example a 342 | recurrent network might be unrolled for different numbers of time steps for each 343 | data point; this unrolling can be implemented as a loop. With a static graph the 344 | loop construct needs to be a part of the graph; for this reason TensorFlow 345 | provides operators such as `tf.scan` for embedding loops into the graph. With 346 | dynamic graphs the situation is simpler: since we build graphs on-the-fly for 347 | each example, we can use normal imperative flow control to perform computation 348 | that differs for each input. 349 | 350 | To contrast with the PyTorch autograd example above, here we use TensorFlow to 351 | fit a simple two-layer net: 352 | 353 | ```python 354 | # Code in file autograd/tf_two_layer_net.py 355 | import tensorflow as tf 356 | import numpy as np 357 | 358 | # First we set up the computational graph: 359 | 360 | # N is batch size; D_in is input dimension; 361 | # H is hidden dimension; D_out is output dimension. 362 | N, D_in, H, D_out = 64, 1000, 100, 10 363 | 364 | # Create placeholders for the input and target data; these will be filled 365 | # with real data when we execute the graph. 366 | x = tf.placeholder(tf.float32, shape=(None, D_in)) 367 | y = tf.placeholder(tf.float32, shape=(None, D_out)) 368 | 369 | # Create Variables for the weights and initialize them with random data. 370 | # A TensorFlow Variable persists its value across executions of the graph. 371 | w1 = tf.Variable(tf.random_normal((D_in, H))) 372 | w2 = tf.Variable(tf.random_normal((H, D_out))) 373 | 374 | # Forward pass: Compute the predicted y using operations on TensorFlow Tensors. 375 | # Note that this code does not actually perform any numeric operations; it 376 | # merely sets up the computational graph that we will later execute. 377 | h = tf.matmul(x, w1) 378 | h_relu = tf.maximum(h, tf.zeros(1)) 379 | y_pred = tf.matmul(h_relu, w2) 380 | 381 | # Compute loss using operations on TensorFlow Tensors 382 | loss = tf.reduce_sum((y - y_pred) ** 2.0) 383 | 384 | # Compute gradient of the loss with respect to w1 and w2. 385 | grad_w1, grad_w2 = tf.gradients(loss, [w1, w2]) 386 | 387 | # Update the weights using gradient descent. To actually update the weights 388 | # we need to evaluate new_w1 and new_w2 when executing the graph. Note that 389 | # in TensorFlow the the act of updating the value of the weights is part of 390 | # the computational graph; in PyTorch this happens outside the computational 391 | # graph. 392 | learning_rate = 1e-6 393 | new_w1 = w1.assign(w1 - learning_rate * grad_w1) 394 | new_w2 = w2.assign(w2 - learning_rate * grad_w2) 395 | 396 | # Now we have built our computational graph, so we enter a TensorFlow session to 397 | # actually execute the graph. 398 | with tf.Session() as sess: 399 | # Run the graph once to initialize the Variables w1 and w2. 400 | sess.run(tf.global_variables_initializer()) 401 | 402 | # Create numpy arrays holding the actual data for the inputs x and targets y 403 | x_value = np.random.randn(N, D_in) 404 | y_value = np.random.randn(N, D_out) 405 | for _ in range(500): 406 | # Execute the graph many times. Each time it executes we want to bind 407 | # x_value to x and y_value to y, specified with the feed_dict argument. 408 | # Each time we execute the graph we want to compute the values for loss, 409 | # new_w1, and new_w2; the values of these Tensors are returned as numpy 410 | # arrays. 411 | loss_value, _, _ = sess.run([loss, new_w1, new_w2], 412 | feed_dict={x: x_value, y: y_value}) 413 | print(loss_value) 414 | ``` 415 | 416 | 417 | ## PyTorch: nn 418 | Computational graphs and autograd are a very powerful paradigm for defining 419 | complex operators and automatically taking derivatives; however for large 420 | neural networks raw autograd can be a bit too low-level. 421 | 422 | When building neural networks we frequently think of arranging the computation 423 | into **layers**, some of which have **learnable parameters** which will be 424 | optimized during learning. 425 | 426 | In TensorFlow, packages like [Keras](https://github.com/fchollet/keras), 427 | [TensorFlow-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim), 428 | and [TFLearn](http://tflearn.org/) provide higher-level abstractions over 429 | raw computational graphs that are useful for building neural networks. 430 | 431 | In PyTorch, the `nn` package serves this same purpose. The `nn` package defines a set of 432 | **Modules**, which are roughly equivalent to neural network layers. A Module receives 433 | input Tensors and computes output Tensors, but may also hold internal state such as 434 | Tensors containing learnable parameters. The `nn` package also defines a set of useful 435 | loss functions that are commonly used when training neural networks. 436 | 437 | In this example we use the `nn` package to implement our two-layer network: 438 | 439 | ```python 440 | # Code in file nn/two_layer_net_nn.py 441 | import torch 442 | 443 | device = torch.device('cpu') 444 | # device = torch.device('cuda') # Uncomment this to run on GPU 445 | 446 | # N is batch size; D_in is input dimension; 447 | # H is hidden dimension; D_out is output dimension. 448 | N, D_in, H, D_out = 64, 1000, 100, 10 449 | 450 | # Create random Tensors to hold inputs and outputs 451 | x = torch.randn(N, D_in, device=device) 452 | y = torch.randn(N, D_out, device=device) 453 | 454 | # Use the nn package to define our model as a sequence of layers. nn.Sequential 455 | # is a Module which contains other Modules, and applies them in sequence to 456 | # produce its output. Each Linear Module computes output from input using a 457 | # linear function, and holds internal Tensors for its weight and bias. 458 | # After constructing the model we use the .to() method to move it to the 459 | # desired device. 460 | model = torch.nn.Sequential( 461 | torch.nn.Linear(D_in, H), 462 | torch.nn.ReLU(), 463 | torch.nn.Linear(H, D_out), 464 | ).to(device) 465 | 466 | # The nn package also contains definitions of popular loss functions; in this 467 | # case we will use Mean Squared Error (MSE) as our loss function. Setting 468 | # reduction='sum' means that we are computing the *sum* of squared errors rather 469 | # than the mean; this is for consistency with the examples above where we 470 | # manually compute the loss, but in practice it is more common to use mean 471 | # squared error as a loss by setting reduction='elementwise_mean'. 472 | loss_fn = torch.nn.MSELoss(reduction='sum') 473 | 474 | learning_rate = 1e-4 475 | for t in range(500): 476 | # Forward pass: compute predicted y by passing x to the model. Module objects 477 | # override the __call__ operator so you can call them like functions. When 478 | # doing so you pass a Tensor of input data to the Module and it produces 479 | # a Tensor of output data. 480 | y_pred = model(x) 481 | 482 | # Compute and print loss. We pass Tensors containing the predicted and true 483 | # values of y, and the loss function returns a Tensor containing the loss. 484 | loss = loss_fn(y_pred, y) 485 | print(t, loss.item()) 486 | 487 | # Zero the gradients before running the backward pass. 488 | model.zero_grad() 489 | 490 | # Backward pass: compute gradient of the loss with respect to all the learnable 491 | # parameters of the model. Internally, the parameters of each Module are stored 492 | # in Tensors with requires_grad=True, so this call will compute gradients for 493 | # all learnable parameters in the model. 494 | loss.backward() 495 | 496 | # Update the weights using gradient descent. Each parameter is a Tensor, so 497 | # we can access its data and gradients like we did before. 498 | with torch.no_grad(): 499 | for param in model.parameters(): 500 | param.data -= learning_rate * param.grad 501 | ``` 502 | 503 | 504 | ## PyTorch: optim 505 | Up to this point we have updated the weights of our models by manually mutating 506 | Tensors holding learnable parameters. This is not a huge burden 507 | for simple optimization algorithms like stochastic gradient descent, but in practice 508 | we often train neural networks using more sophisiticated optimizers like AdaGrad, 509 | RMSProp, Adam, etc. 510 | 511 | The `optim` package in PyTorch abstracts the idea of an optimization algorithm and 512 | provides implementations of commonly used optimization algorithms. 513 | 514 | In this example we will use the `nn` package to define our model as before, but we 515 | will optimize the model using the Adam algorithm provided by the `optim` package: 516 | 517 | ```python 518 | # Code in file nn/two_layer_net_optim.py 519 | import torch 520 | 521 | # N is batch size; D_in is input dimension; 522 | # H is hidden dimension; D_out is output dimension. 523 | N, D_in, H, D_out = 64, 1000, 100, 10 524 | 525 | # Create random Tensors to hold inputs and outputs. 526 | x = torch.randn(N, D_in) 527 | y = torch.randn(N, D_out) 528 | 529 | # Use the nn package to define our model and loss function. 530 | model = torch.nn.Sequential( 531 | torch.nn.Linear(D_in, H), 532 | torch.nn.ReLU(), 533 | torch.nn.Linear(H, D_out), 534 | ) 535 | loss_fn = torch.nn.MSELoss(reduction='sum') 536 | 537 | # Use the optim package to define an Optimizer that will update the weights of 538 | # the model for us. Here we will use Adam; the optim package contains many other 539 | # optimization algorithms. The first argument to the Adam constructor tells the 540 | # optimizer which Tensors it should update. 541 | learning_rate = 1e-4 542 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 543 | for t in range(500): 544 | # Forward pass: compute predicted y by passing x to the model. 545 | y_pred = model(x) 546 | 547 | # Compute and print loss. 548 | loss = loss_fn(y_pred, y) 549 | print(t, loss.item()) 550 | 551 | # Before the backward pass, use the optimizer object to zero all of the 552 | # gradients for the Tensors it will update (which are the learnable weights 553 | # of the model) 554 | optimizer.zero_grad() 555 | 556 | # Backward pass: compute gradient of the loss with respect to model parameters 557 | loss.backward() 558 | 559 | # Calling the step function on an Optimizer makes an update to its parameters 560 | optimizer.step() 561 | ``` 562 | 563 | 564 | ## PyTorch: Custom nn Modules 565 | Sometimes you will want to specify models that are more complex than a sequence of 566 | existing Modules; for these cases you can define your own Modules by subclassing 567 | `nn.Module` and defining a `forward` which receives input Tensors and produces 568 | output Tensors using other modules or other autograd operations on Tensors. 569 | 570 | In this example we implement our two-layer network as a custom Module subclass: 571 | 572 | ```python 573 | # Code in file nn/two_layer_net_module.py 574 | import torch 575 | 576 | class TwoLayerNet(torch.nn.Module): 577 | def __init__(self, D_in, H, D_out): 578 | """ 579 | In the constructor we instantiate two nn.Linear modules and assign them as 580 | member variables. 581 | """ 582 | super(TwoLayerNet, self).__init__() 583 | self.linear1 = torch.nn.Linear(D_in, H) 584 | self.linear2 = torch.nn.Linear(H, D_out) 585 | 586 | def forward(self, x): 587 | """ 588 | In the forward function we accept a Tensor of input data and we must return 589 | a Tensor of output data. We can use Modules defined in the constructor as 590 | well as arbitrary (differentiable) operations on Tensors. 591 | """ 592 | h_relu = self.linear1(x).clamp(min=0) 593 | y_pred = self.linear2(h_relu) 594 | return y_pred 595 | 596 | # N is batch size; D_in is input dimension; 597 | # H is hidden dimension; D_out is output dimension. 598 | N, D_in, H, D_out = 64, 1000, 100, 10 599 | 600 | # Create random Tensors to hold inputs and outputs 601 | x = torch.randn(N, D_in) 602 | y = torch.randn(N, D_out) 603 | 604 | # Construct our model by instantiating the class defined above. 605 | model = TwoLayerNet(D_in, H, D_out) 606 | 607 | # Construct our loss function and an Optimizer. The call to model.parameters() 608 | # in the SGD constructor will contain the learnable parameters of the two 609 | # nn.Linear modules which are members of the model. 610 | loss_fn = torch.nn.MSELoss(reduction='sum') 611 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) 612 | for t in range(500): 613 | # Forward pass: Compute predicted y by passing x to the model 614 | y_pred = model(x) 615 | 616 | # Compute and print loss 617 | loss = loss_fn(y_pred, y) 618 | print(t, loss.item()) 619 | 620 | # Zero gradients, perform a backward pass, and update the weights. 621 | optimizer.zero_grad() 622 | loss.backward() 623 | optimizer.step() 624 | 625 | ``` 626 | 627 | 628 | ## PyTorch: Control Flow + Weight Sharing 629 | As an example of dynamic graphs and weight sharing, we implement a very strange 630 | model: a fully-connected ReLU network that on each forward pass chooses a random 631 | number between 1 and 4 and uses that many hidden layers, reusing the same weights 632 | multiple times to compute the innermost hidden layers. 633 | 634 | For this model can use normal Python flow control to implement the loop, and we 635 | can implement weight sharing among the innermost layers by simply reusing the 636 | same Module multiple times when defining the forward pass. 637 | 638 | We can easily implement this model as a Module subclass: 639 | 640 | ```python 641 | # Code in file nn/dynamic_net.py 642 | import random 643 | import torch 644 | 645 | class DynamicNet(torch.nn.Module): 646 | def __init__(self, D_in, H, D_out): 647 | """ 648 | In the constructor we construct three nn.Linear instances that we will use 649 | in the forward pass. 650 | """ 651 | super(DynamicNet, self).__init__() 652 | self.input_linear = torch.nn.Linear(D_in, H) 653 | self.middle_linear = torch.nn.Linear(H, H) 654 | self.output_linear = torch.nn.Linear(H, D_out) 655 | 656 | def forward(self, x): 657 | """ 658 | For the forward pass of the model, we randomly choose either 0, 1, 2, or 3 659 | and reuse the middle_linear Module that many times to compute hidden layer 660 | representations. 661 | 662 | Since each forward pass builds a dynamic computation graph, we can use normal 663 | Python control-flow operators like loops or conditional statements when 664 | defining the forward pass of the model. 665 | 666 | Here we also see that it is perfectly safe to reuse the same Module many 667 | times when defining a computational graph. This is a big improvement from Lua 668 | Torch, where each Module could be used only once. 669 | """ 670 | h_relu = self.input_linear(x).clamp(min=0) 671 | for _ in range(random.randint(0, 3)): 672 | h_relu = self.middle_linear(h_relu).clamp(min=0) 673 | y_pred = self.output_linear(h_relu) 674 | return y_pred 675 | 676 | 677 | # N is batch size; D_in is input dimension; 678 | # H is hidden dimension; D_out is output dimension. 679 | N, D_in, H, D_out = 64, 1000, 100, 10 680 | 681 | # Create random Tensors to hold inputs and outputs. 682 | x = torch.randn(N, D_in) 683 | y = torch.randn(N, D_out) 684 | 685 | # Construct our model by instantiating the class defined above 686 | model = DynamicNet(D_in, H, D_out) 687 | 688 | # Construct our loss function and an Optimizer. Training this strange model with 689 | # vanilla stochastic gradient descent is tough, so we use momentum 690 | criterion = torch.nn.MSELoss(reduction='sum') 691 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) 692 | for t in range(500): 693 | # Forward pass: Compute predicted y by passing x to the model 694 | y_pred = model(x) 695 | 696 | # Compute and print loss 697 | loss = criterion(y_pred, y) 698 | print(t, loss.item()) 699 | 700 | # Zero gradients, perform a backward pass, and update the weights. 701 | optimizer.zero_grad() 702 | loss.backward() 703 | optimizer.step() 704 | ``` 705 | -------------------------------------------------------------------------------- /README_raw.md: -------------------------------------------------------------------------------- 1 | This repository introduces the fundamental concepts of 2 | [PyTorch](https://github.com/pytorch/pytorch) 3 | through self-contained examples. 4 | 5 | At its core, PyTorch provides two main features: 6 | - An n-dimensional Tensor, similar to numpy but can run on GPUs 7 | - Automatic differentiation for building and training neural networks 8 | 9 | We will use a fully-connected ReLU network as our running example. The network 10 | will have a single hidden layer, and will be trained with gradient descent to 11 | fit random data by minimizing the Euclidean distance between the network output 12 | and the true output. 13 | 14 | **NOTE:** These examples have been update for PyTorch 0.4, which made several 15 | major changes to the core PyTorch API. Most notably, prior to 0.4 Tensors had 16 | to be wrapped in Variable objects to use autograd; this functionality has now 17 | been added directly to Tensors, and Variables are now deprecated. 18 | 19 | ### Table of Contents 20 | - Warm-up: numpy 21 | - PyTorch: Tensors 22 | - PyTorch: Autograd 23 | - PyTorch: Defining new autograd functions 24 | - TensorFlow: Static Graphs 25 | - PyTorch: nn 26 | - PyTorch: optim 27 | - PyTorch: Custom nn Modules 28 | - PyTorch: Control Flow and Weight Sharing 29 | 30 | ## Warm-up: numpy 31 | 32 | Before introducing PyTorch, we will first implement the network using numpy. 33 | 34 | Numpy provides an n-dimensional array object, and many functions for manipulating 35 | these arrays. Numpy is a generic framework for scientific computing; it does not 36 | know anything about computation graphs, or deep learning, or gradients. However 37 | we can easily use numpy to fit a two-layer network to random data by manually 38 | implementing the forward and backward passes through the network using numpy 39 | operations: 40 | 41 | ```python 42 | :INCLUDE tensor/two_layer_net_numpy.py 43 | ``` 44 | 45 | ## PyTorch: Tensors 46 | 47 | Numpy is a great framework, but it cannot utilize GPUs to accelerate its 48 | numerical computations. For modern deep neural networks, GPUs often provide 49 | speedups of [50x or greater](https://github.com/jcjohnson/cnn-benchmarks), so 50 | unfortunately numpy won't be enough for modern deep learning. 51 | 52 | Here we introduce the most fundamental PyTorch concept: the **Tensor**. A PyTorch 53 | Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional 54 | array, and PyTorch provides many functions for operating on these Tensors. 55 | Any computation you might want to perform with numpy can also be accomplished 56 | with PyTorch Tensors; you should think of them as a generic tool for scientific 57 | computing. 58 | 59 | However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their 60 | numeric computations. To run a PyTorch Tensor on GPU, you use the `device` 61 | argument when constructing a Tensor to place the Tensor on a GPU. 62 | 63 | Here we use PyTorch Tensors to fit a two-layer network to random data. Like the 64 | numpy example above we manually implement the forward and backward 65 | passes through the network, using operations on PyTorch Tensors: 66 | 67 | ```python 68 | :INCLUDE tensor/two_layer_net_tensor.py 69 | ``` 70 | 71 | ## PyTorch: Autograd 72 | 73 | In the above examples, we had to manually implement both the forward and 74 | backward passes of our neural network. Manually implementing the backward pass 75 | is not a big deal for a small two-layer network, but can quickly get very hairy 76 | for large complex networks. 77 | 78 | Thankfully, we can use 79 | [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) 80 | to automate the computation of backward passes in neural networks. 81 | The **autograd** package in PyTorch provides exactly this functionality. 82 | When using autograd, the forward pass of your network will define a 83 | **computational graph**; nodes in the graph will be Tensors, and edges will be 84 | functions that produce output Tensors from input Tensors. Backpropagating through 85 | this graph then allows you to easily compute gradients. 86 | 87 | This sounds complicated, it's pretty simple to use in practice. If we want to 88 | compute gradients with respect to some Tensor, then we set `requires_grad=True` 89 | when constructing that Tensor. Any PyTorch operations on that Tensor will cause 90 | a computational graph to be constructed, allowing us to later perform backpropagation 91 | through the graph. If `x` is a Tensor with `requires_grad=True`, then after 92 | backpropagation `x.grad` will be another Tensor holding the gradient of `x` with 93 | respect to some scalar value. 94 | 95 | Sometimes you may wish to prevent PyTorch from building computational graphs when 96 | performing certain operations on Tensors with `requires_grad=True`; for example 97 | we usually don't want to backpropagate through the weight update steps when 98 | training a neural network. In such scenarios we can use the `torch.no_grad()` 99 | context manager to prevent the construction of a computational graph. 100 | 101 | Here we use PyTorch Tensors and autograd to implement our two-layer network; 102 | now we no longer need to manually implement the backward pass through the 103 | network: 104 | 105 | ```python 106 | :INCLUDE autograd/two_layer_net_autograd.py 107 | ``` 108 | 109 | ## PyTorch: Defining new autograd functions 110 | Under the hood, each primitive autograd operator is really two functions that 111 | operate on Tensors. The **forward** function computes output Tensors from input 112 | Tensors. The **backward** function receives the gradient of the output Tensors 113 | with respect to some scalar value, and computes the gradient of the input Tensors 114 | with respect to that same scalar value. 115 | 116 | In PyTorch we can easily define our own autograd operator by defining a subclass 117 | of `torch.autograd.Function` and implementing the `forward` and `backward` functions. 118 | We can then use our new autograd operator by constructing an instance and calling it 119 | like a function, passing Tensors containing input data. 120 | 121 | In this example we define our own custom autograd function for performing the ReLU 122 | nonlinearity, and use it to implement our two-layer network: 123 | 124 | ```python 125 | :INCLUDE autograd/two_layer_net_custom_function.py 126 | ``` 127 | 128 | ## TensorFlow: Static Graphs 129 | PyTorch autograd looks a lot like TensorFlow: in both frameworks we define 130 | a computational graph, and use automatic differentiation to compute gradients. 131 | The biggest difference between the two is that TensorFlow's computational graphs 132 | are **static** and PyTorch uses **dynamic** computational graphs. 133 | 134 | In TensorFlow, we define the computational graph once and then execute the same 135 | graph over and over again, possibly feeding different input data to the graph. 136 | In PyTorch, each forward pass defines a new computational graph. 137 | 138 | Static graphs are nice because you can optimize the graph up front; for example 139 | a framework might decide to fuse some graph operations for efficiency, or to 140 | come up with a strategy for distributing the graph across many GPUs or many 141 | machines. If you are reusing the same graph over and over, then this potentially 142 | costly up-front optimization can be amortized as the same graph is rerun over 143 | and over. 144 | 145 | One aspect where static and dynamic graphs differ is control flow. For some models 146 | we may wish to perform different computation for each data point; for example a 147 | recurrent network might be unrolled for different numbers of time steps for each 148 | data point; this unrolling can be implemented as a loop. With a static graph the 149 | loop construct needs to be a part of the graph; for this reason TensorFlow 150 | provides operators such as `tf.scan` for embedding loops into the graph. With 151 | dynamic graphs the situation is simpler: since we build graphs on-the-fly for 152 | each example, we can use normal imperative flow control to perform computation 153 | that differs for each input. 154 | 155 | To contrast with the PyTorch autograd example above, here we use TensorFlow to 156 | fit a simple two-layer net: 157 | 158 | ```python 159 | :INCLUDE autograd/tf_two_layer_net.py 160 | ``` 161 | 162 | 163 | ## PyTorch: nn 164 | Computational graphs and autograd are a very powerful paradigm for defining 165 | complex operators and automatically taking derivatives; however for large 166 | neural networks raw autograd can be a bit too low-level. 167 | 168 | When building neural networks we frequently think of arranging the computation 169 | into **layers**, some of which have **learnable parameters** which will be 170 | optimized during learning. 171 | 172 | In TensorFlow, packages like [Keras](https://github.com/fchollet/keras), 173 | [TensorFlow-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim), 174 | and [TFLearn](http://tflearn.org/) provide higher-level abstractions over 175 | raw computational graphs that are useful for building neural networks. 176 | 177 | In PyTorch, the `nn` package serves this same purpose. The `nn` package defines a set of 178 | **Modules**, which are roughly equivalent to neural network layers. A Module receives 179 | input Tensors and computes output Tensors, but may also hold internal state such as 180 | Tensors containing learnable parameters. The `nn` package also defines a set of useful 181 | loss functions that are commonly used when training neural networks. 182 | 183 | In this example we use the `nn` package to implement our two-layer network: 184 | 185 | ```python 186 | :INCLUDE nn/two_layer_net_nn.py 187 | ``` 188 | 189 | 190 | ## PyTorch: optim 191 | Up to this point we have updated the weights of our models by manually mutating 192 | Tensors holding learnable parameters. This is not a huge burden 193 | for simple optimization algorithms like stochastic gradient descent, but in practice 194 | we often train neural networks using more sophisiticated optimizers like AdaGrad, 195 | RMSProp, Adam, etc. 196 | 197 | The `optim` package in PyTorch abstracts the idea of an optimization algorithm and 198 | provides implementations of commonly used optimization algorithms. 199 | 200 | In this example we will use the `nn` package to define our model as before, but we 201 | will optimize the model using the Adam algorithm provided by the `optim` package: 202 | 203 | ```python 204 | :INCLUDE nn/two_layer_net_optim.py 205 | ``` 206 | 207 | 208 | ## PyTorch: Custom nn Modules 209 | Sometimes you will want to specify models that are more complex than a sequence of 210 | existing Modules; for these cases you can define your own Modules by subclassing 211 | `nn.Module` and defining a `forward` which receives input Tensors and produces 212 | output Tensors using other modules or other autograd operations on Tensors. 213 | 214 | In this example we implement our two-layer network as a custom Module subclass: 215 | 216 | ```python 217 | :INCLUDE nn/two_layer_net_module.py 218 | ``` 219 | 220 | 221 | ## PyTorch: Control Flow + Weight Sharing 222 | As an example of dynamic graphs and weight sharing, we implement a very strange 223 | model: a fully-connected ReLU network that on each forward pass chooses a random 224 | number between 1 and 4 and uses that many hidden layers, reusing the same weights 225 | multiple times to compute the innermost hidden layers. 226 | 227 | For this model can use normal Python flow control to implement the loop, and we 228 | can implement weight sharing among the innermost layers by simply reusing the 229 | same Module multiple times when defining the forward pass. 230 | 231 | We can easily implement this model as a Module subclass: 232 | 233 | ```python 234 | :INCLUDE nn/dynamic_net.py 235 | ``` 236 | -------------------------------------------------------------------------------- /autograd/tf_two_layer_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | """ 5 | A fully-connected ReLU network with one hidden layer and no biases, trained to 6 | predict y from x by minimizing squared Euclidean distance. 7 | 8 | This implementation uses basic TensorFlow operations to set up a computational 9 | graph, then executes the graph many times to actually train the network. 10 | 11 | One of the main differences between TensorFlow and PyTorch is that TensorFlow 12 | uses static computational graphs while PyTorch uses dynamic computational 13 | graphs. 14 | 15 | In TensorFlow we first set up the computational graph, then execute the same 16 | graph many times. 17 | """ 18 | 19 | # First we set up the computational graph: 20 | 21 | # N is batch size; D_in is input dimension; 22 | # H is hidden dimension; D_out is output dimension. 23 | N, D_in, H, D_out = 64, 1000, 100, 10 24 | 25 | # Create placeholders for the input and target data; these will be filled 26 | # with real data when we execute the graph. 27 | x = tf.placeholder(tf.float32, shape=(None, D_in)) 28 | y = tf.placeholder(tf.float32, shape=(None, D_out)) 29 | 30 | # Create Variables for the weights and initialize them with random data. 31 | # A TensorFlow Variable persists its value across executions of the graph. 32 | w1 = tf.Variable(tf.random_normal((D_in, H))) 33 | w2 = tf.Variable(tf.random_normal((H, D_out))) 34 | 35 | # Forward pass: Compute the predicted y using operations on TensorFlow Tensors. 36 | # Note that this code does not actually perform any numeric operations; it 37 | # merely sets up the computational graph that we will later execute. 38 | h = tf.matmul(x, w1) 39 | h_relu = tf.maximum(h, tf.zeros(1)) 40 | y_pred = tf.matmul(h_relu, w2) 41 | 42 | # Compute loss using operations on TensorFlow Tensors 43 | loss = tf.reduce_sum((y - y_pred) ** 2.0) 44 | 45 | # Compute gradient of the loss with respect to w1 and w2. 46 | grad_w1, grad_w2 = tf.gradients(loss, [w1, w2]) 47 | 48 | # Update the weights using gradient descent. To actually update the weights 49 | # we need to evaluate new_w1 and new_w2 when executing the graph. Note that 50 | # in TensorFlow the the act of updating the value of the weights is part of 51 | # the computational graph; in PyTorch this happens outside the computational 52 | # graph. 53 | learning_rate = 1e-6 54 | new_w1 = w1.assign(w1 - learning_rate * grad_w1) 55 | new_w2 = w2.assign(w2 - learning_rate * grad_w2) 56 | 57 | # Now we have built our computational graph, so we enter a TensorFlow session to 58 | # actually execute the graph. 59 | with tf.Session() as sess: 60 | # Run the graph once to initialize the Variables w1 and w2. 61 | sess.run(tf.global_variables_initializer()) 62 | 63 | # Create numpy arrays holding the actual data for the inputs x and targets y 64 | x_value = np.random.randn(N, D_in) 65 | y_value = np.random.randn(N, D_out) 66 | for _ in range(500): 67 | # Execute the graph many times. Each time it executes we want to bind 68 | # x_value to x and y_value to y, specified with the feed_dict argument. 69 | # Each time we execute the graph we want to compute the values for loss, 70 | # new_w1, and new_w2; the values of these Tensors are returned as numpy 71 | # arrays. 72 | loss_value, _, _ = sess.run([loss, new_w1, new_w2], 73 | feed_dict={x: x_value, y: y_value}) 74 | print(loss_value) 75 | -------------------------------------------------------------------------------- /autograd/two_layer_net_autograd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer and no biases, trained to 5 | predict y from x by minimizing squared Euclidean distance. 6 | 7 | This implementation computes the forward pass using operations on PyTorch 8 | Tensors, and uses PyTorch autograd to compute gradients. 9 | 10 | When we create a PyTorch Tensor with requires_grad=True, then operations 11 | involving that Tensor will not just compute values; they will also build up 12 | a computational graph in the background, allowing us to easily backpropagate 13 | through the graph to compute gradients of some downstream (scalar) loss with 14 | respect to a Tensor. Concretely if x is a Tensor with x.requires_grad == True 15 | then after backpropagation x.grad will be another Tensor holding the gradient 16 | of x with respect to some scalar value. 17 | """ 18 | 19 | device = torch.device('cpu') 20 | # device = torch.device('cuda') # Uncomment this to run on GPU 21 | 22 | # N is batch size; D_in is input dimension; 23 | # H is hidden dimension; D_out is output dimension. 24 | N, D_in, H, D_out = 64, 1000, 100, 10 25 | 26 | # Create random Tensors to hold input and outputs 27 | x = torch.randn(N, D_in, device=device) 28 | y = torch.randn(N, D_out, device=device) 29 | 30 | # Create random Tensors for weights; setting requires_grad=True means that we 31 | # want to compute gradients for these Tensors during the backward pass. 32 | w1 = torch.randn(D_in, H, device=device, requires_grad=True) 33 | w2 = torch.randn(H, D_out, device=device, requires_grad=True) 34 | 35 | learning_rate = 1e-6 36 | for t in range(500): 37 | # Forward pass: compute predicted y using operations on Tensors. Since w1 and 38 | # w2 have requires_grad=True, operations involving these Tensors will cause 39 | # PyTorch to build a computational graph, allowing automatic computation of 40 | # gradients. Since we are no longer implementing the backward pass by hand we 41 | # don't need to keep references to intermediate values. 42 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 43 | 44 | # Compute and print loss. Loss is a Tensor of shape (), and loss.item() 45 | # is a Python number giving its value. 46 | loss = (y_pred - y).pow(2).sum() 47 | print(t, loss.item()) 48 | 49 | # Use autograd to compute the backward pass. This call will compute the 50 | # gradient of loss with respect to all Tensors with requires_grad=True. 51 | # After this call w1.grad and w2.grad will be Tensors holding the gradient 52 | # of the loss with respect to w1 and w2 respectively. 53 | loss.backward() 54 | 55 | # Update weights using gradient descent. For this step we just want to mutate 56 | # the values of w1 and w2 in-place; we don't want to build up a computational 57 | # graph for the update steps, so we use the torch.no_grad() context manager 58 | # to prevent PyTorch from building a computational graph for the updates 59 | with torch.no_grad(): 60 | w1 -= learning_rate * w1.grad 61 | w2 -= learning_rate * w2.grad 62 | 63 | # Manually zero the gradients after running the backward pass 64 | w1.grad.zero_() 65 | w2.grad.zero_() 66 | -------------------------------------------------------------------------------- /autograd/two_layer_net_custom_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer and no biases, trained to 5 | predict y from x by minimizing squared Euclidean distance. 6 | 7 | This implementation computes the forward pass using operations on PyTorch 8 | Tensors, and uses PyTorch autograd to compute gradients. 9 | 10 | In this implementation we implement our own custom autograd function to perform 11 | the ReLU function. 12 | """ 13 | 14 | class MyReLU(torch.autograd.Function): 15 | """ 16 | We can implement our own custom autograd Functions by subclassing 17 | torch.autograd.Function and implementing the forward and backward passes 18 | which operate on Tensors. 19 | """ 20 | @staticmethod 21 | def forward(ctx, x): 22 | """ 23 | In the forward pass we receive a context object and a Tensor containing the 24 | input; we must return a Tensor containing the output, and we can use the 25 | context object to cache objects for use in the backward pass. 26 | """ 27 | ctx.save_for_backward(x) 28 | return x.clamp(min=0) 29 | 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | """ 33 | In the backward pass we receive the context object and a Tensor containing 34 | the gradient of the loss with respect to the output produced during the 35 | forward pass. We can retrieve cached data from the context object, and must 36 | compute and return the gradient of the loss with respect to the input to the 37 | forward function. 38 | """ 39 | x, = ctx.saved_tensors 40 | grad_x = grad_output.clone() 41 | grad_x[x < 0] = 0 42 | return grad_x 43 | 44 | 45 | device = torch.device('cpu') 46 | # device = torch.device('cuda') # Uncomment this to run on GPU 47 | 48 | # N is batch size; D_in is input dimension; 49 | # H is hidden dimension; D_out is output dimension. 50 | N, D_in, H, D_out = 64, 1000, 100, 10 51 | 52 | # Create random Tensors to hold input and output 53 | x = torch.randn(N, D_in, device=device) 54 | y = torch.randn(N, D_out, device=device) 55 | 56 | # Create random Tensors for weights. 57 | w1 = torch.randn(D_in, H, device=device, requires_grad=True) 58 | w2 = torch.randn(H, D_out, device=device, requires_grad=True) 59 | 60 | learning_rate = 1e-6 61 | for t in range(500): 62 | # Forward pass: compute predicted y using operations on Tensors; we call our 63 | # custom ReLU implementation using the MyReLU.apply function 64 | y_pred = MyReLU.apply(x.mm(w1)).mm(w2) 65 | 66 | # Compute and print loss 67 | loss = (y_pred - y).pow(2).sum() 68 | print(t, loss.item()) 69 | 70 | # Use autograd to compute the backward pass. 71 | loss.backward() 72 | 73 | with torch.no_grad(): 74 | # Update weights using gradient descent 75 | w1 -= learning_rate * w1.grad 76 | w2 -= learning_rate * w2.grad 77 | 78 | # Manually zero the gradients after running the backward pass 79 | w1.grad.zero_() 80 | w2.grad.zero_() 81 | 82 | -------------------------------------------------------------------------------- /build_readme.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | """ 4 | GitHub doesn't provide an include mechanism for README files so we have to 5 | implement our own. 6 | """ 7 | 8 | def main(): 9 | build_readme('README_raw.md', 'README.md') 10 | for d in os.listdir('.'): 11 | if not os.path.isdir(d) or d.startswith('.'): 12 | continue 13 | in_path = os.path.join(d, 'README_raw.md') 14 | out_path = os.path.join(d, 'README.md') 15 | build_readme(in_path, out_path) 16 | 17 | 18 | def build_readme(in_path, out_path): 19 | if not os.path.isfile(in_path): 20 | return 21 | with open(in_path, 'r') as fin, open(out_path, 'w') as fout: 22 | for line in fin: 23 | if not line.startswith(':INCLUDE'): 24 | fout.write('%s' % line) 25 | else: 26 | include_path = line.split(' ')[1].strip() 27 | include_path = os.path.join(os.path.split(in_path)[0], include_path) 28 | fout.write('# Code in file %s\n' % include_path) 29 | skip_toggle = False 30 | skip_next = False 31 | with open(include_path, 'r') as finc: 32 | for ll in finc: 33 | if ll.startswith('"""'): 34 | skip_next = skip_toggle 35 | skip_toggle = not skip_toggle 36 | elif not skip_toggle and not skip_next: 37 | fout.write(ll) 38 | elif skip_next: 39 | skip_next = False 40 | 41 | if __name__ == '__main__': 42 | main() 43 | 44 | -------------------------------------------------------------------------------- /nn/dynamic_net.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | """ 5 | To showcase the power of PyTorch dynamic graphs, we will implement a very strange 6 | model: a fully-connected ReLU network that on each forward pass randomly chooses 7 | a number between 1 and 4 and has that many hidden layers, reusing the same 8 | weights multiple times to compute the innermost hidden layers. 9 | """ 10 | 11 | class DynamicNet(torch.nn.Module): 12 | def __init__(self, D_in, H, D_out): 13 | """ 14 | In the constructor we construct three nn.Linear instances that we will use 15 | in the forward pass. 16 | """ 17 | super(DynamicNet, self).__init__() 18 | self.input_linear = torch.nn.Linear(D_in, H) 19 | self.middle_linear = torch.nn.Linear(H, H) 20 | self.output_linear = torch.nn.Linear(H, D_out) 21 | 22 | def forward(self, x): 23 | """ 24 | For the forward pass of the model, we randomly choose either 0, 1, 2, or 3 25 | and reuse the middle_linear Module that many times to compute hidden layer 26 | representations. 27 | 28 | Since each forward pass builds a dynamic computation graph, we can use normal 29 | Python control-flow operators like loops or conditional statements when 30 | defining the forward pass of the model. 31 | 32 | Here we also see that it is perfectly safe to reuse the same Module many 33 | times when defining a computational graph. This is a big improvement from Lua 34 | Torch, where each Module could be used only once. 35 | """ 36 | h_relu = self.input_linear(x).clamp(min=0) 37 | for _ in range(random.randint(0, 3)): 38 | h_relu = self.middle_linear(h_relu).clamp(min=0) 39 | y_pred = self.output_linear(h_relu) 40 | return y_pred 41 | 42 | 43 | # N is batch size; D_in is input dimension; 44 | # H is hidden dimension; D_out is output dimension. 45 | N, D_in, H, D_out = 64, 1000, 100, 10 46 | 47 | # Create random Tensors to hold inputs and outputs. 48 | x = torch.randn(N, D_in) 49 | y = torch.randn(N, D_out) 50 | 51 | # Construct our model by instantiating the class defined above 52 | model = DynamicNet(D_in, H, D_out) 53 | 54 | # Construct our loss function and an Optimizer. Training this strange model with 55 | # vanilla stochastic gradient descent is tough, so we use momentum 56 | criterion = torch.nn.MSELoss(reduction='sum') 57 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) 58 | for t in range(500): 59 | # Forward pass: Compute predicted y by passing x to the model 60 | y_pred = model(x) 61 | 62 | # Compute and print loss 63 | loss = criterion(y_pred, y) 64 | print(t, loss.item()) 65 | 66 | # Zero gradients, perform a backward pass, and update the weights. 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | -------------------------------------------------------------------------------- /nn/two_layer_net_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer, trained to predict y from x 5 | by minimizing squared Euclidean distance. 6 | 7 | This implementation defines the model as a custom Module subclass. Whenever you 8 | want a model more complex than a simple sequence of existing Modules you will 9 | need to define your model this way. 10 | """ 11 | 12 | class TwoLayerNet(torch.nn.Module): 13 | def __init__(self, D_in, H, D_out): 14 | """ 15 | In the constructor we instantiate two nn.Linear modules and assign them as 16 | member variables. 17 | """ 18 | super(TwoLayerNet, self).__init__() 19 | self.linear1 = torch.nn.Linear(D_in, H) 20 | self.linear2 = torch.nn.Linear(H, D_out) 21 | 22 | def forward(self, x): 23 | """ 24 | In the forward function we accept a Tensor of input data and we must return 25 | a Tensor of output data. We can use Modules defined in the constructor as 26 | well as arbitrary (differentiable) operations on Tensors. 27 | """ 28 | h_relu = self.linear1(x).clamp(min=0) 29 | y_pred = self.linear2(h_relu) 30 | return y_pred 31 | 32 | # N is batch size; D_in is input dimension; 33 | # H is hidden dimension; D_out is output dimension. 34 | N, D_in, H, D_out = 64, 1000, 100, 10 35 | 36 | # Create random Tensors to hold inputs and outputs 37 | x = torch.randn(N, D_in) 38 | y = torch.randn(N, D_out) 39 | 40 | # Construct our model by instantiating the class defined above. 41 | model = TwoLayerNet(D_in, H, D_out) 42 | 43 | # Construct our loss function and an Optimizer. The call to model.parameters() 44 | # in the SGD constructor will contain the learnable parameters of the two 45 | # nn.Linear modules which are members of the model. 46 | loss_fn = torch.nn.MSELoss(reduction='sum') 47 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) 48 | for t in range(500): 49 | # Forward pass: Compute predicted y by passing x to the model 50 | y_pred = model(x) 51 | 52 | # Compute and print loss 53 | loss = loss_fn(y_pred, y) 54 | print(t, loss.item()) 55 | 56 | # Zero gradients, perform a backward pass, and update the weights. 57 | optimizer.zero_grad() 58 | loss.backward() 59 | optimizer.step() 60 | 61 | -------------------------------------------------------------------------------- /nn/two_layer_net_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer, trained to predict y from x 5 | by minimizing squared Euclidean distance. 6 | 7 | This implementation uses the nn package from PyTorch to build the network. 8 | PyTorch autograd makes it easy to define computational graphs and take gradients, 9 | but raw autograd can be a bit too low-level for defining complex neural networks; 10 | this is where the nn package can help. The nn package defines a set of Modules, 11 | which you can think of as a neural network layer that has produces output from 12 | input and may have some trainable weights or other state. 13 | """ 14 | 15 | device = torch.device('cpu') 16 | # device = torch.device('cuda') # Uncomment this to run on GPU 17 | 18 | # N is batch size; D_in is input dimension; 19 | # H is hidden dimension; D_out is output dimension. 20 | N, D_in, H, D_out = 64, 1000, 100, 10 21 | 22 | # Create random Tensors to hold inputs and outputs 23 | x = torch.randn(N, D_in, device=device) 24 | y = torch.randn(N, D_out, device=device) 25 | 26 | # Use the nn package to define our model as a sequence of layers. nn.Sequential 27 | # is a Module which contains other Modules, and applies them in sequence to 28 | # produce its output. Each Linear Module computes output from input using a 29 | # linear function, and holds internal Tensors for its weight and bias. 30 | # After constructing the model we use the .to() method to move it to the 31 | # desired device. 32 | model = torch.nn.Sequential( 33 | torch.nn.Linear(D_in, H), 34 | torch.nn.ReLU(), 35 | torch.nn.Linear(H, D_out), 36 | ).to(device) 37 | 38 | # The nn package also contains definitions of popular loss functions; in this 39 | # case we will use Mean Squared Error (MSE) as our loss function. Setting 40 | # reduction='sum' means that we are computing the *sum* of squared errors rather 41 | # than the mean; this is for consistency with the examples above where we 42 | # manually compute the loss, but in practice it is more common to use mean 43 | # squared error as a loss by setting reduction='elementwise_mean'. 44 | loss_fn = torch.nn.MSELoss(reduction='sum') 45 | 46 | learning_rate = 1e-4 47 | for t in range(500): 48 | # Forward pass: compute predicted y by passing x to the model. Module objects 49 | # override the __call__ operator so you can call them like functions. When 50 | # doing so you pass a Tensor of input data to the Module and it produces 51 | # a Tensor of output data. 52 | y_pred = model(x) 53 | 54 | # Compute and print loss. We pass Tensors containing the predicted and true 55 | # values of y, and the loss function returns a Tensor containing the loss. 56 | loss = loss_fn(y_pred, y) 57 | print(t, loss.item()) 58 | 59 | # Zero the gradients before running the backward pass. 60 | model.zero_grad() 61 | 62 | # Backward pass: compute gradient of the loss with respect to all the learnable 63 | # parameters of the model. Internally, the parameters of each Module are stored 64 | # in Tensors with requires_grad=True, so this call will compute gradients for 65 | # all learnable parameters in the model. 66 | loss.backward() 67 | 68 | # Update the weights using gradient descent. Each parameter is a Tensor, so 69 | # we can access its data and gradients like we did before. 70 | with torch.no_grad(): 71 | for param in model.parameters(): 72 | param.data -= learning_rate * param.grad 73 | -------------------------------------------------------------------------------- /nn/two_layer_net_optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer, trained to predict y from x 5 | by minimizing squared Euclidean distance. 6 | 7 | This implementation uses the nn package from PyTorch to build the network. 8 | 9 | Rather than manually updating the weights of the model as we have been doing, 10 | we use the optim package to define an Optimizer that will update the weights 11 | for us. The optim package defines many optimization algorithms that are commonly 12 | used for deep learning, including SGD+momentum, RMSProp, Adam, etc. 13 | """ 14 | 15 | # N is batch size; D_in is input dimension; 16 | # H is hidden dimension; D_out is output dimension. 17 | N, D_in, H, D_out = 64, 1000, 100, 10 18 | 19 | # Create random Tensors to hold inputs and outputs. 20 | x = torch.randn(N, D_in) 21 | y = torch.randn(N, D_out) 22 | 23 | # Use the nn package to define our model and loss function. 24 | model = torch.nn.Sequential( 25 | torch.nn.Linear(D_in, H), 26 | torch.nn.ReLU(), 27 | torch.nn.Linear(H, D_out), 28 | ) 29 | loss_fn = torch.nn.MSELoss(reduction='sum') 30 | 31 | # Use the optim package to define an Optimizer that will update the weights of 32 | # the model for us. Here we will use Adam; the optim package contains many other 33 | # optimization algoriths. The first argument to the Adam constructor tells the 34 | # optimizer which Tensors it should update. 35 | learning_rate = 1e-4 36 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 37 | for t in range(500): 38 | # Forward pass: compute predicted y by passing x to the model. 39 | y_pred = model(x) 40 | 41 | # Compute and print loss. 42 | loss = loss_fn(y_pred, y) 43 | print(t, loss.item()) 44 | 45 | # Before the backward pass, use the optimizer object to zero all of the 46 | # gradients for the Tensors it will update (which are the learnable weights 47 | # of the model) 48 | optimizer.zero_grad() 49 | 50 | # Backward pass: compute gradient of the loss with respect to model parameters 51 | loss.backward() 52 | 53 | # Calling the step function on an Optimizer makes an update to its parameters 54 | optimizer.step() 55 | -------------------------------------------------------------------------------- /tensor/two_layer_net_numpy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer and no biases, trained to 5 | predict y from x using Euclidean error. 6 | 7 | This implementation uses numpy to manually compute the forward pass, loss, and 8 | backward pass. 9 | 10 | A numpy array is a generic n-dimensional array; it does not know anything about 11 | deep learning or gradients or computational graphs, and is just a way to perform 12 | generic numeric computations. 13 | """ 14 | 15 | # N is batch size; D_in is input dimension; 16 | # H is hidden dimension; D_out is output dimension. 17 | N, D_in, H, D_out = 64, 1000, 100, 10 18 | 19 | # Create random input and output data 20 | x = np.random.randn(N, D_in) 21 | y = np.random.randn(N, D_out) 22 | 23 | # Randomly initialize weights 24 | w1 = np.random.randn(D_in, H) 25 | w2 = np.random.randn(H, D_out) 26 | 27 | learning_rate = 1e-6 28 | for t in range(500): 29 | # Forward pass: compute predicted y 30 | h = x.dot(w1) 31 | h_relu = np.maximum(h, 0) 32 | y_pred = h_relu.dot(w2) 33 | 34 | # Compute and print loss 35 | loss = np.square(y_pred - y).sum() 36 | print(t, loss) 37 | 38 | # Backprop to compute gradients of w1 and w2 with respect to loss 39 | grad_y_pred = 2.0 * (y_pred - y) 40 | grad_w2 = h_relu.T.dot(grad_y_pred) 41 | grad_h_relu = grad_y_pred.dot(w2.T) 42 | grad_h = grad_h_relu.copy() 43 | grad_h[h < 0] = 0 44 | grad_w1 = x.T.dot(grad_h) 45 | 46 | # Update weights 47 | w1 -= learning_rate * grad_w1 48 | w2 -= learning_rate * grad_w2 49 | -------------------------------------------------------------------------------- /tensor/two_layer_net_tensor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ 4 | A fully-connected ReLU network with one hidden layer and no biases, trained to 5 | predict y from x by minimizing squared Euclidean distance. 6 | 7 | This implementation uses PyTorch tensors to manually compute the forward pass, 8 | loss, and backward pass. 9 | 10 | A PyTorch Tensor is basically the same as a numpy array: it does not know 11 | anything about deep learning or computational graphs or gradients, and is just 12 | a generic n-dimensional array to be used for arbitrary numeric computation. 13 | 14 | The biggest difference between a numpy array and a PyTorch Tensor is that 15 | a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU, 16 | just pass a different value to the `device` argument when constructing the 17 | Tensor. 18 | """ 19 | 20 | device = torch.device('cpu') 21 | # device = torch.device('cuda') # Uncomment this to run on GPU 22 | 23 | # N is batch size; D_in is input dimension; 24 | # H is hidden dimension; D_out is output dimension. 25 | N, D_in, H, D_out = 64, 1000, 100, 10 26 | 27 | # Create random input and output data 28 | x = torch.randn(N, D_in, device=device) 29 | y = torch.randn(N, D_out, device=device) 30 | 31 | # Randomly initialize weights 32 | w1 = torch.randn(D_in, H, device=device) 33 | w2 = torch.randn(H, D_out, device=device) 34 | 35 | learning_rate = 1e-6 36 | for t in range(500): 37 | # Forward pass: compute predicted y 38 | h = x.mm(w1) 39 | h_relu = h.clamp(min=0) 40 | y_pred = h_relu.mm(w2) 41 | 42 | # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor 43 | # of shape (); we can get its value as a Python number with loss.item(). 44 | loss = (y_pred - y).pow(2).sum() 45 | print(t, loss.item()) 46 | 47 | # Backprop to compute gradients of w1 and w2 with respect to loss 48 | grad_y_pred = 2.0 * (y_pred - y) 49 | grad_w2 = h_relu.t().mm(grad_y_pred) 50 | grad_h_relu = grad_y_pred.mm(w2.t()) 51 | grad_h = grad_h_relu.clone() 52 | grad_h[h < 0] = 0 53 | grad_w1 = x.t().mm(grad_h) 54 | 55 | # Update weights using gradient descent 56 | w1 -= learning_rate * grad_w1 57 | w2 -= learning_rate * grad_w2 58 | --------------------------------------------------------------------------------