└── README.md
/README.md:
--------------------------------------------------------------------------------
1 | # Effective PyTorch
2 |
3 | Table of Contents
4 | =================
5 | ## Part I: PyTorch Fundamentals
6 | 1. [PyTorch basics](#basics)
7 | 2. [Encapsulate your model with Modules](#modules)
8 | 3. [Broadcasting the good and the ugly](#broadcast)
9 | 4. [Take advantage of the overloaded operators](#overloaded_ops)
10 | 5. [Optimizing runtime with TorchScript](#torchscript)
11 | 6. [Building efficient custom data loaders](#dataloader)
12 | 7. [Numerical stability in PyTorch](#stable)
13 | 8. [Faster training with automatic mixed precision](#amp)
14 | ---
15 |
16 | _To install PyTorch follow the [instructions on the official website](https://pytorch.org/):_
17 | ```
18 | pip install torch torchvision
19 | ```
20 |
21 | _We aim to gradually expand this series by adding new articles and keep the content up to date with the latest releases of PyTorch API. If you have suggestions on how to improve this series or find the explanations ambiguous, feel free to create an issue, send patches, or reach out by email._
22 |
23 | # Part I: PyTorch Fundamentals
24 |
25 |
26 | ## PyTorch basics
27 |
28 | PyTorch is one of the most popular libraries for numerical computation and currently is amongst the most widely used libraries for performing machine learning research. In many ways PyTorch is similar to NumPy, with the additional benefit that PyTorch allows you to perform your computations on CPUs, GPUs, and TPUs without any material change to your code. PyTorch also makes it easy to distribute your computation across multiple devices or machines. One of the most important features of PyTorch is automatic differentiation. It allows computing the gradients of your functions analytically in an efficient manner which is crucial for training machine learning models using gradient descent method. Our goal here is to provide a gentle introduction to PyTorch and discuss best practices for using PyTorch.
29 |
30 | The first thing to learn about PyTorch is the concept of Tensors. Tensors are simply multidimensional arrays. A PyTorch Tensor is very similar to a NumPy array with some ~~magical~~ additional functionality.
31 |
32 | A tensor can store a scalar value:
33 | ```python
34 | import torch
35 | a = torch.tensor(3)
36 | print(a) # tensor(3)
37 | ```
38 |
39 | or an array:
40 | ```python
41 | b = torch.tensor([1, 2])
42 | print(b) # tensor([1, 2])
43 | ```
44 |
45 | a matrix:
46 | ```python
47 | c = torch.zeros([2, 2])
48 | print(c) # tensor([[0., 0.], [0., 0.]])
49 | ```
50 |
51 | or any arbitrary dimensional tensor:
52 | ```python
53 | d = torch.rand([2, 2, 2])
54 | ```
55 |
56 | Tensors can be used to perform algebraic operations efficiently. One of the most commonly used operations in machine learning applications is matrix multiplication. Say you want to multiply two random matrices of size 3x5 and 5x4, this can be done with the matrix multiplication (@) operation:
57 | ```python
58 | import torch
59 |
60 | x = torch.randn([3, 5])
61 | y = torch.randn([5, 4])
62 | z = x @ y
63 |
64 | print(z)
65 | ```
66 |
67 | Similarly, to add two vectors, you can do:
68 | ```python
69 | z = x + y
70 | ```
71 |
72 | To convert a tensor into a numpy array you can call Tensor's numpy() method:
73 | ```python
74 | print(z.numpy())
75 | ```
76 |
77 | And you can always convert a numpy array into a tensor by:
78 | ```python
79 | x = torch.tensor(np.random.normal([3, 5]))
80 | ```
81 |
82 | ### Automatic differentiation
83 |
84 | The most important advantage of PyTorch over NumPy is its automatic differentiation functionality which is very useful in optimization applications such as optimizing parameters of a neural network. Let's try to understand it with an example.
85 |
86 | Say you have a composite function which is a chain of two functions: `g(u(x))`.
87 | To compute the derivative of `g` with respect to `x` we can use the chain rule which states that: `dg/dx = dg/du * du/dx`. PyTorch can analytically compute the derivatives for us.
88 |
89 | To compute the derivatives in PyTorch first we create a tensor and set its `requires_grad` to true. We can use tensor operations to define our functions. We assume `u` is a quadratic function and `g` is a simple linear function:
90 | ```python
91 | x = torch.tensor(1.0, requires_grad=True)
92 |
93 | def u(x):
94 | return x * x
95 |
96 | def g(u):
97 | return -u
98 | ```
99 |
100 | In this case our composite function is `g(u(x)) = -x*x`. So its derivative with respect to `x` is `-2x`. At point `x=1`, this is equal to `-2`.
101 |
102 | Let's verify this. This can be done using grad function in PyTorch:
103 | ```python
104 | dgdx = torch.autograd.grad(g(u(x)), x)[0]
105 | print(dgdx) # tensor(-2.)
106 | ```
107 |
108 | ### Curve fitting
109 |
110 | To understand how powerful automatic differentiation can be let's have a look at another example. Assume that we have samples from a curve (say `f(x) = 5x^2 + 3`) and we want to estimate `f(x)` based on these samples. We define a parametric function `g(x, w) = w0 x^2 + w1 x + w2`, which is a function of the input `x` and latent parameters `w`, our goal is then to find the latent parameters such that `g(x, w) ≈ f(x)`. This can be done by minimizing the following loss function: `L(w) = Σ (f(x) - g(x, w))^2`. Although there's a closed form solution for this simple problem, we opt to use a more general approach that can be applied to any arbitrary differentiable function, and that is using stochastic gradient descent. We simply compute the average gradient of `L(w)` with respect to `w` over a set of sample points and move in the opposite direction.
111 |
112 | Here's how it can be done in PyTorch:
113 |
114 | ```python
115 | import numpy as np
116 | import torch
117 |
118 | # Assuming we know that the desired function is a polynomial of 2nd degree, we
119 | # allocate a vector of size 3 to hold the coefficients and initialize it with
120 | # random noise.
121 | w = torch.tensor(torch.randn([3, 1]), requires_grad=True)
122 |
123 | # We use the Adam optimizer with learning rate set to 0.1 to minimize the loss.
124 | opt = torch.optim.Adam([w], 0.1)
125 |
126 | def model(x):
127 | # We define yhat to be our estimate of y.
128 | f = torch.stack([x * x, x, torch.ones_like(x)], 1)
129 | yhat = torch.squeeze(f @ w, 1)
130 | return yhat
131 |
132 | def compute_loss(y, yhat):
133 | # The loss is defined to be the mean squared error distance between our
134 | # estimate of y and its true value.
135 | loss = torch.nn.functional.mse_loss(yhat, y)
136 | return loss
137 |
138 | def generate_data():
139 | # Generate some training data based on the true function
140 | x = torch.rand(100) * 20 - 10
141 | y = 5 * x * x + 3
142 | return x, y
143 |
144 | def train_step():
145 | x, y = generate_data()
146 |
147 | yhat = model(x)
148 | loss = compute_loss(y, yhat)
149 |
150 | opt.zero_grad()
151 | loss.backward()
152 | opt.step()
153 |
154 | for _ in range(1000):
155 | train_step()
156 |
157 | print(w.detach().numpy())
158 | ```
159 | By running this piece of code you should see a result close to this:
160 | ```python
161 | [4.9924135, 0.00040895029, 3.4504161]
162 | ```
163 | Which is a relatively close approximation to our parameters.
164 |
165 | This is just tip of the iceberg for what PyTorch can do. Many problems such as optimizing large neural networks with millions of parameters can be implemented efficiently in PyTorch in just a few lines of code. PyTorch takes care of scaling across multiple devices, and threads, and supports a variety of platforms.
166 |
167 | ## Encapsulate your model with Modules
168 |
169 | In the previous example we used bare bone tensors and tensor operations to build our model. To make your code slightly more organized it's recommended to use PyTorch's modules. A module is simply a container for your parameters and encapsulates model operations. For example say you want to represent a linear model `y = ax + b`. This model can be represented with the following code:
170 |
171 | ```python
172 | import torch
173 |
174 | class Net(torch.nn.Module):
175 |
176 | def __init__(self):
177 | super().__init__()
178 | self.a = torch.nn.Parameter(torch.rand(1))
179 | self.b = torch.nn.Parameter(torch.rand(1))
180 |
181 | def forward(self, x):
182 | yhat = self.a * x + self.b
183 | return yhat
184 | ```
185 |
186 | To use this model in practice you instantiate the module and simply call it like a function:
187 | ```python
188 | x = torch.arange(100, dtype=torch.float32)
189 |
190 | net = Net()
191 | y = net(x)
192 | ```
193 |
194 | Parameters are essentially tensors with `requires_grad` set to true. It's convenient to use parameters because you can simply retrieve them all with module's `parameters()` method:
195 | ```python
196 | for p in net.parameters():
197 | print(p)
198 | ```
199 |
200 | Now, say you have an unknown function `y = 5x + 3 + some noise`, and you want to optimize the parameters of your model to fit this function. You can start by sampling some points from your function:
201 | ```python
202 | x = torch.arange(100, dtype=torch.float32) / 100
203 | y = 5 * x + 3 + torch.rand(100) * 0.3
204 | ```
205 |
206 | Similar to the previous example, you can define a loss function and optimize the parameters of your model as follows:
207 | ```python
208 | criterion = torch.nn.MSELoss()
209 | optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
210 |
211 | for i in range(10000):
212 | net.zero_grad()
213 | yhat = net(x)
214 | loss = criterion(yhat, y)
215 | loss.backward()
216 | optimizer.step()
217 |
218 | print(net.a, net.b) # Should be close to 5 and 3
219 | ```
220 |
221 | PyTorch comes with a number of predefined modules. One such module is `torch.nn.Linear` which is a more general form of a linear function than what we defined above. We can rewrite our module above using `torch.nn.Linear` like this:
222 |
223 | ```python
224 | class Net(torch.nn.Module):
225 |
226 | def __init__(self):
227 | super().__init__()
228 | self.linear = torch.nn.Linear(1, 1)
229 |
230 | def forward(self, x):
231 | yhat = self.linear(x.unsqueeze(1)).squeeze(1)
232 | return yhat
233 | ```
234 |
235 | Note that we used squeeze and unsqueeze since `torch.nn.Linear` operates on batch of vectors as opposed to scalars.
236 |
237 | By default calling parameters() on a module will return the parameters of all its submodules:
238 | ```python
239 | net = Net()
240 | for p in net.parameters():
241 | print(p)
242 | ```
243 |
244 | There are some predefined modules that act as a container for other modules. The most commonly used container module is `torch.nn.Sequential`. As its name implies it's used to to stack multiple modules (or layers) on top of each other. For example to stack two Linear layers with a `ReLU` nonlinearity in between you can do:
245 |
246 | ```python
247 | model = torch.nn.Sequential(
248 | torch.nn.Linear(64, 32),
249 | torch.nn.ReLU(),
250 | torch.nn.Linear(32, 10),
251 | )
252 | ```
253 |
254 | ## Broadcasting the good and the ugly
255 |
256 | PyTorch supports broadcasting elementwise operations. Normally when you want to perform operations like addition and multiplication, you need to make sure that shapes of the operands match, e.g. you can’t add a tensor of shape `[3, 2]` to a tensor of shape `[3, 4]`. But there’s a special case and that’s when you have a singular dimension. PyTorch implicitly tiles the tensor across its singular dimensions to match the shape of the other operand. So it’s valid to add a tensor of shape `[3, 2]` to a tensor of shape `[3, 1]`.
257 |
258 | ```python
259 | import torch
260 |
261 | a = torch.tensor([[1., 2.], [3., 4.]])
262 | b = torch.tensor([[1.], [2.]])
263 | # c = a + b.repeat([1, 2])
264 | c = a + b
265 |
266 | print(c)
267 | ```
268 |
269 | Broadcasting allows us to perform implicit tiling which makes the code shorter, and more memory efficient, since we don’t need to store the result of the tiling operation. One neat place that this can be used is when combining features of varying length. In order to concatenate features of varying length we commonly tile the input tensors, concatenate the result and apply some nonlinearity. This is a common pattern across a variety of neural network architectures:
270 |
271 | ```python
272 | a = torch.rand([5, 3, 5])
273 | b = torch.rand([5, 1, 6])
274 |
275 | linear = torch.nn.Linear(11, 10)
276 |
277 | # concat a and b and apply nonlinearity
278 | tiled_b = b.repeat([1, 3, 1])
279 | c = torch.cat([a, tiled_b], 2)
280 | d = torch.nn.functional.relu(linear(c))
281 |
282 | print(d.shape) # torch.Size([5, 3, 10])
283 | ```
284 |
285 | But this can be done more efficiently with broadcasting. We use the fact that `f(m(x + y))` is equal to `f(mx + my)`. So we can do the linear operations separately and use broadcasting to do implicit concatenation:
286 |
287 | ```python
288 | a = torch.rand([5, 3, 5])
289 | b = torch.rand([5, 1, 6])
290 |
291 | linear1 = torch.nn.Linear(5, 10)
292 | linear2 = torch.nn.Linear(6, 10)
293 |
294 | pa = linear1(a)
295 | pb = linear2(b)
296 | d = torch.nn.functional.relu(pa + pb)
297 |
298 | print(d.shape) # torch.Size([5, 3, 10])
299 | ```
300 |
301 | In fact this piece of code is pretty general and can be applied to tensors of arbitrary shape as long as broadcasting between tensors is possible:
302 |
303 | ```python
304 | class Merge(torch.nn.Module):
305 | def __init__(self, in_features1, in_features2, out_features, activation=None):
306 | super().__init__()
307 | self.linear1 = torch.nn.Linear(in_features1, out_features)
308 | self.linear2 = torch.nn.Linear(in_features2, out_features)
309 | self.activation = activation
310 |
311 | def forward(self, a, b):
312 | pa = self.linear1(a)
313 | pb = self.linear2(b)
314 | c = pa + pb
315 | if self.activation is not None:
316 | c = self.activation(c)
317 | return c
318 | ```
319 |
320 | So far we discussed the good part of broadcasting. But what’s the ugly part you may ask? Implicit assumptions almost always make debugging harder to do. Consider the following example:
321 |
322 | ```python
323 | a = torch.tensor([[1.], [2.]])
324 | b = torch.tensor([1., 2.])
325 | c = torch.sum(a + b)
326 |
327 | print(c)
328 | ```
329 |
330 | What do you think the value of `c` would be after evaluation? If you guessed 6, that’s wrong. It’s going to be 12. This is because when rank of two tensors don’t match, PyTorch automatically expands the first dimension of the tensor with lower rank before the elementwise operation, so the result of addition would be `[[2, 3], [3, 4]]`, and the reducing over all parameters would give us 12.
331 |
332 | The way to avoid this problem is to be as explicit as possible. Had we specified which dimension we would want to reduce across, catching this bug would have been much easier:
333 |
334 | ```python
335 | a = torch.tensor([[1.], [2.]])
336 | b = torch.tensor([1., 2.])
337 | c = torch.sum(a + b, 0)
338 |
339 | print(c)
340 | ```
341 |
342 | Here the value of `c` would be `[5, 7]`, and we immediately would guess based on the shape of the result that there’s something wrong. A general rule of thumb is to always specify the dimensions in reduction operations and when using `torch.squeeze`.
343 |
344 | ## Take advantage of the overloaded operators
345 |
346 | Just like NumPy, PyTorch overloads a number of python operators to make PyTorch code shorter and more readable.
347 |
348 | The slicing op is one of the overloaded operators that can make indexing tensors very easy:
349 | ```python
350 | z = x[begin:end] # z = torch.narrow(x, 0, begin, end-begin)
351 | ```
352 | Be very careful when using this op though. The slicing op, like any other op, has some overhead. Because it's a common op and innocent looking it may get overused a lot which may lead to inefficiencies. To understand how inefficient this op can be let's look at an example. We want to manually perform reduction across the rows of a matrix:
353 | ```python
354 | import torch
355 | import time
356 |
357 | x = torch.rand([500, 10])
358 |
359 | z = torch.zeros([10])
360 |
361 | start = time.time()
362 | for i in range(500):
363 | z += x[i]
364 | print("Took %f seconds." % (time.time() - start))
365 | ```
366 | This runs quite slow and the reason is that we are calling the slice op 500 times, which adds a lot of overhead. A better choice would have been to use `torch.unbind` op to slice the matrix into a list of vectors all at once:
367 | ```python
368 | z = torch.zeros([10])
369 | for x_i in torch.unbind(x):
370 | z += x_i
371 | ```
372 | This is significantly (~30% on my machine) faster.
373 |
374 | Of course, the right way to do this simple reduction is to use `torch.sum` op to this in one op:
375 | ```python
376 | z = torch.sum(x, dim=0)
377 | ```
378 | which is extremely fast (~100x faster on my machine).
379 |
380 | PyTorch also overloads a range of arithmetic and logical operators:
381 | ```python
382 | z = -x # z = torch.neg(x)
383 | z = x + y # z = torch.add(x, y)
384 | z = x - y # z = torch.sub(x, y)
385 | z = x * y # z = torch.mul(x, y)
386 | z = x / y # z = torch.div(x, y)
387 | z = x // y # z = torch.floor_divide(x, y)
388 | z = x % y # z = torch.remainder(x, y)
389 | z = x ** y # z = torch.pow(x, y)
390 | z = x @ y # z = torch.matmul(x, y)
391 | z = x > y # z = torch.gt(x, y)
392 | z = x >= y # z = torch.ge(x, y)
393 | z = x < y # z = torch.lt(x, y)
394 | z = x <= y # z = torch.le(x, y)
395 | z = abs(x) # z = torch.abs(x)
396 | z = x & y # z = torch.bitwise_and(x, y)
397 | z = x | y # z = torch.bitwise_or(x, y)
398 | z = x ^ y # z = torch.bitwise_xor(x, y)
399 | z = ~x # z = torch.bitwise_not(x)
400 | z = x == y # z = torch.eq(x, y)
401 | z = x != y # z = torch.ne(x, y)
402 | ```
403 |
404 | You can also use the augmented version of these ops. For example `x += y` and `x **= 2` are also valid.
405 |
406 | Note that Python doesn't allow overloading `and`, `or`, and `not` keywords.
407 |
408 |
409 | ## Optimizing runtime with TorchScript
410 |
411 | PyTorch is optimized to perform operations on large tensors. Doing many operations on small tensors is quite inefficient in PyTorch. So, whenever possible you should rewrite your computations in batch form to reduce overhead and improve performance. If there's no way you can manually batch your operations, using TorchScript may improve your code's performance. TorchScript is simply a subset of Python functions that are recognized by PyTorch. PyTorch can automatically optimize your TorchScript code using its just in time (jit) compiler and reduce some overheads.
412 |
413 | Let's look at an example. A very common operation in ML applications is "batch gather". This operation can simply written as `output[i] = input[i, index[i]]`. This can be simply implemented in PyTorch as follows:
414 | ```python
415 | import torch
416 | def batch_gather(tensor, indices):
417 | output = []
418 | for i in range(tensor.size(0)):
419 | output += [tensor[i][indices[i]]]
420 | return torch.stack(output)
421 | ```
422 |
423 | To implement the same function using TorchScript simply use the `torch.jit.script` decorator:
424 | ```python
425 | @torch.jit.script
426 | def batch_gather_jit(tensor, indices):
427 | output = []
428 | for i in range(tensor.size(0)):
429 | output += [tensor[i][indices[i]]]
430 | return torch.stack(output)
431 | ```
432 | On my tests this is about 10% faster.
433 |
434 | But nothing beats manually batching your operations. A vectorized implementation in my tests is 100 times faster:
435 | ```python
436 | def batch_gather_vec(tensor, indices):
437 | shape = list(tensor.shape)
438 | flat_first = torch.reshape(
439 | tensor, [shape[0] * shape[1]] + shape[2:])
440 | offset = torch.reshape(
441 | torch.arange(shape[0]).cuda() * shape[1],
442 | [shape[0]] + [1] * (len(indices.shape) - 1))
443 | output = flat_first[indices + offset]
444 | return output
445 | ```
446 |
447 | ## Building efficient custom data loaders
448 |
449 |
450 | In the last lesson we talked about writing efficient PyTorch code. But to make your code run with maximum efficiency you also need to load your data efficiently into your device's memory. Fortunately PyTorch offers a tool to make data loading easy. It's called a `DataLoader`. A `DataLoader` uses multiple workers to simultanously load data from a `Dataset` and optionally uses a `Sampler` to sample data entries and form a batch.
451 |
452 | If you can randomly access your data, using a `DataLoader` is very easy: You simply need to implement a `Dataset` class that implements `__getitem__` (to read each data item) and `__len__` (to return the number of items in the dataset) methods. For example here's how to load images from a given directory:
453 |
454 | ```python
455 | import glob
456 | import os
457 | import random
458 | import cv2
459 | import torch
460 |
461 | class ImageDirectoryDataset(torch.utils.data.Dataset):
462 | def __init__(path, pattern):
463 | self.paths = list(glob.glob(os.path.join(path, pattern)))
464 |
465 | def __len__(self):
466 | return len(self.paths)
467 |
468 | def __getitem__(self):
469 | path = random.choice(paths)
470 | return cv2.imread(path, 1)
471 | ```
472 |
473 | To load all jpeg images from a given directory you can then do the following:
474 | ```python
475 | dataloader = torch.utils.data.DataLoader(ImageDirectoryDataset("/data/imagenet/*.jpg"), num_workers=8)
476 | for data in dataloader:
477 | # do something with data
478 | ```
479 |
480 | Here we are using 8 workers to simultanously read our data from the disk. You can tune the number of workers on your machine for optimal results.
481 |
482 | Using a `DataLoader` to read data with random access may be ok if you have fast storage or if your data items are large. But imagine having a network file system with slow connection. Requesting individual files this way can be extremely slow and would probably end up becoming the bottleneck of your training pipeline.
483 |
484 | A better approach is to store your data in a contiguous file format which can be read sequentially. For example if you have a large collection of images you can use tar to create a single archive and extract files from the archive sequentially in python. To do this you can use PyTorch's `IterableDataset`. To create an `IterableDataset` class you only need to implement an `__iter__` method which sequentially reads and yields data items from the dataset.
485 |
486 | A naive implementation would like this:
487 |
488 | ```python
489 | import tarfile
490 | import torch
491 |
492 | def tar_image_iterator(path):
493 | tar = tarfile.open(self.path, "r")
494 | for tar_info in tar:
495 | file = tar.extractfile(tar_info)
496 | content = file.read()
497 | yield cv2.imdecode(content, 1)
498 | file.close()
499 | tar.members = []
500 | tar.close()
501 |
502 | class TarImageDataset(torch.utils.data.IterableDataset):
503 | def __init__(self, path):
504 | super().__init__()
505 | self.path = path
506 |
507 | def __iter__(self):
508 | yield from tar_image_iterator(self.path)
509 | ```
510 |
511 | But there's a major problem with this implementation. If you try to use DataLoader to read from this dataset with more than one worker you'd observe a lot of duplicated images:
512 |
513 | ```python
514 | dataloader = torch.utils.data.DataLoader(TarImageDataset("/data/imagenet.tar"), num_workers=8)
515 | for data in dataloader:
516 | # data contains duplicated items
517 | ```
518 |
519 | The problem is that each worker creates a separate instance of the dataset and each would start from the beginning of the dataset. One way to avoid this is to instead of having one tar file, split your data into `num_workers` separate tar files and load each with a separate worker:
520 |
521 | ```python
522 | class TarImageDataset(torch.utils.data.IterableDataset):
523 | def __init__(self, paths):
524 | super().__init__()
525 | self.paths = paths
526 |
527 | def __iter__(self):
528 | worker_info = torch.utils.data.get_worker_info()
529 | # For simplicity we assume num_workers is equal to number of tar files
530 | if worker_info is None or worker_info.num_workers != len(self.paths):
531 | raise ValueError("Number of workers doesn't match number of files.")
532 | yield from tar_image_iterator(self.paths[worker_info.worker_id])
533 | ```
534 |
535 | This is how our dataset class can be used:
536 | ```python
537 | dataloader = torch.utils.data.DataLoader(
538 | TarImageDataset(["/data/imagenet_part1.tar", "/data/imagenet_part2.tar"]), num_workers=2)
539 | for data in dataloader:
540 | # do something with data
541 | ```
542 |
543 | We discussed a simple strategy to avoid duplicated entries problem. [tfrecord](https://github.com/vahidk/tfrecord) package uses slightly more sophisticated strategies to shard your data on the fly.
544 |
545 | ## Numerical stability in PyTorch
546 |
547 | When using any numerical computation library such as NumPy or PyTorch, it's important to note that writing mathematically correct code doesn't necessarily lead to correct results. You also need to make sure that the computations are stable.
548 |
549 | Let's start with a simple example. Mathematically, it's easy to see that `x * y / y = x` for any non zero value of `x`. But let's see if that's always true in practice:
550 | ```python
551 | import numpy as np
552 |
553 | x = np.float32(1)
554 |
555 | y = np.float32(1e-50) # y would be stored as zero
556 | z = x * y / y
557 |
558 | print(z) # prints nan
559 | ```
560 |
561 | The reason for the incorrect result is that `y` is simply too small for float32 type. A similar problem occurs when `y` is too large:
562 |
563 | ```python
564 | y = np.float32(1e39) # y would be stored as inf
565 | z = x * y / y
566 |
567 | print(z) # prints nan
568 | ```
569 |
570 | The smallest positive value that float32 type can represent is 1.4013e-45 and anything below that would be stored as zero. Also, any number beyond 3.40282e+38, would be stored as inf.
571 |
572 | ```python
573 | print(np.nextafter(np.float32(0), np.float32(1))) # prints 1.4013e-45
574 | print(np.finfo(np.float32).max) # print 3.40282e+38
575 | ```
576 |
577 | To make sure that your computations are stable, you want to avoid values with small or very large absolute value. This may sound very obvious, but these kind of problems can become extremely hard to debug especially when doing gradient descent in PyTorch. This is because you not only need to make sure that all the values in the forward pass are within the valid range of your data types, but also you need to make sure of the same for the backward pass (during gradient computation).
578 |
579 | Let's look at a real example. We want to compute the softmax over a vector of logits. A naive implementation would look something like this:
580 | ```python
581 | import torch
582 |
583 | def unstable_softmax(logits):
584 | exp = torch.exp(logits)
585 | return exp / torch.sum(exp)
586 |
587 | print(unstable_softmax(torch.tensor([1000., 0.])).numpy()) # prints [ nan, 0.]
588 | ```
589 | Note that computing the exponential of logits for relatively small numbers results to gigantic results that are out of float32 range. The largest valid logit for our naive softmax implementation is `ln(3.40282e+38) = 88.7`, anything beyond that leads to a nan outcome.
590 |
591 | But how can we make this more stable? The solution is rather simple. It's easy to see that `exp(x - c) Σ exp(x - c) = exp(x) / Σ exp(x)`. Therefore we can subtract any constant from the logits and the result would remain the same. We choose this constant to be the maximum of logits. This way the domain of the exponential function would be limited to `[-inf, 0]`, and consequently its range would be `[0.0, 1.0]` which is desirable:
592 |
593 | ```python
594 | import torch
595 |
596 | def softmax(logits):
597 | exp = torch.exp(logits - torch.max(logits))
598 | return exp / torch.sum(exp)
599 |
600 | print(softmax(torch.tensor([1000., 0.])).numpy()) # prints [ 1., 0.]
601 | ```
602 |
603 | Let's look at a more complicated case. Consider we have a classification problem. We use the softmax function to produce probabilities from our logits. We then define our loss function to be the cross entropy between our predictions and the labels. Recall that cross entropy for a categorical distribution can be simply defined as `xe(p, q) = -Σ p_i log(q_i)`. So a naive implementation of the cross entropy would look like this:
604 |
605 | ```python
606 | def unstable_softmax_cross_entropy(labels, logits):
607 | logits = torch.log(softmax(logits))
608 | return -torch.sum(labels * logits)
609 |
610 | labels = torch.tensor([0.5, 0.5])
611 | logits = torch.tensor([1000., 0.])
612 |
613 | xe = unstable_softmax_cross_entropy(labels, logits)
614 |
615 | print(xe.numpy()) # prints inf
616 | ```
617 |
618 | Note that in this implementation as the softmax output approaches zero, the log's output approaches infinity which causes instability in our computation. We can rewrite this by expanding the softmax and doing some simplifications:
619 |
620 | ```python
621 | def softmax_cross_entropy(labels, logits, dim=-1):
622 | scaled_logits = logits - torch.max(logits)
623 | normalized_logits = scaled_logits - torch.logsumexp(scaled_logits, dim)
624 | return -torch.sum(labels * normalized_logits)
625 |
626 | labels = torch.tensor([0.5, 0.5])
627 | logits = torch.tensor([1000., 0.])
628 |
629 | xe = softmax_cross_entropy(labels, logits)
630 |
631 | print(xe.numpy()) # prints 500.0
632 | ```
633 |
634 | We can also verify that the gradients are also computed correctly:
635 | ```python
636 | logits.requires_grad_(True)
637 | xe = softmax_cross_entropy(labels, logits)
638 | g = torch.autograd.grad(xe, logits)[0]
639 | print(g.numpy()) # prints [0.5, -0.5]
640 | ```
641 |
642 | Let me remind again that extra care must be taken when doing gradient descent to make sure that the range of your functions as well as the gradients for each layer are within a valid range. Exponential and logarithmic functions when used naively are especially problematic because they can map small numbers to enormous ones and the other way around.
643 |
644 |
645 | ## Faster training with mixed precision
646 |
647 | By default tensors and model parameters in PyTorch are stored in 32-bit floating point precision. Training neural networks using 32-bit floats is usually stable and doesn't cause major numerical issues, however neural networks have been shown to perform quite well in 16-bit and even lower precisions. Computation in lower precisions can be significantly faster on modern GPUs. It also has the extra benefit of using less memory enabling training larger models and/or with larger batch sizes which can boost the performance further. The problem though is that training in 16 bits often becomes very unstable because the precision is usually not enough to perform some operations like accumulations.
648 |
649 | To help with this problem PyTorch supports training in mixed precision. In a nutshell mixed-precision training is done by performing some expensive operations (like convolutions and matrix multplications) in 16-bit by casting down the inputs while performing other numerically sensitive operations like accumulations in 32-bit. This way we get all the benefits of 16-bit computation without its drawbacks. Next we talk about using Autocast and GradScaler to do automatic mixed-precision training.
650 |
651 | ### Autocast
652 |
653 | `autocast` helps improve runtime performance by automatically casting down data to 16-bit for some computations. To understand how it works let's look at an example:
654 |
655 | ```python
656 | import torch
657 |
658 | x = torch.rand([32, 32]).cuda()
659 | y = torch.rand([32, 32]).cuda()
660 |
661 | with torch.amp.autocast("cuda"):
662 | a = x + y
663 | b = x @ y
664 | print(a.dtype) # prints torch.float32
665 | print(b.dtype) # prints torch.float16
666 | ```
667 |
668 | Note both `x` and `y` are 32-bit tensors, but `autocast` performs matrix multiplication in 16-bit while keeping addition operation in 32-bit. What if one of the operands is in 16-bit?
669 |
670 | ```python
671 | import torch
672 |
673 | x = torch.rand([32, 32]).cuda()
674 | y = torch.rand([32, 32]).cuda().half()
675 |
676 | with torch.amp.autocast("cuda"):
677 | a = x + y
678 | b = x @ y
679 | print(a.dtype) # prints torch.float32
680 | print(b.dtype) # prints torch.float16
681 | ```
682 |
683 | Again `autocast` and casts down the 32-bit operand to 16-bit to perform matrix multiplication, but it doesn't change the addition operation. By default, addition of two tensors in PyTorch results in a cast to higher precision.
684 |
685 | In practice, you can trust `autocast` to do the right casting to improve runtime efficiency. The important thing is to keep all your forward pass computations under `autocast` context:
686 |
687 | ```python
688 | model = ...
689 | loss_fn = ...
690 |
691 | with torch.amp.autocast("cuda"):
692 | outputs = model(inputs)
693 | loss = loss_fn(outputs, targets)
694 | ```
695 |
696 | This maybe all you need if you have a relatively stable optimization problem and if you use a relatively low learning rate. Adding this one line of extra code can reduce your training up to half on modern hardware.
697 |
698 | ### GradScalar
699 |
700 | As we mentioned in the beginning of this section, 16-bit precision may not always be enough for some computations. One particular case of interest is representing gradient values, a great portion of which are usually small values. Representing them with 16-bit floats often leads to buffer underflows (i.e. they'd be represented as zeros). This makes training neural networks very unstable. `GradScalar` is designed to resolve this issue. It takes as input your loss value and multiplies it by a large scalar, inflating gradient values, and therefore making them representable in 16-bit precision. It then scales them down during gradient update to ensure parameters are updated correctly. This is generally what `GradScalar` does. But under the hood `GradScalar` is a bit smarter than that. Inflating the gradients may actually result in overflows which is equally bad. So `GradScalar` actually monitors the gradient values and if it detects overflows it skips updates, scaling down the scalar factor according to a configurable schedule. (The default schedule usually works but you may need to adjust that for your use case.)
701 |
702 | Using `GradScalar` is very easy in practice:
703 |
704 | ```python
705 | scaler = torch.amp.GradScaler()
706 |
707 | loss = ...
708 | optimizer = ... # an instance torch.optim.Optimizer
709 |
710 | scaler.scale(loss).backward()
711 | scaler.step(optimizer)
712 | scaler.update()
713 | ```
714 |
715 | Note that we first create an instance of `GradScalar`. In training loop we call `GradScalar.scale` to scale the loss before calling backward to produce inflated gradients, we then use `GradScalar.step` which (may) update the model parameters. We then call `GradScalar.update` which performs the scalar update if needed. That's all!
716 |
717 | The following is a sample code that show cases mixed precision training on a synthetic problem of learning to generate a checkerboard from image coordinates. You can paste it on a [Google Colab](https://colab.research.google.com/), set the backend to GPU and compare the single and mixed-precision performance. Note that this is a small toy example, in practice with larger networks you may see larger boosts in performance using mixed precision.
718 |
719 | ### An Example
720 |
721 | ### Generating a checker board
722 |
723 | ```python
724 | import torch
725 | import matplotlib.pyplot as plt
726 | import time
727 |
728 | def grid(width, height):
729 | hrange = torch.arange(width).unsqueeze(0).repeat([height, 1]).div(width)
730 | vrange = torch.arange(height).unsqueeze(1).repeat([1, width]).div(height)
731 | output = torch.stack([hrange, vrange], 0)
732 | return output
733 |
734 |
735 | def checker(width, height, freq):
736 | hrange = torch.arange(width).reshape([1, width]).mul(freq / width / 2.0).fmod(1.0).gt(0.5)
737 | vrange = torch.arange(height).reshape([height, 1]).mul(freq / height / 2.0).fmod(1.0).gt(0.5)
738 | output = hrange.logical_xor(vrange).float()
739 | return output
740 |
741 | # Note the inputs are grid coordinates and the target is a checkerboard
742 | inputs = grid(512, 512).unsqueeze(0).cuda()
743 | targets = checker(512, 512, 8).unsqueeze(0).unsqueeze(1).cuda()
744 | ```
745 |
746 | ### Defining a convolutional neural network
747 |
748 | ```python
749 | class Net(torch.jit.ScriptModule):
750 | def __init__(self):
751 | super().__init__()
752 | self.net = torch.nn.Sequential(
753 | torch.nn.Conv2d(2, 256, 1),
754 | torch.nn.BatchNorm2d(256),
755 | torch.nn.ReLU(),
756 | torch.nn.Conv2d(256, 256, 1),
757 | torch.nn.BatchNorm2d(256),
758 | torch.nn.ReLU(),
759 | torch.nn.Conv2d(256, 256, 1),
760 | torch.nn.BatchNorm2d(256),
761 | torch.nn.ReLU(),
762 | torch.nn.Conv2d(256, 1, 1))
763 |
764 | @torch.jit.script_method
765 | def forward(self, x):
766 | return self.net(x)
767 | ```
768 |
769 | ### Single precision training
770 | ```python
771 | net = Net().cuda()
772 | loss_fn = torch.nn.MSELoss()
773 | opt = torch.optim.Adam(net.parameters(), 0.001)
774 |
775 | start_time = time.time()
776 |
777 | for i in range(500):
778 | opt.zero_grad()
779 | outputs = net(inputs)
780 | loss = loss_fn(outputs, targets)
781 | loss.backward()
782 | opt.step()
783 | print(loss)
784 |
785 | print(time.time() - start_time)
786 |
787 | plt.subplot(1,2,1); plt.imshow(outputs.squeeze().detach().cpu());
788 | plt.subplot(1,2,2); plt.imshow(targets.squeeze().cpu()); plt.show()
789 | ```
790 |
791 | ### Mixed precision training
792 | ```python
793 | net = Net().cuda()
794 | loss_fn = torch.nn.MSELoss()
795 | opt = torch.optim.Adam(net.parameters(), 0.001)
796 |
797 | scaler = torch.amp.GradScaler()
798 |
799 | start_time = time.time()
800 |
801 | for i in range(500):
802 | opt.zero_grad()
803 | with torch.amp.autocast("cuda"):
804 | outputs = net(inputs)
805 | loss = loss_fn(outputs, targets)
806 | scaler.scale(loss).backward()
807 | scaler.step(opt)
808 | scaler.update()
809 | print(loss)
810 |
811 | print(time.time() - start_time)
812 |
813 | plt.subplot(1,2,1); plt.imshow(outputs.squeeze().detach().cpu().float());
814 | plt.subplot(1,2,2); plt.imshow(targets.squeeze().cpu().float()); plt.show()
815 | ```
816 |
817 |
818 | ### Reference
819 | - https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html
820 |
--------------------------------------------------------------------------------