├── .gitignore
├── README.md
├── assets
├── detached_y1.png
├── groups_1.png
├── groups_2.png
└── undetached_y1.png
└── markdowns
├── IN_LN_manual_calculations.md
├── Understanding losses.md
├── backprop.md
├── basics.md
├── batchnorm.md
├── dropout.md
├── einstein_summation.md
├── groups.md
├── hooks.md
├── register_buffer.md
└── weightnorm.md
/.gitignore:
--------------------------------------------------------------------------------
1 | create_toc.py
2 | toc.md
3 | .ipynb_checkpoints/**
4 | __pycache__/**
5 | *.ipynb
6 | **/.DS_Store
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Understanding nuts and bolts of neural networks with PyTorch
2 |
3 | __(Work in progress)__
4 |
5 | This is a series of articles in an attempt to understand the witchcraft that is the neural networks (using PyTorch).
6 |
7 | ### Basic
8 |
9 | * [__Gradient descent and various ways to implement it in PyTorch__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/basics.md)
10 |
11 | * [__Dropout__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/dropout.md)
12 |
13 | * [__Batch normalization__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/batchnorm.md)
14 |
15 | * [__Backpropagation and detach__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/backprop.md)
16 |
17 | * [__Losses in PyTorch__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/Understanding%20losses.md)
18 |
19 | * [__register_buffer__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/register_buffer.md)
20 |
21 | ### Intermediate
22 |
23 | * [__Groups in convolution__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/groups.md)
24 |
25 | * [__Hooks__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/hooks.md)
26 |
27 | * [__Weight normalization__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/weightnorm.md)
28 |
29 | * [__Einstein summation__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/einstein_summation.md)
30 |
31 | * [__Instance and layer normalizations__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/IN_LN_manual_calculations.md)
--------------------------------------------------------------------------------
/assets/detached_y1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/detached_y1.png
--------------------------------------------------------------------------------
/assets/groups_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/groups_1.png
--------------------------------------------------------------------------------
/assets/groups_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/groups_2.png
--------------------------------------------------------------------------------
/assets/undetached_y1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/undetached_y1.png
--------------------------------------------------------------------------------
/markdowns/IN_LN_manual_calculations.md:
--------------------------------------------------------------------------------
1 | ```python
2 | import torch
3 | import torch.nn as nn
4 | ```
5 |
6 | ### 3D tensors
7 |
8 |
9 | ```python
10 | x = torch.arange(2*3*4).view(2,3,4).float()
11 | y = nn.InstanceNorm1d(3)(x)
12 | z = nn.LayerNorm(4, elementwise_affine=False)(x)
13 | ```
14 |
15 |
16 | ```python
17 | num = x - x.mean(-1, keepdim=True)
18 | den = (num.pow(2).mean(-1, keepdim=True) + 1e-5).sqrt()
19 | num/den
20 | ```
21 |
22 |
23 |
24 |
25 | tensor([[[-1.3416, -0.4472, 0.4472, 1.3416],
26 | [-1.3416, -0.4472, 0.4472, 1.3416],
27 | [-1.3416, -0.4472, 0.4472, 1.3416]],
28 |
29 | [[-1.3416, -0.4472, 0.4472, 1.3416],
30 | [-1.3416, -0.4472, 0.4472, 1.3416],
31 | [-1.3416, -0.4472, 0.4472, 1.3416]]])
32 |
33 |
34 |
35 |
36 | ```python
37 | z
38 | ```
39 |
40 |
41 |
42 |
43 | tensor([[[-1.3416, -0.4472, 0.4472, 1.3416],
44 | [-1.3416, -0.4472, 0.4472, 1.3416],
45 | [-1.3416, -0.4472, 0.4472, 1.3416]],
46 |
47 | [[-1.3416, -0.4472, 0.4472, 1.3416],
48 | [-1.3416, -0.4472, 0.4472, 1.3416],
49 | [-1.3416, -0.4472, 0.4472, 1.3416]]])
50 |
51 |
52 |
53 |
54 | ```python
55 | y
56 | ```
57 |
58 |
59 |
60 |
61 | tensor([[[-1.3416, -0.4472, 0.4472, 1.3416],
62 | [-1.3416, -0.4472, 0.4472, 1.3416],
63 | [-1.3416, -0.4472, 0.4472, 1.3416]],
64 |
65 | [[-1.3416, -0.4472, 0.4472, 1.3416],
66 | [-1.3416, -0.4472, 0.4472, 1.3416],
67 | [-1.3416, -0.4472, 0.4472, 1.3416]]])
68 |
69 |
70 |
71 | However, `y` and `z` are not identical (I don't know why):
72 |
73 |
74 | ```python
75 | (y - z).abs().max()
76 | ```
77 |
78 |
79 |
80 |
81 | tensor(1.1921e-06)
82 |
83 |
84 |
85 | ### 4D tensors
86 |
87 |
88 | ```python
89 | x = torch.arange(2*3*4*5).view(2,3,4,5).float()
90 | y = nn.InstanceNorm2d(3)(x)
91 | z = nn.LayerNorm(5, elementwise_affine=False)(x)
92 | z2 = nn.LayerNorm([4,5], elementwise_affine=False)(x)
93 | ```
94 |
95 |
96 | ```python
97 | num = x - x.view(2,3,-1).mean(-1, keepdim=True).unsqueeze(-1)
98 | den = (num.view(2,3,-1).pow(2).mean() + 1e-5).sqrt()
99 | (num/den)[0,0]
100 | ```
101 |
102 |
103 |
104 |
105 | tensor([[-1.6475, -1.4741, -1.3007, -1.1272, -0.9538],
106 | [-0.7804, -0.6070, -0.4336, -0.2601, -0.0867],
107 | [ 0.0867, 0.2601, 0.4336, 0.6070, 0.7804],
108 | [ 0.9538, 1.1272, 1.3007, 1.4741, 1.6475]])
109 |
110 |
111 |
112 |
113 | ```python
114 | y[0,0]
115 | ```
116 |
117 |
118 |
119 |
120 | tensor([[-1.6475, -1.4741, -1.3007, -1.1272, -0.9538],
121 | [-0.7804, -0.6070, -0.4336, -0.2601, -0.0867],
122 | [ 0.0867, 0.2601, 0.4336, 0.6070, 0.7804],
123 | [ 0.9538, 1.1272, 1.3007, 1.4741, 1.6475]])
124 |
125 |
126 |
127 |
128 | ```python
129 | z2[0,0]
130 | ```
131 |
132 |
133 |
134 |
135 | tensor([[-1.6475, -1.4741, -1.3007, -1.1272, -0.9538],
136 | [-0.7804, -0.6070, -0.4336, -0.2601, -0.0867],
137 | [ 0.0867, 0.2601, 0.4336, 0.6070, 0.7804],
138 | [ 0.9538, 1.1272, 1.3007, 1.4741, 1.6475]])
139 |
140 |
141 |
142 |
143 | ```python
144 | num = x - x.mean(-1, keepdim=True)
145 | den = (num.pow(2).mean(-1, keepdim=True) + 1e-5).sqrt()
146 | (num/den)[0,0]
147 | ```
148 |
149 |
150 |
151 |
152 | tensor([[-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
153 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
154 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
155 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142]])
156 |
157 |
158 |
159 |
160 | ```python
161 | z[0,0]
162 | ```
163 |
164 |
165 |
166 |
167 | tensor([[-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
168 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
169 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
170 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142]])
171 |
172 |
173 |
174 |
175 | ```python
176 | num = x - x.mean(-1, keepdim=True)
177 | den = (x.var(-1, keepdim=True, unbiased=False) + 1e-5).sqrt()
178 | (num/den)[0,0]
179 | ```
180 |
181 |
182 |
183 |
184 | tensor([[-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
185 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
186 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142],
187 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142]])
188 |
189 |
190 |
191 |
192 | ```python
193 |
194 | ```
195 |
--------------------------------------------------------------------------------
/markdowns/Understanding losses.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ```python
4 | import torch
5 | import torch.nn as nn
6 | ```
7 |
8 | ### 1. LogSoftmax
9 |
10 |
11 | ```python
12 | x = torch.randn(8, 5)
13 | nn.LogSoftmax(dim=1)(x)
14 | ```
15 |
16 |
17 |
18 |
19 | tensor([[-3.3124, -2.3836, -1.1339, -1.4844, -1.1303],
20 | [-1.9481, -1.3503, -2.9377, -1.4468, -1.1712],
21 | [-2.1393, -1.8147, -4.4871, -1.4645, -0.7404],
22 | [-3.1521, -2.6411, -1.9489, -2.7951, -0.3821],
23 | [-1.6078, -1.3891, -2.0410, -3.2716, -0.9610],
24 | [-1.2204, -2.6508, -1.1108, -2.0820, -1.7130],
25 | [-2.0168, -0.9498, -1.7237, -1.5997, -2.3051],
26 | [-2.3434, -1.8218, -2.2064, -0.5169, -3.3294]])
27 |
28 |
29 |
30 | ### Same as taking the softmax first and then taking log of each value
31 | * Softmax: `x.exp() / x.exp().sum(dim=1, keepdim=True)`
32 | * Log: `().log()`
33 |
34 |
35 | ```python
36 | (x.exp() / x.exp().sum(dim=1, keepdim=True)).log()
37 | ```
38 |
39 |
40 |
41 |
42 | tensor([[-3.3124, -2.3836, -1.1339, -1.4844, -1.1303],
43 | [-1.9481, -1.3503, -2.9377, -1.4468, -1.1712],
44 | [-2.1393, -1.8147, -4.4871, -1.4645, -0.7404],
45 | [-3.1521, -2.6411, -1.9489, -2.7951, -0.3821],
46 | [-1.6078, -1.3891, -2.0410, -3.2716, -0.9610],
47 | [-1.2204, -2.6508, -1.1108, -2.0820, -1.7130],
48 | [-2.0168, -0.9498, -1.7237, -1.5997, -2.3051],
49 | [-2.3434, -1.8218, -2.2064, -0.5169, -3.3294]])
50 |
51 |
52 |
53 | ### 2. NLLLoss
54 |
55 | > The negative log likelihood loss. It is useful to train a classification problem with C classes.
56 | If provided, the optional argument weight should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set.
57 |
58 | ### Simply put, (log of probability * -1)
59 |
60 |
61 | ```python
62 | targets = torch.tensor([4,2,3,4,0,0,1,2])
63 | softmax_x = x.exp() / x.exp().sum(dim=1,keepdim=True)
64 | log_likelihood_x = softmax_x.log()
65 | print(log_likelihood_x.size(), targets.size())
66 | nn.NLLLoss()(log_likelihood_x, targets)
67 | ```
68 |
69 | torch.Size([8, 5]) torch.Size([8])
70 |
71 |
72 |
73 |
74 |
75 | tensor(1.4874)
76 |
77 |
78 |
79 |
80 | ```python
81 | nn.NLLLoss()(nn.LogSoftmax(dim=1)(x), targets)
82 | ```
83 |
84 |
85 |
86 |
87 | tensor(1.4874)
88 |
89 |
90 |
91 | ### Manual calculation
92 |
93 |
94 | ```python
95 | loss = 0
96 | for i in range(8):
97 | loss += nn.LogSoftmax(dim=1)(x)[i][targets[i]]
98 |
99 | loss/8
100 | ```
101 |
102 |
103 |
104 |
105 | tensor(-1.4874)
106 |
107 |
108 |
109 | ### 3. CrossEntropyLoss
110 |
111 | This criterion combines `nn.LogSoftmax()` and `nn.NLLLoss()` in one single class.
112 |
113 |
114 | ```python
115 | nn.CrossEntropyLoss()(x, targets)
116 | ```
117 |
118 |
119 |
120 |
121 | tensor(1.4874)
122 |
123 |
124 |
125 | ### 4. BCELoss
126 |
127 | > Creates a criterion that measures the Binary Cross Entropy between the target and the output.
128 |
129 | `CrossEntropyLoss` measures how close is the probability of true class close to 1. It does not consider what other possibilities are. Thus, it works well with `Softmax`.
130 |
131 | When you want to measure how close are the probabilities of true classes close to 1 and how close are the probabilities of non-true classes close to 0, using `BCELoss` makes sense.
132 |
133 |
134 | ```python
135 | x = (torch.randn(4,3)).sigmoid_()
136 | x
137 | ```
138 |
139 |
140 |
141 |
142 | tensor([[0.4929, 0.8562, 0.5976],
143 | [0.3183, 0.7167, 0.5629],
144 | [0.4342, 0.5078, 0.7811],
145 | [0.6997, 0.6790, 0.7381]])
146 |
147 |
148 |
149 |
150 | ```python
151 | targets = torch.FloatTensor([[0,1,1],[1,1,0],[0,0,0],[1,0,0]])
152 | targets
153 | ```
154 |
155 |
156 |
157 |
158 | tensor([[0., 1., 1.],
159 | [1., 1., 0.],
160 | [0., 0., 0.],
161 | [1., 0., 0.]])
162 |
163 |
164 |
165 |
166 | ```python
167 | nn.BCELoss()(x, targets)
168 | ```
169 |
170 |
171 |
172 |
173 | tensor(0.7738)
174 |
175 |
176 |
177 | ### Manual calculation
178 |
179 |
180 | ```python
181 | def loss_per_class(p,t):
182 | if t == 0:
183 | return -1 * (1-t) * torch.log(1-p)
184 | else:
185 | return -1 * t * torch.log(p)
186 |
187 | loss = 0
188 | for index in range(x.size(0)):
189 | predicted = x[index]
190 | true = targets[index]
191 | loss += torch.FloatTensor([loss_per_class(p,t) for p,t in zip(predicted, true)]).sum()
192 |
193 | loss / (4*3)
194 | ```
195 |
196 |
197 |
198 |
199 | tensor(0.7738)
200 |
201 |
202 |
203 | ### 5. KLDivLoss
204 |
205 | As with `NLLLoss`, the input given is expected to contain _log-probabilities_. However, unlike `NLLLoss`, input is not restricted to a 2D Tensor, because the criterion is applied element-wise. The targets are given as _probabilities_ (i.e. without taking the logarithm).
206 |
207 | This criterion expects a target Tensor of the same size as the input Tensor.
208 |
209 |
210 | ```python
211 | x = torch.randn(8,5)
212 | targets = torch.randn_like(x).sigmoid_()
213 | nn.KLDivLoss()(nn.LogSoftmax(dim=1)(x), targets)
214 | ```
215 |
216 |
217 |
218 |
219 | tensor(0.8478)
220 |
221 |
222 |
223 | ### Manual calculation
224 |
225 |
226 | ```python
227 | loss = 0
228 | for i in range(x.size(0)):
229 | for j in range(x.size(1)):
230 | loss += targets[i,j] * (torch.log(targets[i,j]) - x[i,j] + x[i,:].exp().sum().log())
231 | ```
232 |
233 |
234 | ```python
235 | nn.KLDivLoss(reduction='sum')(nn.LogSoftmax(dim=1)(x), targets), loss
236 | ```
237 |
238 |
239 |
240 |
241 | (tensor(33.9101), tensor(33.9101))
242 |
243 |
244 |
245 |
246 | ```python
247 | nn.KLDivLoss()(nn.LogSoftmax(dim=1)(x), targets), loss / (x.size(0) * x.size(1))
248 | ```
249 |
250 |
251 |
252 |
253 | (tensor(0.8478), tensor(0.8478))
254 |
255 |
256 |
--------------------------------------------------------------------------------
/markdowns/backprop.md:
--------------------------------------------------------------------------------
1 | [1. Simple example](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#1-simple-example)
2 |
3 | * [Define a batch of 8 inputs](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#define-a-batch-of-8-inputs)
4 |
5 | * [Note on values of gradients:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#note-on-values-of-gradients)
6 |
7 | * [`x.sum(0)` is same as the weight gradients:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#xsum0-is-same-as-the-weight-gradients)
8 |
9 | [2. Losses calculated from two different batches in two different ways](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#2-losses-calculated-from-two-different-batches-in-two-different-ways)
10 |
11 | * [A note on values of gradients](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#a-note-on-values-of-gradients)
12 |
13 | [ 3. Using `nn.utils.clip_grad_norm` to prevent exploding gradients](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#3-using-nnutilsclip_grad_norm-to-prevent-exploding-gradients)
14 |
15 | * [`nn.utils.clip_grad_norm` returns total norm of parameter gradients _before_ they are clipped](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#nnutilsclip_grad_norm-returns-total-norm-of-parameter-gradients-before-they-are-clipped)
16 |
17 | [4. Using `detach()`](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#using-detach)
18 |
19 | * [Another example of `detach`](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#another-example-of-detach)
20 |
21 | * [Not detaching `y1`:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#not-detaching-y1)
22 |
23 | * [When is `detach()` used?](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#when-is-detach-used)
24 |
25 | ### 1. Simple example
26 |
27 | Consider a fully connected layer with 5 input neurons and 2 output neurons. Let's feed it a batch of 8 inputs, each with dimension 5:
28 |
29 |
30 | ```python
31 | import torch
32 | from torch.autograd import Variable
33 | import torch.nn as nn
34 | ```
35 |
36 | Since no gradient has been backpropogated, its weights are all `None`:
37 |
38 |
39 | ```python
40 | linear = nn.Linear(5,2)
41 | print('=== Weights and biases ===')
42 | _ = [print(p.data) for p in linear.parameters()]
43 | print('=== Gradients ===')
44 | print([p.grad for p in linear.parameters()])
45 | ```
46 |
47 | === Weights and biases ===
48 |
49 | -0.3933 0.3731 0.1929 -0.3506 0.4152
50 | -0.2833 -0.2681 -0.0346 0.2800 -0.2322
51 | [torch.FloatTensor of size 2x5]
52 |
53 |
54 | 0.4141
55 | 0.2196
56 | [torch.FloatTensor of size 2]
57 |
58 | === Gradients ===
59 | [None, None]
60 |
61 |
62 | #### Define a batch of 8 inputs
63 |
64 |
65 | ```python
66 | x = Variable(torch.randn(8,5))
67 | y = linear(x)
68 | y
69 | ```
70 |
71 |
72 |
73 |
74 | Variable containing:
75 | 1.0078 -0.4197
76 | 1.7015 -0.4344
77 | -0.1489 0.5115
78 | 0.0158 -0.3754
79 | 2.0604 -0.3813
80 | 0.0680 0.7753
81 | 1.1815 -0.1078
82 | 1.5967 0.0897
83 | [torch.FloatTensor of size 8x2]
84 |
85 |
86 |
87 | Let's assume the gradient of loss with respect to `y` is 1. We can backpropogate these gradients using `y.backward`:
88 |
89 |
90 | ```python
91 | grads = Variable(torch.ones(8,2))
92 | y.backward(grads)
93 | ```
94 |
95 | The parameters of the linear layer (weights and biases) now have non-None values:
96 |
97 |
98 | ```python
99 | print('=== Weights and biases ===')
100 | _ = [print(p.data) for p in linear.parameters()]
101 | print('=== Gradients ===')
102 | print([p.grad for p in linear.parameters()])
103 | ```
104 |
105 | === Weights and biases ===
106 |
107 | -0.3933 0.3731 0.1929 -0.3506 0.4152
108 | -0.2833 -0.2681 -0.0346 0.2800 -0.2322
109 | [torch.FloatTensor of size 2x5]
110 |
111 |
112 | 0.4141
113 | 0.2196
114 | [torch.FloatTensor of size 2]
115 |
116 | === Gradients ===
117 | [Variable containing:
118 | -0.6444 3.9640 3.1489 -1.0274 3.5395
119 | -0.6444 3.9640 3.1489 -1.0274 3.5395
120 | [torch.FloatTensor of size 2x5]
121 | , Variable containing:
122 | 8
123 | 8
124 | [torch.FloatTensor of size 2]
125 | ]
126 |
127 |
128 | #### Note on values of gradients:
129 |
130 | To make it easy to understand, let's consider only on neuron.
131 |
132 | For a given linear neuron which does the following operation ...
133 |
134 | > $y = \sum_{i=1}^n w_i x_i + b $
135 |
136 | ... the derivatives wrt bias and weights are given as:
137 |
138 | > $\frac{\partial Loss}{\partial b} = \frac{\partial Loss}{\partial y}$
139 |
140 | and
141 |
142 | > $\frac{\partial Loss}{\partial w_i} = \frac{\partial Loss}{\partial y} * x_i$
143 |
144 | Also note that these operations take place _per input sample_. In our case, there are __eight__ samples and the above operations take place eight times. For each sample, we have $\frac{\partial Loss}{\partial y} = 1$. Hence, we do the following operation eight times:
145 |
146 | $\frac{\partial Loss}{\partial b} = 1$
147 |
148 | Each time, the new gradient gets added to the existing gradient. Hence, the final value of $\frac{\partial Loss}{\partial b}$ is equal to 8.
149 |
150 | Similarly, for each sample input `i`, we have:
151 |
152 | $\frac{\partial Loss}{\partial w_i} = \frac{\partial Loss}{\partial y} * x_i = x_i$
153 |
154 | After backpropagating for each of the eight samples, the final value of $\frac{\partial Loss}{\partial w_i}$ is
155 | $\sum_{i=1}^8 x_i$
156 |
157 | Let's verify this:
158 |
159 | #### `x.sum(0)` is same as the weight gradients:
160 |
161 |
162 | ```python
163 | x.sum(0)
164 | ```
165 |
166 |
167 |
168 |
169 | Variable containing:
170 | -0.6444
171 | 3.9640
172 | 3.1489
173 | -1.0274
174 | 3.5395
175 | [torch.FloatTensor of size 5]
176 |
177 |
178 |
179 |
180 | ```python
181 | linear.weight.grad
182 | ```
183 |
184 |
185 |
186 |
187 | Variable containing:
188 | -0.6444 3.9640 3.1489 -1.0274 3.5395
189 | -0.6444 3.9640 3.1489 -1.0274 3.5395
190 | [torch.FloatTensor of size 2x5]
191 |
192 |
193 |
194 | Before we move on, make sure you've understood everything up until this point.
195 |
196 | ---
197 |
198 | ### 2. Losses calculated from two different batches in two different ways
199 |
200 | Consider two batches `x1` and `x2` of sizes 2 and 3 respectively:
201 |
202 |
203 | ```python
204 | x1 = Variable(torch.randn(2,5))
205 | x2 = Variable(torch.randn(3,5))
206 | y1, y2 = linear(x1), linear(x2)
207 | print(y1.data)
208 | print(y2.data)
209 | ```
210 |
211 |
212 | 1.3945 0.9929
213 | 0.3436 -0.1620
214 | [torch.FloatTensor of size 2x2]
215 |
216 |
217 | -0.1487 0.1283
218 | 1.3610 0.0088
219 | -0.3215 0.2808
220 | [torch.FloatTensor of size 3x2]
221 |
222 |
223 |
224 | Let's assume the target values for `y1` and `y2` are all 1. But we want to give the loss from `y2` twice the weight:
225 |
226 |
227 | ```python
228 | loss = (1 - y1).sum() + 2*(1 - y2).sum()
229 | print(loss)
230 | ```
231 |
232 | Variable containing:
233 | 10.8135
234 | [torch.FloatTensor of size 1]
235 |
236 |
237 |
238 | Make sure the gradients of our layer are reset to zero before we backprop again:
239 |
240 |
241 | ```python
242 | linear.zero_grad()
243 | [p.grad for p in linear.parameters()]
244 | ```
245 |
246 |
247 |
248 |
249 | [Variable containing:
250 | 0 0 0 0 0
251 | 0 0 0 0 0
252 | [torch.FloatTensor of size 2x5], Variable containing:
253 | 0
254 | 0
255 | [torch.FloatTensor of size 2]]
256 |
257 |
258 |
259 |
260 | ```python
261 | loss.backward()
262 | print('=== Gradients ===')
263 | print([p.grad for p in linear.parameters()])
264 | ```
265 |
266 | === Gradients ===
267 | [Variable containing:
268 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483
269 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483
270 | [torch.FloatTensor of size 2x5]
271 | , Variable containing:
272 | -8
273 | -8
274 | [torch.FloatTensor of size 2]
275 | ]
276 |
277 |
278 | #### A note on values of gradients
279 |
280 | Here, we have:
281 |
282 | > $\frac{\partial Loss}{\partial y1} = -1$. This operation takes place for each sample in `x1` (i.e. 2 times)
283 |
284 | and
285 |
286 | > $\frac{\partial Loss}{\partial y2} = -2$. This operation takes place for each sample in `x2` (i.e. 3 times)
287 |
288 | For each case we have:
289 |
290 | > $\frac{\partial Loss}{\partial b} = \frac{\partial Loss}{\partial y}$
291 |
292 | Thus the final value of $\frac{\partial Loss}{\partial b}$ is:
293 |
294 | $2 * -1 + 3 * -2 = -8$
295 |
296 | Similarly, __weight gradients are equal to negative of `x1.sum(0) + 2*x2.sum(0)`__
297 |
298 | Let's verify this:
299 |
300 |
301 | ```python
302 | -1 * (x1.sum(0) + 2*x2.sum(0))
303 | ```
304 |
305 |
306 |
307 |
308 | Variable containing:
309 | -0.0289
310 | -1.5594
311 | -0.9259
312 | -0.7791
313 | 0.6483
314 | [torch.FloatTensor of size 5]
315 |
316 |
317 |
318 |
319 | ```python
320 | linear.weight.grad
321 | ```
322 |
323 |
324 |
325 |
326 | Variable containing:
327 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483
328 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483
329 | [torch.FloatTensor of size 2x5]
330 |
331 |
332 |
333 | ---
334 |
335 | ### 3. Using `nn.utils.clip_grad_norm` to prevent exploding gradients
336 |
337 | _After_ the gradients have been backpropagated, the gradients can be clipped to keep them small. This can be done using `nn.utils.clip_grad_norm` like so:
338 |
339 |
340 | ```python
341 | linear.zero_grad()
342 | x = Variable( torch.randn(8,5) )
343 | y = linear(x)
344 | grads = Variable( torch.ones(8,2) )
345 | y.backward(grads)
346 | ```
347 |
348 |
349 | ```python
350 | [p.grad.norm() for p in linear.parameters()]
351 | ```
352 |
353 |
354 |
355 |
356 | [Variable containing:
357 | 8.7584
358 | [torch.FloatTensor of size 1], Variable containing:
359 | 11.3137
360 | [torch.FloatTensor of size 1]]
361 |
362 |
363 |
364 | #### `nn.utils.clip_grad_norm` returns total norm of parameter gradients _before_ they are clipped
365 |
366 |
367 | ```python
368 | nn.utils.clip_grad_norm(linear.parameters(), 3)
369 | ```
370 |
371 |
372 |
373 |
374 | 14.30769861897899
375 |
376 |
377 |
378 | Let's verify this:
379 |
380 |
381 | ```python
382 | (8.7584**2 + 11.3137**2)**0.5 # should be the same as the value returned by the clip_grad_norm above
383 | ```
384 |
385 |
386 |
387 |
388 | 14.307668512025291
389 |
390 |
391 |
392 | The norm of the new gradients is 3 now:
393 |
394 |
395 | ```python
396 | [p.grad.norm() for p in linear.parameters()]
397 | ```
398 |
399 |
400 |
401 |
402 | [Variable containing:
403 | 1.8364
404 | [torch.FloatTensor of size 1], Variable containing:
405 | 2.3722
406 | [torch.FloatTensor of size 1]]
407 |
408 |
409 |
410 |
411 | ```python
412 | (1.8364**2 + 2.3722**2)**0.5
413 | ```
414 |
415 |
416 |
417 |
418 | 2.999949632910526
419 |
420 |
421 |
422 | ### Using `detach()`
423 |
424 | Remember that a PyTorch `Variable` contains history of a graph (if it is a part of a graph, obviously). This includes intermediate values and backprop operations in the graph. In the example below, `y` requires a gradient and the first operation it will do during backprop is `AddmmBackward`
425 |
426 |
427 | ```python
428 | linear.zero_grad()
429 |
430 | x = Variable(torch.randn(8,5))
431 | y = linear(x)
432 | print(y.requires_grad, y.grad_fn)
433 | ```
434 |
435 | True
436 |
437 |
438 | When we create a new Variable, say `y_new` from `y` having the same data as `y`, `y_new` does not get the graph-related history from `y`. It only gets the data from `y`. Hence, in the example below, `y_new` does not require a gradient and its `grad_fn` is `None`:
439 |
440 |
441 | ```python
442 | y_new = Variable(y.data)
443 | print(y_new.requires_grad, y_new.grad_fn)
444 | ```
445 |
446 | False None
447 |
448 |
449 | You can say that `y_new` is _detached_ from the graph. When a backprop operation encounters `y_new`, it will see that its `grad_fn` is `None` and it will not continue backpropagating along the path where `y_new` was encountered.
450 |
451 | One can achieve the same effect without defining a new variable; just call the `detach()` method. There's also a `detach_()` method for in-place operation.
452 |
453 | Let's understanding it in greater detail.
454 |
455 |
456 | ```python
457 | y_detached = y.detach()
458 | print(y_detached.requires_grad, y_detached.grad_fn)
459 | ```
460 |
461 | False None
462 |
463 |
464 |
465 | ```python
466 | print(y.requires_grad, y.grad_fn)
467 | y.detach_()
468 | print(y.requires_grad, y.grad_fn)
469 |
470 | ```
471 |
472 | True
473 | False None
474 |
475 |
476 | Let's understand it in greater detail.
477 |
478 | ### Another example of `detach`
479 |
480 | Consider two serial operations on a single fully connected layer. Notice that `y1` is detached _before_ `y2` is calculated. Thus, backpropagation from `y2` will stop at `y1`. The gradients of the linear layer are calculated as if `y1` was a leaf variable. The gradient values can be calculated in a fashion similar to what we did in [our simple example](#1-Simple-example)
481 |
482 | 
483 |
484 |
485 | ```python
486 | linear = nn.Linear(4,4)
487 | y0 = Variable( torch.randn(10,4) )
488 | y1 = linear(y0)
489 | y1.detach_()
490 | y2 = linear(y1)
491 | ```
492 |
493 |
494 | ```python
495 | [p.grad for p in linear.parameters()]
496 | ```
497 |
498 |
499 |
500 |
501 | [None, None]
502 |
503 |
504 |
505 |
506 | ```python
507 | y2.backward( torch.ones(10,4) )
508 | [p.grad for p in linear.parameters()]
509 | ```
510 |
511 |
512 |
513 |
514 | [Variable containing:
515 | 0.7047 -1.9666 4.8305 -0.1647
516 | 0.7047 -1.9666 4.8305 -0.1647
517 | 0.7047 -1.9666 4.8305 -0.1647
518 | 0.7047 -1.9666 4.8305 -0.1647
519 | [torch.FloatTensor of size 4x4], Variable containing:
520 | 10
521 | 10
522 | 10
523 | 10
524 | [torch.FloatTensor of size 4]]
525 |
526 |
527 |
528 | Note how the weight gradients are simply the sum of $x_is$:
529 |
530 |
531 | ```python
532 | y1.sum(0)
533 | ```
534 |
535 |
536 |
537 |
538 | Variable containing:
539 | 0.7047
540 | -1.9666
541 | 4.8305
542 | -0.1647
543 | [torch.FloatTensor of size 4]
544 |
545 |
546 |
547 | ### Not detaching `y1`:
548 |
549 | In this case, backpropagation will continue beyond `y1`:
550 |
551 | 
552 |
553 | Let's verify this. For sake of simplicity, let's just take a look at how bias gradients are calculated.
554 |
555 | When gradients from `y2` are backpropagated, similar to the calculations we did in [our simple example](#1-Simple-example), we have:
556 |
557 | > $\frac{\partial Loss}{\partial bias} = 10$
558 |
559 | When gradients from `y1` are backpropagated, for a neuron $y1_i$, we have:
560 |
561 | > $\frac{\partial Loss}{\partial y1_i} = \sum_{i=1}^4 \frac{\partial Loss}{\partial w1_i} * w1_i = \sum_{i=1}^4 w1_i$
562 |
563 | and finally:
564 |
565 | > $\frac{\partial Loss}{\partial bias} = $\frac{\partial Loss}{\partial y1_i}$
566 |
567 | Hence, final bias gradient is equal to:
568 |
569 | $10 + \sum_{i=1}^4 w1_i$
570 |
571 | Let's verify this:
572 |
573 |
574 | ```python
575 | y1 = linear(y0)
576 | y2 = linear(y1)
577 | linear.zero_grad()
578 | y2.backward( torch.ones(10,4) )
579 | ```
580 |
581 |
582 | ```python
583 | linear.bias.grad
584 | ```
585 |
586 |
587 |
588 |
589 | Variable containing:
590 | 4.5625
591 | 4.7037
592 | 14.1218
593 | 5.4621
594 | [torch.FloatTensor of size 4]
595 |
596 |
597 |
598 | Now let's calculate $10 + \sum_{i=1}^4 w1_i$. The result should be (and is) equal to the bias gradients shown above:
599 |
600 |
601 | ```python
602 | 10 + (linear.weight.sum(0) * 10)
603 | ```
604 |
605 |
606 |
607 |
608 | Variable containing:
609 | 4.5625
610 | 4.7037
611 | 14.1218
612 | 5.4621
613 | [torch.FloatTensor of size 4]
614 |
615 |
616 |
617 | #### When is `detach()` used?
618 |
619 | It can useful when you don't want to backpropagate beyond a few steps in time, for example, in training language models.
620 |
--------------------------------------------------------------------------------
/markdowns/basics.md:
--------------------------------------------------------------------------------
1 |
2 | ### Solve the same problem in different ways
3 |
4 | __Problem__: Find square root of `5`.
5 |
6 | __Actual solutions__: `±2.2360679775`
7 |
8 | ### 1. Start with a guess and improve the guess using gradient descent on loss
9 |
10 |
11 | ```python
12 | import torch
13 | ```
14 |
15 |
16 | ```python
17 | guess = torch.tensor(4.0, requires_grad=True)
18 | learning_rate = 0.01
19 | ```
20 |
21 |
22 | ```python
23 | for _ in range(20):
24 | loss = torch.pow(5 - torch.pow(guess, 2), 2)
25 | if guess.grad is not None:
26 | guess.grad.data.zero_()
27 | loss.backward()
28 | guess.data = guess.data - learning_rate * guess.grad.data
29 | print(guess.item())
30 | ```
31 |
32 | 2.240000009536743
33 | 2.2384231090545654
34 | 2.2374794483184814
35 | 2.2369143962860107
36 | 2.2365756034851074
37 | 2.236372470855713
38 | 2.236250638961792
39 | 2.236177682876587
40 | 2.2361338138580322
41 | 2.236107587814331
42 | 2.2360918521881104
43 | 2.2360823154449463
44 | 2.236076593399048
45 | 2.2360732555389404
46 | 2.2360711097717285
47 | 2.236069917678833
48 | 2.2360692024230957
49 | 2.2360687255859375
50 | 2.2360684871673584
51 | 2.2360682487487793
52 |
53 |
54 | ### 2. Parameterize the guess: `shift` the guess
55 |
56 | The idea here is that you do not change the input parameter. Instead, you come up with some variables which interact with the input to create a guess. The variables are then updated using gradient descent. These variables are also called `parameters`.
57 |
58 | Here we use a parameter called `shift`.
59 |
60 | __This is an import shift (no pun intended) in how you think about getting the right answers__:
61 | > __Don't update the guess. Instead update the variables that interact with the guess. A similar line of thought employed in reparameterization trick used in variational inference on neuron networks.__
62 |
63 |
64 | ```python
65 | input_val = torch.tensor(4.0)
66 | shift = torch.tensor(1.0, requires_grad=True)
67 | ```
68 |
69 |
70 | ```python
71 | for _ in range(30):
72 | guess = input_val + shift
73 | loss = torch.pow(5 - torch.pow(guess, 2), 2)
74 | if shift.grad is not None:
75 | shift.grad.data.zero_()
76 | loss.backward()
77 | shift.data = shift.data - learning_rate * shift.grad.data
78 | print(guess.item())
79 | ```
80 |
81 | 5.0
82 | 1.0
83 | 1.1600000858306885
84 | 1.3295643329620361
85 | 1.5014641284942627
86 | 1.6663613319396973
87 | 1.8145501613616943
88 | 1.9384772777557373
89 | 2.034804582595825
90 | 2.104766845703125
91 | 2.152751922607422
92 | 2.184238910675049
93 | 2.2042551040649414
94 | 2.216710090637207
95 | 2.2243528366088867
96 | 2.2290022373199463
97 | 2.2318150997161865
98 | 2.233511447906494
99 | 2.234532356262207
100 | 2.2351458072662354
101 | 2.2355144023895264
102 | 2.2357358932495117
103 | 2.235868453979492
104 | 2.235948324203491
105 | 2.2359962463378906
106 | 2.236024856567383
107 | 2.236042022705078
108 | 2.2360525131225586
109 | 2.2360587120056152
110 | 2.236062526702881
111 |
112 |
113 | ### 2.1 Parameterize the guess: `shift` and `scale` the guess
114 |
115 |
116 | ```python
117 | input_val = torch.tensor(4.0)
118 | shift = torch.tensor(1.0, requires_grad=True)
119 | scale = torch.tensor(1.0, requires_grad=True)
120 | ```
121 |
122 | The learning rate of 0.01 is pretty high for this operation. Thus we reduce it (in log scale). After a couple of runs, we decide on the rate of `0.0005`.
123 |
124 |
125 | ```python
126 | learning_rate = 0.0005
127 | for _ in range(30):
128 | guess = input_val*scale + shift
129 | loss = torch.pow(5 - torch.pow(guess, 2), 2)
130 | if shift.grad is not None:
131 | shift.grad.data.zero_()
132 | if scale.grad is not None:
133 | scale.grad.data.zero_()
134 | loss.backward()
135 | shift.data = shift.data - learning_rate * shift.grad.data
136 | scale.data = scale.data - learning_rate * scale.grad.data
137 | print(guess.item())
138 | ```
139 |
140 | 5.0
141 | 1.5999999046325684
142 | 1.7327359914779663
143 | 1.8504221439361572
144 | 1.9495712518692017
145 | 2.0290589332580566
146 | 2.0899698734283447
147 | 2.134881019592285
148 | 2.1669845581054688
149 | 2.1893954277038574
150 | 2.204770803451538
151 | 2.2151894569396973
152 | 2.222188949584961
153 | 2.2268640995025635
154 | 2.2299740314483643
155 | 2.2320375442504883
156 | 2.2334041595458984
157 | 2.2343082427978516
158 | 2.234905958175659
159 | 2.2353007793426514
160 | 2.2355613708496094
161 | 2.2357337474823
162 | 2.235847234725952
163 | 2.235922336578369
164 | 2.2359719276428223
165 | 2.23600435256958
166 | 2.2360260486602783
167 | 2.2360403537750244
168 | 2.2360496520996094
169 | 2.236055850982666
170 |
171 |
172 | ### 2.3 Note that the loss function has two minima: with a little bit of tweaking, a relatively higher learning rate pushes it to another minimum.
173 |
174 | Another interesting thing to note is that there are two solutions to square root of 5: `2.2360` and `-2.2360`. If we increase the learning rate to `0.001`, it pushes it into another minima. The solution converges to `-2.2360680103302`.
175 |
176 |
177 | ```python
178 | input_val = torch.tensor(4.0)
179 | shift = torch.tensor(1.0, requires_grad=True)
180 | scale = torch.tensor(1.0, requires_grad=True)
181 |
182 | learning_rate = 0.001
183 | for _ in range(30):
184 | guess = input_val*scale + shift
185 | loss = torch.pow(5 - torch.pow(guess, 2), 2)
186 | if shift.grad is not None:
187 | shift.grad.data.zero_()
188 | if scale.grad is not None:
189 | scale.grad.data.zero_()
190 | loss.backward()
191 | shift.data = shift.data - learning_rate * shift.grad.data
192 | scale.data = scale.data - learning_rate * scale.grad.data
193 | print(guess.item())
194 | ```
195 |
196 | 5.0
197 | -1.8000000715255737
198 | -2.0154242515563965
199 | -2.1439850330352783
200 | -2.202786445617676
201 | -2.2249152660369873
202 | -2.2324423789978027
203 | -2.2349019050598145
204 | -2.235694408416748
205 | -2.235948085784912
206 | -2.236029624938965
207 | -2.236055850982666
208 | -2.2360641956329346
209 | -2.2360668182373047
210 | -2.236067533493042
211 | -2.236067771911621
212 | -2.2360680103302
213 | -2.2360680103302
214 | -2.2360680103302
215 | -2.2360680103302
216 | -2.2360680103302
217 | -2.2360680103302
218 | -2.2360680103302
219 | -2.2360680103302
220 | -2.2360680103302
221 | -2.2360680103302
222 | -2.2360680103302
223 | -2.2360680103302
224 | -2.2360680103302
225 | -2.2360680103302
226 |
227 |
228 | ### 3. Replacing `scale` and `shift` with a neuron
229 |
230 |
231 | ```python
232 | neuron = torch.nn.Linear(1,1)
233 | input_val = torch.tensor(4.0).view(1,-1) # input to a linear layer should be in the form [batch_size, input_size]
234 |
235 | learning_rate = 0.001
236 | for _ in range(30):
237 | guess = neuron(input_val)
238 | loss = torch.pow(5 - torch.pow(guess, 2), 2)
239 | neuron.zero_grad()
240 | loss.backward()
241 | for param in neuron.parameters(): # parameters of a model can be iterated through easily using the `.parameters()` method
242 | param.data = param.data - learning_rate * param.grad.data
243 | print(guess.item())
244 | ```
245 |
246 | 0.2342168092727661
247 | 0.3129768371582031
248 | 0.41730424761772156
249 | 0.5542460680007935
250 | 0.7311122417449951
251 | 0.9531161785125732
252 | 1.2182986736297607
253 | 1.5095584392547607
254 | 1.7888929843902588
255 | 2.0078368186950684
256 | 2.1400814056396484
257 | 2.201209545135498
258 | 2.2243618965148926
259 | 2.232259511947632
260 | 2.234842538833618
261 | 2.23567533493042
262 | 2.2359421253204346
263 | 2.236027717590332
264 | 2.2360548973083496
265 | 2.2360639572143555
266 | 2.2360665798187256
267 | 2.236067533493042
268 | 2.236067771911621
269 | 2.2360680103302
270 | 2.2360680103302
271 | 2.2360680103302
272 | 2.2360680103302
273 | 2.2360680103302
274 | 2.2360680103302
275 | 2.2360680103302
276 |
277 |
278 | ### 5. Using the optimizer to do parameter updates for you
279 |
280 | We tell the optimizer two things:
281 | a) what paramters need to be updated
282 | b) what is the learning rate
283 |
284 | Then it does all the hard work for you:
285 |
286 |
287 | ```python
288 | neuron = torch.nn.Linear(1,1)
289 | optimizer = torch.optim.SGD(neuron.parameters(), lr=0.001)
290 |
291 | input_val = torch.tensor(4.0).view(1,-1) # input to a linear layer should be in the form [batch_size, input_size]
292 |
293 | learning_rate = 0.001
294 | for _ in range(30):
295 | guess = neuron(input_val)
296 | loss = torch.pow(5 - torch.pow(guess, 2), 2)
297 | neuron.zero_grad()
298 | loss.backward()
299 | optimizer.step()
300 | print(guess.item())
301 | ```
302 |
303 | 1.4186859130859375
304 | 1.7068755626678467
305 | 1.949059247970581
306 | 2.108257293701172
307 | 2.187859058380127
308 | 2.2195885181427
309 | 2.230670928955078
310 | 2.234327793121338
311 | 2.2355096340179443
312 | 2.235889196395874
313 | 2.236010789871216
314 | 2.2360496520996094
315 | 2.2360620498657227
316 | 2.2360661029815674
317 | 2.236067295074463
318 | 2.236067771911621
319 | 2.236067771911621
320 | 2.2360680103302
321 | 2.2360680103302
322 | 2.2360680103302
323 | 2.2360680103302
324 | 2.2360680103302
325 | 2.2360680103302
326 | 2.2360680103302
327 | 2.2360680103302
328 | 2.2360680103302
329 | 2.2360680103302
330 | 2.2360680103302
331 | 2.2360680103302
332 | 2.2360680103302
333 |
334 |
335 | ### 6. Using the loss function to calculate the loss automatically
336 |
337 | In some cases calculating the loss may not be as trivial as it is in this case. In those scenarios, we can use PyTorch's in-built loss functions.
338 |
339 | Here we will be using `MSELoss`:
340 |
341 |
342 | ```python
343 | neuron = torch.nn.Linear(1,1)
344 | optimizer = torch.optim.SGD(neuron.parameters(), lr=0.001)
345 | loss_function = torch.nn.MSELoss()
346 |
347 | input_val = torch.tensor(4.0).view(1,-1) # input to a linear layer should be in the form [batch_size, input_size]
348 |
349 | learning_rate = 0.001
350 | for _ in range(30):
351 | guess = neuron(input_val)
352 | predicted_output = torch.pow(guess, 2)
353 | actual_output = torch.tensor(5.0).view(1,-1)
354 | loss = loss_function(predicted_output, actual_output)
355 | neuron.zero_grad()
356 | loss.backward()
357 | optimizer.step()
358 | print(guess.item())
359 | ```
360 |
361 | 3.8363184928894043
362 | 1.3013595342636108
363 | 1.593956470489502
364 | 1.860517978668213
365 | 2.0551581382751465
366 | 2.1636502742767334
367 | 2.2105278968811035
368 | 2.2275986671447754
369 | 2.2333250045776367
370 | 2.235186815261841
371 | 2.235785722732544
372 | 2.2359776496887207
373 | 2.236039161682129
374 | 2.2360587120056152
375 | 2.236064910888672
376 | 2.236067295074463
377 | 2.236067533493042
378 | 2.236067771911621
379 | 2.2360680103302
380 | 2.2360680103302
381 | 2.2360680103302
382 | 2.2360680103302
383 | 2.2360680103302
384 | 2.2360680103302
385 | 2.2360680103302
386 | 2.2360680103302
387 | 2.2360680103302
388 | 2.2360680103302
389 | 2.2360680103302
390 | 2.2360680103302
391 |
392 |
--------------------------------------------------------------------------------
/markdowns/batchnorm.md:
--------------------------------------------------------------------------------
1 | [Introduction](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#introduction)
2 |
3 | [An intuitive example of internal covariate shift](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#an-intuitive-example-of-internal-covariate-shift)
4 |
5 | [Solution to internal covariate shift](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#solution-to-internal-covariate-shift)
6 |
7 | [Batch Normalization may not always be optimal for learning](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-may-not-always-be-optimal-for-learning)
8 |
9 | [Batch Normalization with backpropagation](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-with-backpropagation)
10 |
11 | [Batch Normalization for 2D data](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-for-2d-data)
12 |
13 | [Batch Normalization for 3D inputs](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-for-3d-inputs)
14 |
15 | [Batch Normalization for images (or any 4D input)](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-for-images-or-any-4d-input)
16 |
17 | ### Introduction
18 |
19 | Let's start with a discussion of what problem were the authors of the [original paper](https://arxiv.org/abs/1502.03167) dealing with when they came up with the idea of Batch Normalization:
20 |
21 | > [...] the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change. This slows down the training by requiring lower learning rates and careful parameter initialization, and makes it notoriously hard to train models with saturating nonlinearities. We refer to this phenomenon as internal covariate shift, and address the problem by normalizing layer inputs.
22 |
23 | ### An intuitive example of internal covariate shift
24 | Let's first talk about covariate shift. A covariate shift occurs when:
25 |
26 | i) given the same observation X = x, the conditional distributions of Y in training and test sets are the same BUT
27 |
28 | ii) marginal distribution of X in training set is different from marginal distribution of X in test set.
29 |
30 | In other words:
31 |
Ptrain[Y|X=x] = Ptest[Y|X=x] BUT
Ptrain[X] ≠ Ptest[X]
32 |
33 | Now let's talk about the second layer _l_ inside a network with just two layers. The layer takes in an input and spits out an output. The output depends on the input (X) as well as the parameters of the network (θ). It can be thought of as a single layer neural network trying to learn P[Y|X,θ] where θ is the set of network parameters. Now if the distribution of θ changes during the training process, the network has to re-learn P[Y|X,θnew]. This is the internal covariate shift problem. This may not be easy or fast if the network finds itself spending more time in linear extremities of non-linear functions like sigmoid or tanh. Quoting from the paper:
34 |
35 | > Consider a network computing
36 | `l = F2(F1(u, Θ1), Θ2)`
37 | where F1 and F2 are arbitrary transformations, and the parameters Θ1,Θ2 are to be learned so as to minimize the loss l. Learning Θ2 can be viewed as if the inputs x = F1 (u, Θ1 ) are fed into the sub-network
38 | l = F2(x,Θ2). For example, a gradient descent step
39 | Θ2 <- Θ2 - (α/m)\*Σm(∂F2(xi,Θ2)/∂Θ2)
40 | (for batch size m and learning rate α) is exactly equivalent to that for a stand-alone network F2 with input x. There- fore, the input distribution properties that make training more efficient – such as having the same distribution be- tween the training and test data – apply to training the sub-network as well. As such it is advantageous for the distribution of x to remain fixed over time. Then, Θ2 does not have to readjust to compensate for the change in the distribution of x.
41 |
42 | So the solution is to normalize the inputs.
43 |
44 |
45 | ### Solution to internal covariate shift
46 |
47 | If a neuron is able to see the entire range of values across the entire training set that it is going to get as subsequent inputs, training can be made faster by normalizing the distribution (i.e. making its mean zero and variance one). The authors call it _input whitening_:
48 |
49 | > It has been long known (LeCun et al., 1998b; Wiesler & Ney, 2011) that the network training converges faster if its inputs are whitened – i.e., linearly transformed to have zero means and unit variances, and decorrelated. As each layer observes the inputs produced by the layers below, it would be advantageous to achieve the same whitening of the inputs of each layer. By whitening the inputs to each layer, we would take a step towards achieving the fixed distributions of inputs that would remove the ill effects of the internal covariate shift.
50 |
51 | However, for a large training set, it is not computationally feasible to normalize inputs for each neuron. Hence, the normalization is carried out on a per-batch basis:
52 |
53 | > In the batch setting where each training step is based on the entire training set, we would use the whole set to normalize activations. However, this is impractical when using stochastic optimization. Therefore, we make the second simplification: since we use mini-batches in stochastic gradient training, each mini-batch produces estimates of the mean and variance of each activation. This way, the statistics used for normalization can fully participate in the gradient backpropagation.
54 |
55 | ### Batch Normalization may not always be optimal for learning
56 |
57 | Batch normalization may lead to, say, inputs always lying in the linear range of sigmoid function. Hence the inputs are shifted and scaled by parameters γ and β. Optimal values of these hyperparameters can be __learnt__ by using backpropagation.
58 |
59 | > Note that simply normalizing each input of a layer may change what the layer can represent. For instance, normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity. To address this, we make sure that the transformation inserted in the network can represent the identity transform. To accomplish this, we introduce, for each activation x(k) , a pair of parameters γ(k), β(k), which scale and shift the normalized value:
60 | y(k) = γ(k)xnormalized(k) + β(k).
61 |
62 | ### Batch Normalization with backpropagation
63 |
64 | If backpropagation ignores batch normalization, it may undo the modification due to batch normalization. Hence, backprop needs to incorporate batch normalization while calculating gradients. [This link](https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html) provides a nice explanation of how to derive backpropagation gradients for batch normalization. Let's see it in action now:
65 |
66 | ### Batch Normalization for 2D data
67 |
68 | Let's start with 1D normalization.
69 | ```python
70 | import numpy as np
71 | import torch
72 | import torch.nn as nn
73 | from torch.autograd import Variable
74 | ```
75 |
76 | Let's define X and calculate its mean and variance. Note that X has size (20,100) which means it has 20 samples with each sample having 100 features (or dimensions). For normalizing, we need to look across all samples. In other words, we normalize across 20 values for each dimension (with each value coming on a sample).
77 |
78 | ```python
79 | X = torch.randn(20,100) * 5 + 10
80 | X = Variable(X)
81 |
82 | mu = torch.mean(X[:,1])
83 | var_ = torch.var(X[:,1], unbiased=False)
84 | sigma = torch.sqrt(var_ + 1e-5)
85 | x = (X[:,1] - mu)/sigma
86 | ```
87 |
88 | Note that in the line above we set `unbiased = False` while calculating `var_`. This is to prevent [Bessel's correction](https://en.wikipedia.org/wiki/Bessel's_correction). In other words, we want to divide by N and not N-1 while calculating the variance. `x` is the same as the result of batch normalization. Also note that in the code below we set `affine = False`. This is to prevent creation of parameters γ(k), β(k) which may scale and shift normalized data again. While training, `affine` is set to `True`.
89 |
90 | ```python
91 | B = nn.BatchNorm1d(100, affine=False)
92 | y = B(X)
93 | print(x.data / y[:,1].data)
94 | ```
95 |
96 | Output:
97 | ```
98 | 1.0000
99 | 1.0000
100 | 1.0000
101 | ...
102 | ```
103 | This also works for 3D input.
104 |
105 | ### Batch Normalization for 3D inputs
106 |
107 | Let's define some 3D data with mean 4 and variance 4:
108 | ```python
109 | X3 = torch.randn(150,20,100) * 2 + 4
110 | X3 = Variable(X3)
111 | B2 = nn.BatchNorm1d(20)
112 | ```
113 |
114 | Note that here we did not set `affine = False`. Instead, we can manually set those values to what we want. To preserve normalization, we want γ(k) = 1 and β(k) = 0. These values are stored in parameters `weight` and `bias` of the BatchNorm variable.
115 |
116 | ```python
117 | B2.weight.data.fill_(1)
118 | B2.bias.data.fill_(0)
119 | Y = B2(X3)
120 |
121 | #Manual calculation
122 | mu = X3[:,0,:].mean()
123 | sigma = torch.sqrt(torch.var(X3[:,0,:], unbiased=False) + 1e-5)
124 | X_normalized = (X3[:,0,:] - mu)/sigma
125 | ```
126 |
127 | In the above example, `X_normalized` has the same values as `Y[:,0,:]`.
128 |
129 | ### Batch Normalization for images (or any 4D input)
130 |
131 | A batch of RGB images has four dimensions: (B,C,X,Y) or (B,X,Y,C) where B is batch number, C is channel number, and X, Y are locations. In the last example, we only had (B,N) where N was the number of dimensions so it was pretty straightforward to figure out the axis of normalization.
132 |
133 | Along which axis do we normalize here? Normalizing all values across all batch samples is not the solution here. Why? Because each batch sample
134 |
135 | __Hint__: We want to choose an axis that enables a neuron look through all samples in the batch.
136 |
137 | Batch normalization is images is done along the channel axis.
138 |
139 | ```python
140 | X = torch.randn(5,25,100,100) * 2 + 4
141 | X = Variable(X)
142 | B = nn.BatchNorm2d(25, affine=False)
143 | Y = B(X)
144 | ```
145 |
146 | Here `Y[:,i,:,:]` is the same as
147 | `((X[:,i,:,:] - X[:,i,:,:].data.mean())/((X[:,i,:,:].data.var(unbiased=False) + 1e-5)**0.5))` for all valid values of `i`.
148 |
--------------------------------------------------------------------------------
/markdowns/dropout.md:
--------------------------------------------------------------------------------
1 | [1. Theory, motivation and food for thought](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#1-theory-motivation-and-food-for-thought)
[2. Dropout in action](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#2-dropout-in-action)
2 | * [Importing items we need](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#importing-items-we-need)
3 | * [Create a dropout layer with p = 0.6](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#create-a-dropout-layer-with-p--06)
4 | * [Create an input with every element equal to 1](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#create-an-input-with-every-element-equal-to-1)
5 | * [Let's see what the output looks like](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#lets-see-what-the-output-looks-like)
6 |
7 | [2.1 Why?](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#21-why)
8 | * [Rule for scaling the input](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#rule-for-scaling-the-input)
9 | * [Dropout + Linear Layer](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#dropout--linear-layer)
10 | * [Create input, dropout and linear layer](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#create-input-dropout-and-linear-layer)
11 | * [Let's look at the output of passing the input through dropout layer:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#lets-look-at-the-output-of-passing-the-input-through-dropout-layer)
12 |
13 | [2.2 Randomness of sampling subnetworks](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#22-randomness-of-sampling-subnetworks)
14 |
15 | [2.3 Dropout + Backpropagation](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#23-dropout--backpropagation)
16 | * [Calculate weight gradients for each training case individually and store in a variable. We use `p=0.9` here.](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#calculate-weight-gradients-for-each-training-case-individually-and-store-in-a-variable-we-use-p09-here)
17 | * [Calculate gradients per training case](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#calculate-gradients-per-training-case)
18 | * [Let's take a look at the average of `dWs`:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#lets-take-a-look-at-the-average-of-dws)
19 | * [Now let's take a look at the gradient of the whole batch. It should be the same as `dWs/n` above.](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#now-lets-take-a-look-at-the-gradient-of-the-whole-batch-it-should-be-the-same-as-dwsn-above)
20 |
21 | ### 1. Theory, motivation and food for thought
22 |
23 | The [original paper](http://www.jmlr.org/papers/volume15/srivastava14a.old/srivastava14a.pdf) on Dropout is fairly easy to understand and is one of the more interesting research papers I have read. Before I jump into the what and how of Dropout, I think it's important to appreciate _why_ it was needed and what motivated the people involved to come up with this idea.
24 |
25 | The below snippet from the original paper gives a good idea:
26 |
27 | > Combining several models is most
28 | helpful when the individual models are different from each other and in order to make
29 | neural net models different, they should either have different architectures or be trained
30 | on different data. Training many different architectures is hard because finding optimal
31 | hyperparameters for each architecture is a daunting task and training each large network
32 | requires a lot of computation. Moreover, large networks normally require large amounts of
33 | training data and there may not be enough data available to train different networks on
34 | different subsets of the data. Even if one was able to train many different large networks,
35 | using them all at test time is infeasible in applications where it is important to respond
36 | quickly.
37 |
38 | So what is Dropout?
39 |
40 | > The term “dropout” refers to dropping out units (hidden and
41 | visible) in a neural network. By dropping a unit out, we mean temporarily removing it from
42 | the network, along with all its incoming and outgoing connections, as shown in Figure 1.
43 | The choice of which units to drop is random.
44 |
45 | __Caveat__: The part `By dropping a unit out, we mean temporarily removing it from
46 | the network, along with all its incoming and outgoing connections` can be a bit misleading. While theoretically correct, it gives an impression that the network architecture changes randomly over time. But the implementation of Dropout does not change the network architecture every iteration. Instead, the output of each node chosen to be dropped out is set to zero regardless of what the input was. This also results in back propagation gradients of weights attached to that node becoming zero (we will soon see this in action).
47 |
48 | Before we jump into the code, I would recommend read you read Section 2 of the original paper gives possible explanations for why Dropout works so well. Here's a snippet:
49 |
50 | > One possible explanation for the superiority of sexual reproduction is that, over the long
51 | term, the criterion for natural selection may not be individual fitness but rather mix-ability
52 | of genes. The ability of a set of genes to be able to work well with another random set of
53 | genes makes them more robust. Since a gene cannot rely on a large set of partners to be
54 | present at all times, it must learn to do something useful on its own or in collaboration with
55 | a
56 | small
57 | number of other genes. According to this theory, the role of sexual reproduction
58 | is not just to allow useful new genes to spread throughout the population, but also to
59 | facilitate this process by reducing complex co-adaptations that would reduce the chance of
60 | a new gene improving the fitness of an individual. Similarly, each hidden unit in a neural
61 | network trained with dropout must learn to work with a randomly chosen sample of other
62 | units. This should make each hidden unit more robust and drive it towards creating useful
63 | features on its own without relying on other hidden units to correct its mistakes. However,
64 | the hidden units within a layer will still learn to do different things from each other.
65 |
66 |
67 | ### 2. Dropout in action
68 |
69 | Think of Dropout as a layer. The output of the layer has the same size (or dimensions) as the input. However, a fraction `p` of the elements in the output is set to zero while the remaining fraction `1-p` of elements are identical (not really but we will see why in a while) with the corresponding input. We can choose any value between (0 and 1 of course) for p.
70 |
71 | #### Importing items we need
72 | ```python
73 | import torch
74 | import torch.nn as nn
75 | from torch.autograd import Variable
76 | ```
77 |
78 | #### Create a dropout layer with p = 0.6
79 | ```python
80 | p = 0.6
81 | do = nn.Dropout(p)
82 | ```
83 |
84 | #### Create an input with every element equal to 1
85 | ```python
86 | X = torch.ones(5,2)
87 | print(X)
88 | ```
89 | Output:
90 | ```
91 | 1 1
92 | 1 1
93 | 1 1
94 | 1 1
95 | 1 1
96 | [torch.FloatTensor of size 5x2]
97 | ```
98 |
99 | #### Let's see what the output looks like
100 | ```python
101 | do(X)
102 | ```
103 | Output:
104 | ```
105 | Variable containing:
106 | 2.5000 2.5000
107 | 0.0000 2.5000
108 | 2.5000 0.0000
109 | 0.0000 0.0000
110 | 2.5000 2.5000
111 | [torch.FloatTensor of size 5x2]
112 | ```
113 |
114 | ### 2.1 Why?
115 | We see that some of the nodes (__approximately__ a fraction 0.6 have been set of zero in the output). But the remaining ones are not equal to 1.
116 |
117 | Why? Because we scaled the input values linearly. Why? Because we want the expected output of each node to be the same with and without dropout.
118 |
119 | Why? Because during testing or evaluation dropout is not used.
120 |
121 | Why? Because the overall learned network can be thought of as a _combination of several models_ where each _model_ was the result of a dropout layer involved during the training process. Using dropout during testing or evaluation means we are still using a subset of several possible models and not a combination of them which is suboptimal.
122 |
123 | #### Rule for scaling the input
124 | Let consider a node `n` with value `x` that is subjected to a dropout layer 100 times. It's output will be `0` approximately `p*100` times and `x_out` `(1-p)*100` times. Its expected output is:
125 |
126 | `p*0 + (1-p)*x_out` which is equal to `(1-p)*x_out`
127 |
128 | We want it to be equal to x. Hence `x_out = x/(1-p)`
129 |
130 | In the example above, `p = 0.6`. Hence, the output is `1/(1-0.6) = 2.5`.
131 |
132 | #### Dropout + Linear Layer
133 | Let's create a super small network that looks like this:
134 |
135 | ```
136 | Input Size: (5X10)
137 | ↓
138 | Dropout p = 0.9
139 | ↓
140 | Linear Layer Size: (10X5)
141 | ↓
142 | Output (5X5)
143 | ```
144 |
145 | #### Create input, dropout and linear layer
146 | ```python
147 | #input
148 | inputs = torch.ones(5,10)
149 | inputs = Variable(inputs)
150 |
151 | #dropout
152 | p = 0.9
153 | do = nn.Dropout(p)
154 |
155 | #linear
156 | fc = nn.Linear(10,5)
157 | ```
158 |
159 | #### Let's look at the output of passing the input through dropout layer:
160 |
161 | ```python
162 | out = do(inputs)
163 | print(out)
164 | ```
165 | Output:
166 | ```
167 | Variable containing:
168 | 0 0 0 0 0 0 0 0 0 0
169 | 0 0 0 0 0 0 0 0 0 0
170 | 0 0 0 0 0 10 0 0 0 0
171 | 0 0 0 10 0 0 0 0 0 0
172 | 0 0 0 0 0 0 0 0 0 0
173 | [torch.FloatTensor of size 5x10]
174 | ```
175 |
176 | Since `p = 0.9` about 90% of nodes have been set to 0 in the output. The remaining ones have been scaled by `1/(1-0.9) = 10`.
177 |
178 | If you run the above code again and again, the output will be different each time. The output I got by running the above code again is:
179 |
180 | ```
181 | Variable containing:
182 | 0 10 0 0 0 0 0 0 0 0
183 | 0 0 0 0 10 0 0 0 0 0
184 | 0 0 0 0 0 0 0 0 0 0
185 | 0 0 0 0 0 0 0 0 0 0
186 | 0 0 0 0 0 0 0 0 0 0
187 | [torch.FloatTensor of size 5x10]
188 | ```
189 |
190 | ### 2.2 Randomness of sampling subnetworks
191 |
192 | Thus, in a way we are sampling random subsets of network by randomly shutting some of the nodes off. But _how random_ are they actually? We see most of the rows tend to be the same (full of zeros). There's not much of randomness going on there.
193 |
194 | Let's try setting `p=0.1` and try again:
195 |
196 | ```python
197 | p = 0.1
198 | do = nn.Dropout(p)
199 | out = do(inputs)
200 | print(out)
201 | ```
202 | Output:
203 | ```
204 | Variable containing:
205 | 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 0.0000 1.1111 1.1111
206 | 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 0.0000 1.1111
207 | 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111
208 | 1.1111 1.1111 1.1111 1.1111 0.0000 1.1111 1.1111 1.1111 1.1111 1.1111
209 | 1.1111 0.0000 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111
210 | [torch.FloatTensor of size 5x10]
211 | ```
212 |
213 | Run the above code a few times and you will see there's not much of randomness going on. In this case, most rows tend to be full of the same value: `1.1111`.
214 |
215 | How do we get maximum randomness in our subnetworks? By setting `p=0.5`. If you are familiar with information theory, we are maximizing entropy here.
216 |
217 | ```python
218 | p = 0.5
219 | do = nn.Dropout(p)
220 | out = do(inputs)
221 | print(out)
222 | ```
223 | Output:
224 | ```
225 | Variable containing:
226 | 2 0 2 2 2 2 0 2 2 2
227 | 2 2 2 2 0 2 0 2 2 2
228 | 0 0 2 2 2 2 2 0 2 2
229 | 0 2 2 2 2 2 2 0 0 0
230 | 0 2 2 0 2 2 2 0 0 0
231 | [torch.FloatTensor of size 5x10]
232 | ```
233 |
234 | It is a standard practice to use dropout with `p=0.5` while training a neural network.
235 |
236 | ### 2.3 Dropout + Backpropagation
237 |
238 | Training a network with dropout layer is pretty straightforward. Quoting from the original paper again:
239 |
240 | > Dropout neural networks can be trained using stochastic gradient descent in a manner simi- lar to standard neural nets. The only difference is that for each training case in a mini-batch, we sample a thinned network by dropping out units. Forward and backpropagation for that training case are done only on this thinned network. The gradients for each parameter are averaged over the training cases in each mini-batch. Any training case which does not use a parameter contributes a gradient of zero for that parameter.
241 |
242 | Here we put to use the linear layer we created in the last section.
243 |
244 | First, let's take a look at the weight and bias parameters of our linear layer:
245 | ```python
246 | print(fc.weight)
247 | print(fc.bias)
248 | ```
249 | Output:
250 | ```
251 | Parameter containing:
252 | 0.0738 0.3089 -0.0683 -0.2662 0.2217 -0.2005 0.0918 -0.0003 -0.0241 -0.0386
253 | -0.0407 -0.1743 0.0462 0.1506 0.2435 -0.2997 -0.1420 -0.1419 -0.1131 -0.2221
254 | -0.0406 0.2344 0.1432 -0.2777 -0.1128 0.0976 -0.1798 -0.0479 0.2498 -0.0814
255 | -0.0947 0.2826 -0.0856 0.2716 -0.1775 0.2035 -0.3161 -0.2716 0.0440 -0.1010
256 | 0.0102 -0.0008 -0.0904 0.2708 -0.0478 -0.1248 -0.0073 0.2026 -0.2273 -0.0355
257 | [torch.FloatTensor of size 5x10]
258 |
259 | Parameter containing:
260 | 0.1163
261 | -0.2310
262 | -0.0791
263 | 0.1346
264 | -0.2195
265 | [torch.FloatTensor of size 5]
266 | ```
267 |
268 | We will create a dummy target and use it to calculate cross entropy loss. We will then use the loss value to backpropagate and take a look at how the gradients look.
269 | ```python
270 | target = torch.LongTensor([0,4,2,1,4])
271 | target = Variable(target)
272 | print(target)
273 | ```
274 | Output:
275 | ```
276 | Variable containing:
277 | 0
278 | 4
279 | 2
280 | 1
281 | 4
282 | [torch.LongTensor of size 5]
283 | ```
284 |
285 | #### Calculate weight gradients for each training case individually and store in a variable. We use `p=0.9` here.
286 | ```python
287 | do = nn.Dropout(0.9)
288 | out = do(inputs)
289 | print(out)
290 | ```
291 | Output:
292 | ```
293 | Variable containing:
294 | 0 10 0 0 0 0 10 0 10 0
295 | 0 0 10 0 0 0 0 0 0 0
296 | 0 10 0 0 0 0 0 10 0 0
297 | 0 0 0 0 0 0 0 0 0 0
298 | 0 0 0 0 0 0 0 0 0 0
299 | [torch.FloatTensor of size 5x10]
300 | ```
301 |
302 | #### Calculate gradients per training case
303 | We initialize an zero tensor `dWs` and keep adding gradients from each training case to it.
304 | ```python
305 | dWs = torch.zeros_like(fc.weight)
306 | for i in range(out.size(0)):
307 | i_ = out[i].view(1,-1)
308 | t = target[i]
309 | o = fc(i_)
310 | fc.weight.grad.zero_()
311 | loss = nn.CrossEntropyLoss()(o, t)
312 | loss.backward()
313 | dWs += fc.weight.grad
314 | ```
315 |
316 | #### Let's take a look at the average of `dWs`:
317 | ```python
318 | n = dWs.size(0)
319 | print(dWs/n)
320 | ```
321 | Output:
322 | ```
323 | Variable containing:
324 | 0.0000 0.6987 0.1744 0.0000 0.0000 0.0000 -0.5988 1.2975 -0.5988 0.0000
325 | 0.0000 0.0021 0.3873 0.0000 0.0000 0.0000 0.0003 0.0018 0.0003 0.0000
326 | 0.0000 -1.1258 1.1889 0.0000 0.0000 0.0000 0.5595 -1.6853 0.5595 0.0000
327 | 0.0000 0.1041 0.1495 0.0000 0.0000 0.0000 0.0367 0.0674 0.0367 0.0000
328 | 0.0000 0.3209 -1.9001 0.0000 0.0000 0.0000 0.0022 0.3187 0.0022 0.0000
329 | [torch.FloatTensor of size 5x10]
330 | ```
331 |
332 | #### Now let's take a look at the gradient of the whole batch. It should be the same as `dWs/n` above.
333 |
334 | ```python
335 | out = fc(out)
336 | criterion = nn.CrossEntropyLoss()
337 | loss = criterion(out, target)
338 | fc.weight.grad.zero_()
339 | loss.backward()
340 | print(fc.weight.grad)
341 | ```
342 | Output:
343 | ```
344 | Variable containing:
345 | 0.0000 0.6987 0.1744 0.0000 0.0000 0.0000 -0.5988 1.2975 -0.5988 0.0000
346 | 0.0000 0.0021 0.3873 0.0000 0.0000 0.0000 0.0003 0.0018 0.0003 0.0000
347 | 0.0000 -1.1258 1.1889 0.0000 0.0000 0.0000 0.5595 -1.6853 0.5595 0.0000
348 | 0.0000 0.1041 0.1495 0.0000 0.0000 0.0000 0.0367 0.0674 0.0367 0.0000
349 | 0.0000 0.3209 -1.9001 0.0000 0.0000 0.0000 0.0022 0.3187 0.0022 0.0000
350 | [torch.FloatTensor of size 5x10]
351 | ```
352 |
353 | We get the same output as `dWs/n`.
354 |
--------------------------------------------------------------------------------
/markdowns/einstein_summation.md:
--------------------------------------------------------------------------------
1 |
2 | ### Einstein summation
3 |
4 | Einstein summation is a notation to multiply and/or add tensors which can be convenient and intuitive. The posts below do a pretty good job of explaining how it works. This notebook is more of a practice session to get used to applying einstein summation notation.
5 |
6 | - [Einstein Summation in Numpy](https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/)
7 | - [A basic introduction to NumPy's einsum](https://ajcr.net/Basic-guide-to-einsum/)
8 |
9 | The examples here have been taken from the tables in the second link above.
10 |
11 |
12 | ```python
13 | import numpy as np
14 | import torch
15 | ```
16 |
17 |
18 | ```python
19 | v = np.arange(6)
20 | # array([0, 1, 2, 3, 4, 5])
21 |
22 | w = np.arange(6) + 6
23 | # array([ 6, 7, 8, 9, 10, 11])
24 |
25 | A = np.arange(6).reshape(2,3)
26 | # array([[0, 1, 2],
27 | # [3, 4, 5]])
28 |
29 | B = (np.arange(6) + 6).reshape(3,2)
30 | # array([[ 6, 7],
31 | # [ 8, 9],
32 | # [10, 11]])
33 |
34 | C = np.arange(9).reshape(3,3)
35 | # array([[0, 1, 2],
36 | # [3, 4, 5],
37 | # [6, 7, 8]])
38 |
39 |
40 | v_torch = torch.from_numpy(v)
41 | w_torch = torch.from_numpy(w)
42 | A_torch = torch.from_numpy(A)
43 | B_torch = torch.from_numpy(B)
44 | C_torch = torch.from_numpy(C)
45 | ```
46 |
47 | ## 1. Vectors
48 |
49 | ### 1.1 Return a view of itself
50 |
51 |
52 | ```python
53 | np.einsum('i', v)
54 | ```
55 |
56 |
57 |
58 |
59 | array([0, 1, 2, 3, 4, 5])
60 |
61 |
62 |
63 |
64 | ```python
65 | torch.einsum('i', v_torch)
66 | ```
67 |
68 |
69 |
70 |
71 | tensor([0, 1, 2, 3, 4, 5])
72 |
73 |
74 |
75 | ### 1.2 Sum up elements of a vector
76 |
77 |
78 | ```python
79 | np.einsum('i->', v)
80 | ```
81 |
82 |
83 |
84 |
85 | 15
86 |
87 |
88 |
89 |
90 | ```python
91 | torch.einsum('i->', v_torch)
92 | ```
93 |
94 |
95 |
96 |
97 | tensor(15)
98 |
99 |
100 |
101 |
102 | ```python
103 | v.sum(), v_torch.sum()
104 | ```
105 |
106 |
107 |
108 |
109 | (15, tensor(15))
110 |
111 |
112 |
113 | ### 1.3 Element-wise operations
114 |
115 |
116 | ```python
117 | np.einsum('i,i->i', v, w)
118 | ```
119 |
120 |
121 |
122 |
123 | array([ 0, 7, 16, 27, 40, 55])
124 |
125 |
126 |
127 |
128 | ```python
129 | v * w
130 | ```
131 |
132 |
133 |
134 |
135 | array([ 0, 7, 16, 27, 40, 55])
136 |
137 |
138 |
139 |
140 | ```python
141 | torch.einsum('i,i->i', v_torch, w_torch)
142 | ```
143 |
144 |
145 |
146 |
147 | tensor([ 0, 7, 16, 27, 40, 55])
148 |
149 |
150 |
151 | #### 1.3.1 One can even multiply (element-wise) three or more vectors in the same manner
152 |
153 |
154 | ```python
155 | np.einsum('i,i,i->i', v, w, w)
156 | ```
157 |
158 |
159 |
160 |
161 | array([ 0, 49, 128, 243, 400, 605])
162 |
163 |
164 |
165 |
166 | ```python
167 | v*w*w
168 | ```
169 |
170 |
171 |
172 |
173 | array([ 0, 49, 128, 243, 400, 605])
174 |
175 |
176 |
177 |
178 | ```python
179 | torch.einsum('i,i,i->i', v_torch, w_torch, w_torch)
180 | ```
181 |
182 |
183 |
184 |
185 | tensor([ 0, 49, 128, 243, 400, 605])
186 |
187 |
188 |
189 | ### 1.4 Inner product
190 |
191 |
192 | ```python
193 | np.einsum('i,i->', v,w)
194 | ```
195 |
196 |
197 |
198 |
199 | 145
200 |
201 |
202 |
203 |
204 | ```python
205 | np.dot(v,w)
206 | ```
207 |
208 |
209 |
210 |
211 | 145
212 |
213 |
214 |
215 |
216 | ```python
217 | torch.einsum('i,i->',v_torch,w_torch)
218 | ```
219 |
220 |
221 |
222 |
223 | tensor(145)
224 |
225 |
226 |
227 | ### 1.5 Outer product
228 |
229 |
230 | ```python
231 | np.einsum('i,j->ij',v,w)
232 | ```
233 |
234 |
235 |
236 |
237 | array([[ 0, 0, 0, 0, 0, 0],
238 | [ 6, 7, 8, 9, 10, 11],
239 | [12, 14, 16, 18, 20, 22],
240 | [18, 21, 24, 27, 30, 33],
241 | [24, 28, 32, 36, 40, 44],
242 | [30, 35, 40, 45, 50, 55]])
243 |
244 |
245 |
246 |
247 | ```python
248 | np.outer(v,w)
249 | ```
250 |
251 |
252 |
253 |
254 | array([[ 0, 0, 0, 0, 0, 0],
255 | [ 6, 7, 8, 9, 10, 11],
256 | [12, 14, 16, 18, 20, 22],
257 | [18, 21, 24, 27, 30, 33],
258 | [24, 28, 32, 36, 40, 44],
259 | [30, 35, 40, 45, 50, 55]])
260 |
261 |
262 |
263 |
264 | ```python
265 | torch.einsum('i,j->ij',v_torch, w_torch)
266 | ```
267 |
268 |
269 |
270 |
271 | tensor([[ 0, 0, 0, 0, 0, 0],
272 | [ 6, 7, 8, 9, 10, 11],
273 | [12, 14, 16, 18, 20, 22],
274 | [18, 21, 24, 27, 30, 33],
275 | [24, 28, 32, 36, 40, 44],
276 | [30, 35, 40, 45, 50, 55]])
277 |
278 |
279 |
280 | #### 1.5.1 Transpose is just as easy
281 |
282 |
283 | ```python
284 | np.einsum('i,j->ji',v,w)
285 | ```
286 |
287 |
288 |
289 |
290 | array([[ 0, 6, 12, 18, 24, 30],
291 | [ 0, 7, 14, 21, 28, 35],
292 | [ 0, 8, 16, 24, 32, 40],
293 | [ 0, 9, 18, 27, 36, 45],
294 | [ 0, 10, 20, 30, 40, 50],
295 | [ 0, 11, 22, 33, 44, 55]])
296 |
297 |
298 |
299 | ## 2. Matrices
300 |
301 | ### 2.1 Do nothing
302 |
303 |
304 | ```python
305 | np.einsum('ij',B)
306 | ```
307 |
308 |
309 |
310 |
311 | array([[ 6, 7],
312 | [ 8, 9],
313 | [10, 11]])
314 |
315 |
316 |
317 |
318 | ```python
319 | torch.einsum('ij', B_torch)
320 | ```
321 |
322 |
323 |
324 |
325 | tensor([[ 6, 7],
326 | [ 8, 9],
327 | [10, 11]])
328 |
329 |
330 |
331 | ### 2.2 Transpose
332 |
333 |
334 | ```python
335 | A
336 | ```
337 |
338 |
339 |
340 |
341 | array([[0, 1, 2],
342 | [3, 4, 5]])
343 |
344 |
345 |
346 |
347 | ```python
348 | np.einsum('ij->ji', A)
349 | ```
350 |
351 |
352 |
353 |
354 | array([[0, 3],
355 | [1, 4],
356 | [2, 5]])
357 |
358 |
359 |
360 |
361 | ```python
362 | torch.einsum('ij->ji', A_torch)
363 | ```
364 |
365 |
366 |
367 |
368 | tensor([[0, 3],
369 | [1, 4],
370 | [2, 5]])
371 |
372 |
373 |
374 |
375 | ```python
376 | print(A.T)
377 | print('---')
378 | print(A_torch.t())
379 | ```
380 |
381 | [[0 3]
382 | [1 4]
383 | [2 5]]
384 | ---
385 | tensor([[0, 3],
386 | [1, 4],
387 | [2, 5]])
388 |
389 |
390 | ### 2.3 Diagonal of a square matrix
391 |
392 |
393 | ```python
394 | C
395 | ```
396 |
397 |
398 |
399 |
400 | array([[0, 1, 2],
401 | [3, 4, 5],
402 | [6, 7, 8]])
403 |
404 |
405 |
406 |
407 | ```python
408 | np.einsum('ii->i',C)
409 | ```
410 |
411 |
412 |
413 |
414 | array([0, 4, 8])
415 |
416 |
417 |
418 |
419 | ```python
420 | torch.einsum('ii->i', C_torch)
421 | ```
422 |
423 |
424 |
425 |
426 | tensor([0, 4, 8])
427 |
428 |
429 |
430 |
431 | ```python
432 | np.diag(C), torch.diag(C_torch)
433 | ```
434 |
435 |
436 |
437 |
438 | (array([0, 4, 8]), tensor([0, 4, 8]))
439 |
440 |
441 |
442 | ### 2.4 Trace of a matrix
443 |
444 |
445 | ```python
446 | np.einsum('ii->',C)
447 | ```
448 |
449 |
450 |
451 |
452 | 12
453 |
454 |
455 |
456 |
457 | ```python
458 | torch.einsum('ii->', C_torch)
459 | ```
460 |
461 |
462 |
463 |
464 | tensor(12)
465 |
466 |
467 |
468 |
469 | ```python
470 | np.trace(C), torch.trace(C_torch)
471 | ```
472 |
473 |
474 |
475 |
476 | (12, tensor(12))
477 |
478 |
479 |
480 | ### 2.5 Sum of matrix
481 |
482 |
483 | ```python
484 | np.einsum('ij->',A)
485 | ```
486 |
487 |
488 |
489 |
490 | 15
491 |
492 |
493 |
494 |
495 | ```python
496 | torch.einsum('ij->',A_torch)
497 | ```
498 |
499 |
500 |
501 |
502 | tensor(15)
503 |
504 |
505 |
506 |
507 | ```python
508 | A.sum(), A_torch.sum()
509 | ```
510 |
511 |
512 |
513 |
514 | (15, tensor(15))
515 |
516 |
517 |
518 | ### 2.6 Sum of matrix along axes
519 |
520 |
521 | ```python
522 | np.einsum('ij->j',B)
523 | ```
524 |
525 |
526 |
527 |
528 | array([24, 27])
529 |
530 |
531 |
532 |
533 | ```python
534 | B.sum(0)
535 | ```
536 |
537 |
538 |
539 |
540 | array([24, 27])
541 |
542 |
543 |
544 |
545 | ```python
546 | torch.einsum('ij->j', B_torch)
547 | ```
548 |
549 |
550 |
551 |
552 | tensor([24, 27])
553 |
554 |
555 |
556 |
557 | ```python
558 | torch.einsum('ij->i', B_torch)
559 | ```
560 |
561 |
562 |
563 |
564 | tensor([13, 17, 21])
565 |
566 |
567 |
568 |
569 | ```python
570 | B_torch.sum(1)
571 | ```
572 |
573 |
574 |
575 |
576 | tensor([13, 17, 21])
577 |
578 |
579 |
580 | ### 2.7 Element-wise multiplication
581 |
582 |
583 | ```python
584 | np.einsum('ij,ij->ij',A,B.T)
585 | ```
586 |
587 |
588 |
589 |
590 | array([[ 0, 8, 20],
591 | [21, 36, 55]])
592 |
593 |
594 |
595 |
596 | ```python
597 | A * B.T
598 | ```
599 |
600 |
601 |
602 |
603 | array([[ 0, 8, 20],
604 | [21, 36, 55]])
605 |
606 |
607 |
608 | #### 2.7.1 Element-wise multiplication can be done in various ways just by permuting the indices
609 |
610 |
611 | ```python
612 | torch.einsum('ij,ji->ij', C_torch.t(), C_torch)
613 | ```
614 |
615 |
616 |
617 |
618 | tensor([[ 0, 9, 36],
619 | [ 1, 16, 49],
620 | [ 4, 25, 64]])
621 |
622 |
623 |
624 |
625 | ```python
626 | C_torch.t() * C_torch.t()
627 | ```
628 |
629 |
630 |
631 |
632 | tensor([[ 0, 9, 36],
633 | [ 1, 16, 49],
634 | [ 4, 25, 64]])
635 |
636 |
637 |
638 | #### 2.7.2 The below two operations are also the same
639 |
640 |
641 | ```python
642 | torch.einsum('ij,ji->ij', C_torch, C_torch)
643 | ```
644 |
645 |
646 |
647 |
648 | tensor([[ 0, 3, 12],
649 | [ 3, 16, 35],
650 | [12, 35, 64]])
651 |
652 |
653 |
654 |
655 | ```python
656 | C_torch * C_torch.t()
657 | ```
658 |
659 |
660 |
661 |
662 | tensor([[ 0, 3, 12],
663 | [ 3, 16, 35],
664 | [12, 35, 64]])
665 |
666 |
667 |
668 | ### 2.8 Matrix multiplication
669 |
670 |
671 | ```python
672 | np.einsum('ij,jk->ik',A,B)
673 | ```
674 |
675 |
676 |
677 |
678 | array([[ 28, 31],
679 | [100, 112]])
680 |
681 |
682 |
683 |
684 | ```python
685 | np.einsum('ij,jk',A,B)
686 | ```
687 |
688 |
689 |
690 |
691 | array([[ 28, 31],
692 | [100, 112]])
693 |
694 |
695 |
696 |
697 | ```python
698 | np.dot(A,B)
699 | ```
700 |
701 |
702 |
703 |
704 | array([[ 28, 31],
705 | [100, 112]])
706 |
707 |
708 |
709 |
710 | ```python
711 | torch.einsum('ij,jk->ik', A_torch, B_torch)
712 | ```
713 |
714 |
715 |
716 |
717 | tensor([[ 28, 31],
718 | [100, 112]])
719 |
720 |
721 |
722 |
723 | ```python
724 | torch.einsum('ij,jk', A_torch, B_torch)
725 | ```
726 |
727 |
728 |
729 |
730 | tensor([[ 28, 31],
731 | [100, 112]])
732 |
733 |
734 |
735 | ### 2.9 Inner product of two matrices
736 |
737 |
738 | ```python
739 | A.shape, B.shape, C.shape
740 | ```
741 |
742 |
743 |
744 |
745 | ((2, 3), (3, 2), (3, 3))
746 |
747 |
748 |
749 |
750 | ```python
751 | print(A)
752 | print('---')
753 | print(C)
754 | ```
755 |
756 | [[0 1 2]
757 | [3 4 5]]
758 | ---
759 | [[0 1 2]
760 | [3 4 5]
761 | [6 7 8]]
762 |
763 |
764 |
765 | ```python
766 | np.einsum('ij,kj->ik', A, C)
767 | ```
768 |
769 |
770 |
771 |
772 | array([[ 5, 14, 23],
773 | [14, 50, 86]])
774 |
775 |
776 |
777 |
778 | ```python
779 | torch.einsum('ij,kj->ik', A_torch, C_torch)
780 | ```
781 |
782 |
783 |
784 |
785 | tensor([[ 5, 14, 23],
786 | [14, 50, 86]])
787 |
788 |
789 |
790 |
791 | ```python
792 | i,j = A.shape
793 | k,j2 = C.shape
794 | assert j == j2
795 |
796 | result = np.empty((i,k))
797 |
798 | for index_i in range(i):
799 | for index_k in range(k):
800 | total = 0
801 | for index_j in range(j):
802 | total += A[index_i, index_j] * C[index_k, index_j]
803 | result[index_i, index_k] = total
804 |
805 | result
806 | ```
807 |
808 |
809 |
810 |
811 | array([[ 5., 14., 23.],
812 | [14., 50., 86.]])
813 |
814 |
815 |
816 | ## 3. Higher-order tensors
817 |
818 | ### 3.1 Each row of A multiplied (element-wise) by C
819 |
820 |
821 | ```python
822 | np.einsum('ij,kj->ikj', A, C)
823 | ```
824 |
825 |
826 |
827 |
828 | array([[[ 0, 1, 4],
829 | [ 0, 4, 10],
830 | [ 0, 7, 16]],
831 |
832 | [[ 0, 4, 10],
833 | [ 9, 16, 25],
834 | [18, 28, 40]]])
835 |
836 |
837 |
838 |
839 | ```python
840 | A
841 | ```
842 |
843 |
844 |
845 |
846 | array([[0, 1, 2],
847 | [3, 4, 5]])
848 |
849 |
850 |
851 |
852 | ```python
853 | C
854 | ```
855 |
856 |
857 |
858 |
859 | array([[0, 1, 2],
860 | [3, 4, 5],
861 | [6, 7, 8]])
862 |
863 |
864 |
865 |
866 | ```python
867 | torch.einsum('ij,kj->ikj', A_torch, C_torch)
868 | ```
869 |
870 |
871 |
872 |
873 | tensor([[[ 0, 1, 4],
874 | [ 0, 4, 10],
875 | [ 0, 7, 16]],
876 |
877 | [[ 0, 4, 10],
878 | [ 9, 16, 25],
879 | [18, 28, 40]]])
880 |
881 |
882 |
883 |
884 | ```python
885 | i,j = A.shape
886 | k,j2 = C.shape
887 | assert j == j2
888 |
889 | result = np.empty((i,k,j))
890 |
891 | for index_i in range(i):
892 | for index_k in range(k):
893 | for index_j in range(j):
894 | result[index_i, index_k, index_j] = A[index_i, index_j] * C[index_k, index_j]
895 |
896 | result
897 | ```
898 |
899 |
900 |
901 |
902 | array([[[ 0., 1., 4.],
903 | [ 0., 4., 10.],
904 | [ 0., 7., 16.]],
905 |
906 | [[ 0., 4., 10.],
907 | [ 9., 16., 25.],
908 | [18., 28., 40.]]])
909 |
910 |
911 |
912 | ### 3.2 Each element of first matrix multiplied (element-wise) by second matrix
913 |
914 |
915 | ```python
916 | A
917 | ```
918 |
919 |
920 |
921 |
922 | array([[0, 1, 2],
923 | [3, 4, 5]])
924 |
925 |
926 |
927 |
928 | ```python
929 | B
930 | ```
931 |
932 |
933 |
934 |
935 | array([[ 6, 7],
936 | [ 8, 9],
937 | [10, 11]])
938 |
939 |
940 |
941 |
942 | ```python
943 | np.einsum('ij,kl->ijkl', A, B)
944 | ```
945 |
946 |
947 |
948 |
949 | array([[[[ 0, 0],
950 | [ 0, 0],
951 | [ 0, 0]],
952 |
953 | [[ 6, 7],
954 | [ 8, 9],
955 | [10, 11]],
956 |
957 | [[12, 14],
958 | [16, 18],
959 | [20, 22]]],
960 |
961 |
962 | [[[18, 21],
963 | [24, 27],
964 | [30, 33]],
965 |
966 | [[24, 28],
967 | [32, 36],
968 | [40, 44]],
969 |
970 | [[30, 35],
971 | [40, 45],
972 | [50, 55]]]])
973 |
974 |
975 |
976 |
977 | ```python
978 | i,j = A.shape
979 | k,l = B.shape
980 |
981 | result = np.empty((i,j,k,l))
982 |
983 | for index_i in range(i):
984 | for index_j in range(j):
985 | for index_k in range(k):
986 | for index_l in range(l):
987 | result[index_i, index_j, index_k, index_l] = A[index_i, index_j] * B[index_k, index_l]
988 |
989 | result
990 | ```
991 |
992 |
993 |
994 |
995 | array([[[[ 0., 0.],
996 | [ 0., 0.],
997 | [ 0., 0.]],
998 |
999 | [[ 6., 7.],
1000 | [ 8., 9.],
1001 | [10., 11.]],
1002 |
1003 | [[12., 14.],
1004 | [16., 18.],
1005 | [20., 22.]]],
1006 |
1007 |
1008 | [[[18., 21.],
1009 | [24., 27.],
1010 | [30., 33.]],
1011 |
1012 | [[24., 28.],
1013 | [32., 36.],
1014 | [40., 44.]],
1015 |
1016 | [[30., 35.],
1017 | [40., 45.],
1018 | [50., 55.]]]])
1019 |
1020 |
1021 |
1022 | ## 4. Some examples in PyTorch
1023 |
1024 | ### 4.1 Batch multiplication
1025 |
1026 |
1027 | ```python
1028 | m1 = torch.randn(4,5,3)
1029 | m2 = torch.randn(4,3,5)
1030 | torch.bmm(m1,m2)
1031 | ```
1032 |
1033 |
1034 |
1035 |
1036 | tensor([[[ 8.2181e-01, -5.8035e-01, 2.2078e+00, 1.4295e+00, 1.8635e+00],
1037 | [ 8.4052e-01, -1.0589e-01, 1.4207e+00, 9.9271e-01, 2.3920e+00],
1038 | [-1.5352e-02, -5.3438e-01, 3.2493e+00, 1.4200e+00, 3.1127e+00],
1039 | [-1.3939e+00, 1.3775e+00, -2.2805e+00, -1.9652e+00, 5.8474e-01],
1040 | [ 2.0536e+00, -6.2420e-01, 2.3070e+00, 2.0755e+00, 2.6713e+00]],
1041 |
1042 | [[ 7.0778e-01, 1.4530e-01, 1.9873e+00, 2.1278e+00, -3.3463e-01],
1043 | [ 1.5298e-01, -1.7556e+00, -2.0336e+00, -3.3895e+00, 2.8165e-03],
1044 | [-4.9915e-01, 2.9698e-02, -9.9656e-01, -5.8711e-01, 3.3424e-01],
1045 | [-7.3538e-02, -6.1606e-01, -1.1962e+00, -1.9709e+00, -1.5120e-02],
1046 | [-8.9944e-01, -2.8037e-01, -2.4217e+00, -2.2450e+00, 5.3909e-01]],
1047 |
1048 | [[ 7.2014e-01, 1.4626e+00, -5.0004e-01, 1.5049e+00, -4.8354e-01],
1049 | [-1.4394e+00, 2.9822e-01, 1.2814e+00, 2.6319e+00, 3.4886e+00],
1050 | [ 4.2215e-01, 1.8446e+00, -5.1419e-03, 1.3402e-01, 1.1296e-01],
1051 | [-2.9111e-01, 7.2741e-01, 1.2039e-01, 4.1217e+00, 1.5959e+00],
1052 | [ 4.1003e-01, -1.7096e-01, -2.7092e-01, -2.1486e+00, -1.2508e+00]],
1053 |
1054 | [[ 8.9113e-01, 1.7029e+00, 1.9923e-01, -3.2940e-01, -9.3853e-01],
1055 | [ 7.0222e+00, 1.5919e-01, -2.6844e+00, 1.0894e+00, 1.2812e+00],
1056 | [ 6.5914e-02, -1.0917e+00, -1.0884e-01, 6.8298e-01, 1.0476e+00],
1057 | [ 3.3019e+00, 2.1760e+00, -1.3713e+00, -1.1536e+00, -1.7116e+00],
1058 | [-4.7698e+00, -1.7244e-01, 1.6248e+00, -9.6778e-01, -1.0414e+00]]])
1059 |
1060 |
1061 |
1062 | This is equivalent to:
1063 |
1064 |
1065 | ```python
1066 | torch.einsum('bij,bjk->bik', m1, m2)
1067 | ```
1068 |
1069 |
1070 |
1071 |
1072 | tensor([[[ 8.2181e-01, -5.8035e-01, 2.2078e+00, 1.4295e+00, 1.8635e+00],
1073 | [ 8.4052e-01, -1.0589e-01, 1.4207e+00, 9.9271e-01, 2.3920e+00],
1074 | [-1.5352e-02, -5.3438e-01, 3.2493e+00, 1.4200e+00, 3.1127e+00],
1075 | [-1.3939e+00, 1.3775e+00, -2.2805e+00, -1.9652e+00, 5.8474e-01],
1076 | [ 2.0536e+00, -6.2420e-01, 2.3070e+00, 2.0755e+00, 2.6713e+00]],
1077 |
1078 | [[ 7.0778e-01, 1.4530e-01, 1.9873e+00, 2.1278e+00, -3.3463e-01],
1079 | [ 1.5298e-01, -1.7556e+00, -2.0336e+00, -3.3895e+00, 2.8165e-03],
1080 | [-4.9915e-01, 2.9698e-02, -9.9656e-01, -5.8711e-01, 3.3424e-01],
1081 | [-7.3538e-02, -6.1606e-01, -1.1962e+00, -1.9709e+00, -1.5120e-02],
1082 | [-8.9944e-01, -2.8037e-01, -2.4217e+00, -2.2450e+00, 5.3909e-01]],
1083 |
1084 | [[ 7.2014e-01, 1.4626e+00, -5.0004e-01, 1.5049e+00, -4.8354e-01],
1085 | [-1.4394e+00, 2.9822e-01, 1.2814e+00, 2.6319e+00, 3.4886e+00],
1086 | [ 4.2215e-01, 1.8446e+00, -5.1419e-03, 1.3402e-01, 1.1296e-01],
1087 | [-2.9111e-01, 7.2741e-01, 1.2039e-01, 4.1217e+00, 1.5959e+00],
1088 | [ 4.1003e-01, -1.7096e-01, -2.7092e-01, -2.1486e+00, -1.2508e+00]],
1089 |
1090 | [[ 8.9113e-01, 1.7029e+00, 1.9923e-01, -3.2940e-01, -9.3853e-01],
1091 | [ 7.0222e+00, 1.5919e-01, -2.6844e+00, 1.0894e+00, 1.2812e+00],
1092 | [ 6.5914e-02, -1.0917e+00, -1.0884e-01, 6.8298e-01, 1.0476e+00],
1093 | [ 3.3019e+00, 2.1760e+00, -1.3713e+00, -1.1536e+00, -1.7116e+00],
1094 | [-4.7698e+00, -1.7244e-01, 1.6248e+00, -9.6778e-01, -1.0414e+00]]])
1095 |
1096 |
1097 |
1098 | ### The official implementation of [MoCo paper](https://arxiv.org/pdf/1911.05722.pdf) uses `einsum` for calculating dot products with positive and negative examples [as shown here](https://github.com/facebookresearch/moco/blob/master/moco/builder.py#L143-L146):
1099 |
1100 |
1101 | ```python
1102 | # positive logits: Nx1
1103 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
1104 | # negative logits: NxK
1105 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
1106 | ```
1107 |
--------------------------------------------------------------------------------
/markdowns/groups.md:
--------------------------------------------------------------------------------
1 |
2 | ### Groups in convolution
3 |
4 | It is easier to read the explanation and then look at the visuals and the code to understand how it works. There will not be much text to explain what is going on.
5 |
6 | In the code shown below, x has value `(c % 16) + 1` where `c` is the index of second axis.
7 |
8 |
9 | ```python
10 | import torch
11 | import torch.nn as nn
12 | from itertools import product
13 |
14 | x = torch.ones(2,16,10,10)
15 | for n,c,h,w in product( *[range(i) for i in x.size()] ):
16 | val = (c % 16) + 1
17 | x[n,c,h,w] = val
18 | ```
19 |
20 | ### Official documentation:
21 |
22 | * `groups` controls the connections between inputs and outputs.
23 | * `in_channels` and `out_channels` must both be divisible by `groups`. For example,
24 |
25 | * At groups=1, all inputs are convolved to all outputs.
26 | * At groups=2, the operation becomes equivalent to having two conv
27 | layers side by side, each seeing half the input channels,
28 | and producing half the output channels, and both subsequently
29 | concatenated.
30 | * At groups= `in_channels`, each input channel is convolved with
31 | its own set of filters
32 |
33 | Here is how it can be explained visually:
34 |
35 | 
36 |
37 | Define convolutional layers with number of groups = 1,2 and 4:
38 |
39 | ```python
40 | conv_g1 = nn.Conv2d(in_channels=16, out_channels=4, kernel_size=1, groups=1)
41 | conv_g2 = nn.Conv2d(in_channels=16, out_channels=4, kernel_size=1, groups=2)
42 | conv_g4 = nn.Conv2d(in_channels=16, out_channels=4, kernel_size=1, groups=4)
43 | conv_nets = [conv_g1, conv_g2, conv_g4]
44 | ```
45 |
46 | Can you guess what the size of the weights looks like?
47 |
48 | ```python
49 | _ = [print(conv.weight.size()) for conv in conv_nets]
50 | ```
51 |
52 | torch.Size([4, 16, 1, 1])
53 | torch.Size([4, 8, 1, 1])
54 | torch.Size([4, 4, 1, 1])
55 |
56 |
57 |
58 | ```python
59 | _ = [print(conv.bias.size()) for conv in conv_nets]
60 | ```
61 |
62 | torch.Size([4])
63 | torch.Size([4])
64 | torch.Size([4])
65 |
66 | Initialize the parameters to have unit weights and zero biases:
67 |
68 | ```python
69 | for index, conv in enumerate(conv_nets, 1):
70 | nn.init.constant_(conv.weight, index)
71 | nn.init.constant_(conv.bias, 0)
72 | ```
73 |
74 |
75 | ```python
76 | _ = [print(conv.weight.mean().item()) for conv in conv_nets]
77 | ```
78 |
79 | 1.0
80 | 2.0
81 | 3.0
82 |
83 |
84 |
85 | ```python
86 | _ = [print(conv.bias.mean().item()) for conv in conv_nets]
87 | ```
88 |
89 | 0.0
90 | 0.0
91 | 0.0
92 |
93 | Regardless of the number of groups, the output size will always be the same:
94 |
95 | ```python
96 | _ = [print(conv(x).size()) for conv in conv_nets]
97 | ```
98 |
99 | torch.Size([2, 4, 10, 10])
100 | torch.Size([2, 4, 10, 10])
101 | torch.Size([2, 4, 10, 10])
102 |
103 | But the values will be different:
104 |
105 | ```python
106 | _ = [print(conv(x).mean().item()) for conv in conv_nets]
107 | ```
108 |
109 | 136.0
110 | 136.0
111 | 102.0
112 |
113 |
114 |
115 | ```python
116 | y = conv_g4(x)
117 | ```
118 |
119 | 
120 |
121 |
122 | ```python
123 | for i in range(4):
124 | print(y[:,i].mean(1))
125 | ```
126 |
127 | tensor([[ 30., 30., 30., 30., 30., 30., 30., 30., 30., 30.],
128 | [ 30., 30., 30., 30., 30., 30., 30., 30., 30., 30.]])
129 | tensor([[ 78., 78., 78., 78., 78., 78., 78., 78., 78., 78.],
130 | [ 78., 78., 78., 78., 78., 78., 78., 78., 78., 78.]])
131 | tensor([[ 126., 126., 126., 126., 126., 126., 126., 126., 126.,
132 | 126.],
133 | [ 126., 126., 126., 126., 126., 126., 126., 126., 126.,
134 | 126.]])
135 | tensor([[ 174., 174., 174., 174., 174., 174., 174., 174., 174.,
136 | 174.],
137 | [ 174., 174., 174., 174., 174., 174., 174., 174., 174.,
138 | 174.]])
139 |
--------------------------------------------------------------------------------
/markdowns/hooks.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ```python
4 | import torch
5 | import torch.nn as nn
6 | ```
7 |
8 | ### Let's create a series of basic operations and calculate gradient
9 |
10 |
11 | ```python
12 | x = torch.tensor([1.,2.,3.], requires_grad=True)
13 | y = x*x
14 | z = torch.sum(y)
15 | z
16 | ```
17 |
18 |
19 |
20 |
21 | tensor(14., grad_fn=)
22 |
23 |
24 |
25 |
26 | ```python
27 | y.requires_grad, z.requires_grad
28 | ```
29 |
30 |
31 |
32 |
33 | (True, True)
34 |
35 |
36 |
37 |
38 | ```python
39 | x.grad, y.grad, z.grad
40 | ```
41 |
42 |
43 |
44 |
45 | (None, None, None)
46 |
47 |
48 |
49 |
50 | ```python
51 | z.backward()
52 | ```
53 |
54 |
55 | ```python
56 | x.grad, y.grad, z.grad
57 | ```
58 |
59 |
60 |
61 |
62 | (tensor([2., 4., 6.]), None, None)
63 |
64 |
65 |
66 | ### You can register a hook to a PyTorch `Module` or `Tensor`
67 |
68 |
69 | ```python
70 | x = torch.tensor([1.,2.,3.], requires_grad=True)
71 | y = x*x
72 | z = torch.sum(y)
73 | z
74 | ```
75 |
76 |
77 |
78 |
79 | tensor(14., grad_fn=)
80 |
81 |
82 |
83 | #### `register_hook` registers a backward hook
84 |
85 |
86 | ```python
87 | h = y.register_hook(lambda grad: print(grad))
88 | ```
89 |
90 |
91 | ```python
92 | z.backward()
93 | ```
94 |
95 | tensor([1., 1., 1.])
96 |
97 |
98 |
99 | ```python
100 | x.grad, y.grad, z.grad
101 | ```
102 |
103 |
104 |
105 |
106 | (tensor([2., 4., 6.]), None, None)
107 |
108 |
109 |
110 | #### You can also use a hook to manipulte gradients. The hook should have the following signature:
111 |
112 | `hook(grad) -> Tensor or None`
113 |
114 |
115 | ```python
116 | x = torch.tensor([1.,2.,3.], requires_grad=True)
117 | y = x*x
118 | z = torch.sum(y)
119 | z
120 | ```
121 |
122 |
123 |
124 |
125 | tensor(14., grad_fn=)
126 |
127 |
128 |
129 | #### Let's add random noise to gradient of `y`
130 |
131 |
132 | ```python
133 | def add_random_noise_and_print(grad):
134 | noise = torch.randn(grad.size())
135 | print('Noise:', noise)
136 | return grad + noise
137 |
138 | h = y.register_hook(add_random_noise_and_print)
139 | ```
140 |
141 |
142 | ```python
143 | x.grad, y.grad, z.grad
144 | ```
145 |
146 |
147 |
148 |
149 | (None, None, None)
150 |
151 |
152 |
153 |
154 | ```python
155 | z.backward()
156 | ```
157 |
158 | Noise: tensor([1.1173, 1.2854, 0.0611])
159 |
160 |
161 |
162 | ```python
163 | x.grad, y.grad, z.grad
164 | ```
165 |
166 |
167 |
168 |
169 | (tensor([4.2346, 9.1415, 6.3665]), None, None)
170 |
171 |
172 |
173 | ### Forward and backward hooks
174 |
175 | There are three hooks listed below. They are available only for `nn.Module`s
176 |
177 | * __`register_forward_pre_hook`__: function is called BEFORE forward call
178 | * Signature: `hook(module, input) -> None
179 |
180 |
181 | * __`register_forward_hook`__: function is called AFTER forward call
182 | * Signature: `hook(module, input, output) -> None`
183 |
184 |
185 | * __`register_backward_hook`__: function is called AFTER gradients wrt module input are computed
186 | * Signature: `hook(module, grad_input, grad_output) -> Tensor or None`
187 |
188 | #### `register_forward_pre_hook`
189 |
190 |
191 | ```python
192 | linear = nn.Linear(10,1)
193 | ```
194 |
195 |
196 | ```python
197 | [param for param in linear.parameters()]
198 | ```
199 |
200 |
201 |
202 |
203 | [Parameter containing:
204 | tensor([[-0.3009, 0.0351, -0.2786, 0.1136, -0.2712, 0.0183, -0.2881, -0.1555,
205 | -0.3108, 0.0767]], requires_grad=True), Parameter containing:
206 | tensor([0.0377], requires_grad=True)]
207 |
208 |
209 |
210 |
211 | ```python
212 | h = linear.register_forward_pre_hook(lambda _, inp: print(inp))
213 | ```
214 |
215 |
216 | ```python
217 | x = torch.ones([8,10])
218 | ```
219 |
220 |
221 | ```python
222 | y = linear(x)
223 | ```
224 |
225 | (tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
226 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
227 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
228 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
229 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
230 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
231 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
232 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),)
233 |
234 |
235 | #### `register_forward_hook`
236 |
237 |
238 | ```python
239 | linear = nn.Linear(10,1)
240 | [param for param in linear.parameters()]
241 | ```
242 |
243 |
244 |
245 |
246 | [Parameter containing:
247 | tensor([[ 0.0978, -0.1878, 0.0189, 0.3040, 0.1120, 0.1977, 0.2137, -0.2841,
248 | -0.0718, 0.2079]], requires_grad=True), Parameter containing:
249 | tensor([0.2796], requires_grad=True)]
250 |
251 |
252 |
253 |
254 | ```python
255 | h = linear.register_forward_hook(lambda _, inp, out: print(inp, '\n\n', out))
256 | ```
257 |
258 |
259 | ```python
260 | y = linear(x)
261 | ```
262 |
263 | (tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
264 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
265 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
266 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
267 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
268 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
269 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
270 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),)
271 |
272 | tensor([[0.8878],
273 | [0.8878],
274 | [0.8878],
275 | [0.8878],
276 | [0.8878],
277 | [0.8878],
278 | [0.8878],
279 | [0.8878]], grad_fn=)
280 |
281 |
282 | Just to verify, the result above can also be computed manually like so:
283 |
284 |
285 | ```python
286 | [param for param in linear.parameters()][0].sum() + [param for param in linear.parameters()][1]
287 | ```
288 |
289 |
290 |
291 |
292 | tensor([0.8878], grad_fn=)
293 |
294 |
295 |
296 | #### `register_backward_hook`
297 |
298 |
299 | ```python
300 | linear = nn.Linear(3,1)
301 | [param for param in linear.parameters()]
302 | ```
303 |
304 |
305 |
306 |
307 | [Parameter containing:
308 | tensor([[0.5395, 0.2303, 0.5583]], requires_grad=True), Parameter containing:
309 | tensor([-0.3510], requires_grad=True)]
310 |
311 |
312 |
313 |
314 | ```python
315 | def print_sizes(module, grad_inp, grad_out):
316 | print(len(grad_inp), len(grad_out))
317 | print('===')
318 | print('Grad input sizes:', [i.size() for i in grad_inp if i is not None])
319 | print('Grad output sizes:', [i.size() for i in grad_out if i is not None])
320 | print('===')
321 | print('Grad_input 0:', grad_inp[0])
322 | print('Grad_input 1:', grad_inp[1])
323 | print('Grad_input 2:', grad_inp[2])
324 | print('===')
325 | print(grad_out)
326 |
327 | h = linear.register_backward_hook(print_sizes)
328 | ```
329 |
330 |
331 | ```python
332 | x = torch.ones([8,3]) * 10
333 | y = linear(x)
334 | ```
335 |
336 |
337 | ```python
338 | y.backward(torch.ones_like(y) * 1.5)
339 | ```
340 |
341 | 3 1
342 | ===
343 | Grad input sizes: [torch.Size([1]), torch.Size([3, 1])]
344 | Grad output sizes: [torch.Size([8, 1])]
345 | ===
346 | Grad_input 0: tensor([12.])
347 | Grad_input 1: None
348 | Grad_input 2: tensor([[120.],
349 | [120.],
350 | [120.]])
351 | ===
352 | (tensor([[1.5000],
353 | [1.5000],
354 | [1.5000],
355 | [1.5000],
356 | [1.5000],
357 | [1.5000],
358 | [1.5000],
359 | [1.5000]]),)
360 |
361 |
362 | `register_backward_hook` can be used to tweak `grad_inp` values after they have been calculated. To do so, the hook function should return a new gradient with respect to input that will be used in place of `grad_input` in subsequent computations
363 |
364 | In the example below, we add random noise to the gradient of bias using the hook:
365 |
366 |
367 | ```python
368 | linear = nn.Linear(3,1)
369 | [param for param in linear.parameters()]
370 | ```
371 |
372 |
373 |
374 |
375 | [Parameter containing:
376 | tensor([[-0.5438, 0.5539, 0.5210]], requires_grad=True),
377 | Parameter containing:
378 | tensor([-0.4839], requires_grad=True)]
379 |
380 |
381 |
382 |
383 | ```python
384 | def add_noise_and_print(module, grad_inp, grad_out):
385 | noise = torch.randn(grad_inp[0].size())
386 | print('Noise:', noise)
387 | print('===')
388 | print('Grad input sizes:', [i.size() for i in grad_inp if i is not None])
389 | print('Grad output sizes:', [i.size() for i in grad_out if i is not None])
390 | print('===')
391 | print('Grad_input 0:', grad_inp[0])
392 | print('Grad_input 1:', grad_inp[1])
393 | print('Grad_input 2:', grad_inp[2])
394 | print('===')
395 | print(grad_out)
396 | return (grad_inp[0] + noise, None, grad_inp[2])
397 |
398 | h = linear.register_backward_hook(add_noise_and_print)
399 | ```
400 |
401 |
402 | ```python
403 | x = torch.ones([8,3]) * 10
404 | y = linear(x)
405 | ```
406 |
407 |
408 | ```python
409 | y.backward(torch.ones_like(y) * 1.5)
410 | ```
411 |
412 | Noise: tensor([0.7553])
413 | ===
414 | Grad input sizes: [torch.Size([1]), torch.Size([3, 1])]
415 | Grad output sizes: [torch.Size([8, 1])]
416 | ===
417 | Grad_input 0: tensor([12.])
418 | Grad_input 1: None
419 | Grad_input 2: tensor([[120.],
420 | [120.],
421 | [120.]])
422 | ===
423 | (tensor([[1.5000],
424 | [1.5000],
425 | [1.5000],
426 | [1.5000],
427 | [1.5000],
428 | [1.5000],
429 | [1.5000],
430 | [1.5000]]),)
431 |
432 |
433 | `linear.bias.grad` was originally `12.0` but a random noise `Noise: tensor([0.7553])` was added to it:
434 |
435 | ```
436 | 12 + 0.7553 = 12.7553
437 | ```
438 |
439 |
440 | ```python
441 | linear.bias.grad
442 | ```
443 |
444 |
445 |
446 |
447 | tensor([12.7553])
448 |
449 |
450 |
451 |
452 | ```python
453 | linear.weight.grad
454 | ```
455 |
456 |
457 |
458 |
459 | tensor([[120., 120., 120.]])
460 |
461 |
462 |
--------------------------------------------------------------------------------
/markdowns/register_buffer.md:
--------------------------------------------------------------------------------
1 |
2 | ## `register_buffer`
3 |
4 | The documentation says:
5 |
6 | > This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the persistent state.
7 |
8 | > Buffers can be accessed as attributes using given names.
9 |
10 | As an example we will just implement a simple running average using `register_buffer`:
11 |
12 |
13 | ```python
14 | import torch
15 | import torch.nn as nn
16 | ```
17 |
18 | `running_average` and `count` are registered as buffers inside the `Model` definition. When an object is created from the `Model` class, it will have `running_average` and `count` as attributed.
19 |
20 | In the `forward` method, you can even define how these attributes will be updated once the model is called.
21 |
22 |
23 | ```python
24 | class Model(nn.Module):
25 | def __init__(self):
26 | super(Model, self).__init__()
27 | self.register_buffer('running_average', torch.Tensor([0.0]))
28 | self.register_buffer('count', torch.Tensor([0.0]))
29 |
30 | def forward(self, x):
31 | # self.count keeps a count of how many times the model was called
32 | self.count += 1
33 |
34 | # self.running_average keeps the running average in memory
35 | self.running_average = self.running_average.mul(self.count-1).add(x).div(self.count)
36 | return x
37 | ```
38 |
39 | ### Note that items registered as buffers are not considered "parameters" of the model. Hence they will not show up under `model.parameters()`:
40 |
41 |
42 | ```python
43 | model = Model()
44 | list(model.parameters())
45 | ```
46 |
47 |
48 |
49 |
50 | []
51 |
52 |
53 |
54 | ### However they do show up in the `state_dict`. What this means is that when you save the `state_dict`, these values will be saved (and later retrieved) as well:
55 |
56 |
57 | ```python
58 | model.state_dict()
59 | ```
60 |
61 |
62 |
63 |
64 | OrderedDict([('running_average', tensor([0.])), ('count', tensor([0.]))])
65 |
66 |
67 |
68 | ### Now let's just call the model a couple of times with different values and see how these values change:
69 |
70 |
71 | ```python
72 | model(10)
73 | ```
74 |
75 |
76 |
77 |
78 | 10
79 |
80 |
81 |
82 |
83 | ```python
84 | model.count, model.running_average
85 | ```
86 |
87 |
88 |
89 |
90 | (tensor([1.]), tensor([10.]))
91 |
92 |
93 |
94 |
95 | ```python
96 | model(10)
97 | model.count, model.running_average
98 | ```
99 |
100 |
101 |
102 |
103 | (tensor([2.]), tensor([10.]))
104 |
105 |
106 |
107 |
108 | ```python
109 | model(5)
110 | model.count, model.running_average
111 | ```
112 |
113 |
114 |
115 |
116 | (tensor([3.]), tensor([8.3333]))
117 |
118 |
119 |
120 |
121 | ```python
122 | model(15)
123 | model.count, model.running_average
124 | ```
125 |
126 |
127 |
128 |
129 | (tensor([4.]), tensor([10.]))
130 |
131 |
132 |
133 |
134 | ```python
135 | model(1)
136 | model.count, model.running_average
137 | ```
138 |
139 |
140 |
141 |
142 | (tensor([5.]), tensor([8.2000]))
143 |
144 |
145 |
--------------------------------------------------------------------------------
/markdowns/weightnorm.md:
--------------------------------------------------------------------------------
1 |
2 | Weight normalization was introduced by OpenAI in their paper https://arxiv.org/abs/1602.07868
3 |
4 | It is now a part of PyTorch as built-in functionality. It helps in speeding up gradient descent and has been used in Wavenet.
5 |
6 | The basic idea is pretty straight-forward: for each neuron, the weight vector `w` is broken down into two components: its magnitude (`g`) and direction (unit vector in direction of `v`: `v / ||v||`). This we have:
7 |
8 | > `w = g * v / ||v||` where `w` and `v` are vectors and `g` is a scalar
9 |
10 |
11 | Instead of optimizing `w` directly, optimizing on `g` and `v` separately has been found to be faster and more accurate.
12 |
13 |
14 | ```python
15 | import torch
16 | import torch.nn as nn
17 | ```
18 |
19 | ### There are two things to observe about weight normalization:
20 |
21 | 1) It increases the number of parameters (by the amount of neurons)
22 |
23 | 2) It adds a forward pre-hook to the module
24 |
25 |
26 | ```python
27 | get_num_of_params = lambda model: sum([p.numel() for p in model.parameters()])
28 | get_param_name_and_size = lambda model: [(name, param.size()) for (name, param) in model.named_parameters()]
29 | ```
30 |
31 |
32 | ```python
33 | linear = nn.Linear(5,3)
34 | ```
35 |
36 |
37 | ```python
38 | linear._forward_pre_hooks #no hooks present
39 | ```
40 |
41 |
42 |
43 |
44 | OrderedDict()
45 |
46 |
47 |
48 |
49 | ```python
50 | get_num_of_params(linear), get_param_name_and_size(linear)
51 | ```
52 |
53 |
54 |
55 |
56 | (18, [('weight', torch.Size([3, 5])), ('bias', torch.Size([3]))])
57 |
58 |
59 |
60 | ### Weight normalization in PyTorch can be done by calling the `nn.utils.weight_norm` function.
61 |
62 | By default, it normalizes the `weight` of a module:
63 |
64 |
65 | ```python
66 | _ = nn.utils.weight_norm(linear)
67 | ```
68 |
69 | The number of parameters increased by 3 (we have 3 neurons here). Also the parameter `name` is replaced by two parameters `name_g` and `name_v` respectively:
70 |
71 |
72 | ```python
73 | get_num_of_params(linear), get_param_name_and_size(linear)
74 | ```
75 |
76 |
77 |
78 |
79 | (21,
80 | [('bias', torch.Size([3])),
81 | ('weight_g', torch.Size([3, 1])),
82 | ('weight_v', torch.Size([3, 5]))])
83 |
84 |
85 |
86 |
87 | ```python
88 | linear.weight_g.data
89 | ```
90 |
91 |
92 |
93 |
94 | tensor([[0.5075],
95 | [0.4952],
96 | [0.6064]])
97 |
98 |
99 |
100 |
101 | ```python
102 | linear.weight_v.data
103 | ```
104 |
105 |
106 |
107 |
108 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
109 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
110 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
111 |
112 |
113 |
114 | Also note that `linear` module now has a forward pre-hook added to it:
115 |
116 |
117 | ```python
118 | linear._forward_pre_hooks
119 | ```
120 |
121 |
122 |
123 |
124 | OrderedDict([(0, )])
125 |
126 |
127 |
128 | The original `weight` is also present but is not a part of `parameters` attribute of `linear` module. Also note that `weight_v.data` is the same as `weight.data`:
129 |
130 |
131 | ```python
132 | linear.weight.data
133 | ```
134 |
135 |
136 |
137 |
138 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
139 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
140 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
141 |
142 |
143 |
144 |
145 | ```python
146 | linear.weight_v.data
147 | ```
148 |
149 |
150 |
151 |
152 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
153 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
154 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
155 |
156 |
157 |
158 | ### How is `weight_g` calculated? We will look at it in a moment
159 |
160 |
161 | ```python
162 | linear.weight_g.data
163 | ```
164 |
165 |
166 |
167 |
168 | tensor([[0.5075],
169 | [0.4952],
170 | [0.6064]])
171 |
172 |
173 |
174 | ### We can get `name` from `name_g` and `name_v` using `torch._weight_norm` function. We will also look at it in a moment
175 |
176 |
177 | ```python
178 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data)
179 | ```
180 |
181 |
182 |
183 |
184 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
185 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
186 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
187 |
188 |
189 |
190 | ### Let's look at `norm_except_dim` function
191 |
192 | Pretty self-explanatory: it calculates the norm of a tensor except at dimsension provided to the function. We can calculate any Lp norm and omit any dimension we want.
193 |
194 | A few examples should make it clear.
195 |
196 |
197 | ```python
198 | ones = torch.ones(5,5,5)
199 | ones.size()
200 | ```
201 |
202 |
203 |
204 |
205 | torch.Size([5, 5, 5])
206 |
207 |
208 |
209 | It we omit dimension `0`, we have 25 elements each of value `1`. Thus their L2 norm is `5`:
210 |
211 |
212 | ```python
213 | norm = 2
214 | dim = 0
215 | y = torch.norm_except_dim(ones, norm, dim)
216 | print('y.size():', y.size())
217 | print('y:', y)
218 | ```
219 |
220 | y.size(): torch.Size([5, 1, 1])
221 | y: tensor([[[5.]],
222 |
223 | [[5.]],
224 |
225 | [[5.]],
226 |
227 | [[5.]],
228 |
229 | [[5.]]])
230 |
231 |
232 | Similar, omitting dim = `0` and calculating L3 norm gives 25 ** (1/3) = 2.9240:
233 |
234 |
235 | ```python
236 | norm = 3
237 | dim = 0
238 | y = torch.norm_except_dim(ones, norm, dim)
239 | print('y.size():', y.size())
240 | print('y:', y)
241 | ```
242 |
243 | y.size(): torch.Size([5, 1, 1])
244 | y: tensor([[[2.9240]],
245 |
246 | [[2.9240]],
247 |
248 | [[2.9240]],
249 |
250 | [[2.9240]],
251 |
252 | [[2.9240]]])
253 |
254 |
255 | Omitting dim = `2` changes the shape of the output:
256 |
257 |
258 | ```python
259 | norm = 2
260 | dim = 2
261 | y = torch.norm_except_dim(ones, norm, dim)
262 | print('y.size():', y.size())
263 | print('y:', y)
264 | ```
265 |
266 | y.size(): torch.Size([1, 1, 5])
267 | y: tensor([[[5., 5., 5., 5., 5.]]])
268 |
269 |
270 | Omitting dim = `-1` is the same as not omitting anything at all:
271 |
272 |
273 | ```python
274 | torch.norm_except_dim(ones, 2, -1)
275 | ```
276 |
277 |
278 |
279 |
280 | tensor(11.1803)
281 |
282 |
283 |
284 |
285 | ```python
286 | ones.norm()
287 | ```
288 |
289 |
290 |
291 |
292 | tensor(11.1803)
293 |
294 |
295 |
296 | ### By default `nn.utils.weight_norm` calls `torch.norm_except_dim` with `dim=0`. This is how we get `weight_g`:
297 |
298 |
299 | ```python
300 | torch.norm_except_dim(linear.weight.data, 2, 0)
301 | ```
302 |
303 |
304 |
305 |
306 | tensor([[0.5075],
307 | [0.4952],
308 | [0.6064]])
309 |
310 |
311 |
312 | It is the same as doing the below operation:
313 |
314 |
315 | ```python
316 | linear.weight.data.norm(dim=1)
317 | ```
318 |
319 |
320 |
321 |
322 | tensor([0.5075, 0.4952, 0.6064])
323 |
324 |
325 |
326 | ### Let's look at `_weight_norm` function
327 |
328 | This function is used to calculate `name` from `name_v` and `name_g`. Let's see how it works:
329 |
330 |
331 | ```python
332 | v = torch.randn(5,3)*10 + 4
333 | g = torch.randn(5,1)
334 | ```
335 |
336 | For the given values of `g` and `v`, `torch._weight_norm(v,g,0)` is basically the same as `g * v/v.norm(dim=1,keepdim=True)`:
337 |
338 |
339 | ```python
340 | torch._weight_norm(v,g,0)
341 | ```
342 |
343 |
344 |
345 |
346 | tensor([[-0.1452, -0.2595, -0.2552],
347 | [ 0.3853, 0.1218, 0.6896],
348 | [-0.0612, -0.0946, -0.0837],
349 | [-0.1722, 0.1708, 0.4013],
350 | [-0.0354, 0.0114, -0.0197]])
351 |
352 |
353 |
354 |
355 | ```python
356 | g * v/v.norm(dim=1,keepdim=True)
357 | ```
358 |
359 |
360 |
361 |
362 | tensor([[-0.1452, -0.2595, -0.2552],
363 | [ 0.3853, 0.1218, 0.6896],
364 | [-0.0612, -0.0946, -0.0837],
365 | [-0.1722, 0.1708, 0.4013],
366 | [-0.0354, 0.0114, -0.0197]])
367 |
368 |
369 |
370 | Similarly, `torch._weight_norm(v,g,1)` is basically the same as `g * v/v.norm(dim=0,keepdim=True)`
371 |
372 |
373 | ```python
374 | torch._weight_norm(v,g,1)
375 | ```
376 |
377 |
378 |
379 |
380 | tensor([[-0.2282, -0.2974, -0.2645],
381 | [ 0.3056, 0.0704, 0.3607],
382 | [-0.0779, -0.0878, -0.0702],
383 | [-0.0785, 0.0568, 0.1207],
384 | [-0.0178, 0.0042, -0.0065]])
385 |
386 |
387 |
388 |
389 | ```python
390 | g * v/v.norm(dim=0,keepdim=True)
391 | ```
392 |
393 |
394 |
395 |
396 | tensor([[-0.2282, -0.2974, -0.2645],
397 | [ 0.3056, 0.0704, 0.3607],
398 | [-0.0779, -0.0878, -0.0702],
399 | [-0.0785, 0.0568, 0.1207],
400 | [-0.0178, 0.0042, -0.0065]])
401 |
402 |
403 |
404 | And `torch._weight_norm(v,g,-1)` is basically the same as `g * v/v.norm()`
405 |
406 |
407 | ```python
408 | torch._weight_norm(v,g,-1)
409 | ```
410 |
411 |
412 |
413 |
414 | tensor([[-0.1003, -0.1792, -0.1762],
415 | [ 0.1343, 0.0424, 0.2403],
416 | [-0.0342, -0.0529, -0.0468],
417 | [-0.0345, 0.0342, 0.0804],
418 | [-0.0078, 0.0025, -0.0043]])
419 |
420 |
421 |
422 |
423 | ```python
424 | g * v/v.norm()
425 | ```
426 |
427 |
428 |
429 |
430 | tensor([[-0.1003, -0.1792, -0.1762],
431 | [ 0.1343, 0.0424, 0.2403],
432 | [-0.0342, -0.0529, -0.0468],
433 | [-0.0345, 0.0342, 0.0804],
434 | [-0.0078, 0.0025, -0.0043]])
435 |
436 |
437 |
438 | For `linear` module, this is how we get `weight` from `weight_v` and `weight_g` (notice `dim=0`):
439 |
440 |
441 | ```python
442 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0)
443 | ```
444 |
445 |
446 |
447 |
448 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
449 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
450 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
451 |
452 |
453 |
454 |
455 | ```python
456 | linear.weight.data
457 | ```
458 |
459 |
460 |
461 |
462 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
463 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
464 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
465 |
466 |
467 |
468 | ### But what is the point of the forward pre-hook?
469 |
470 |
471 | ```python
472 | linear._forward_pre_hooks
473 | ```
474 |
475 |
476 |
477 |
478 | OrderedDict([(0, )])
479 |
480 |
481 |
482 |
483 | ```python
484 | hook = linear._forward_pre_hooks[0]
485 | ```
486 |
487 | #### Let's first see what a hook does: it basically returns the value of `weight.data`.
488 |
489 |
490 | ```python
491 | hook.compute_weight(linear)
492 | ```
493 |
494 |
495 |
496 |
497 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
498 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
499 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]], grad_fn=)
500 |
501 |
502 |
503 | Let's say you are training `linear` on a dataset with `batch_size` = 8. After back-propagation and weight update, the values of `weight_g` and `weight_v` will be different:
504 |
505 |
506 | ```python
507 | batch_size = 8
508 | x = torch.randn(batch_size, 5)
509 | ```
510 |
511 |
512 | ```python
513 | y = linear(x)
514 | ```
515 |
516 |
517 | ```python
518 | loss = (y-1).sum()
519 | ```
520 |
521 |
522 | ```python
523 | loss.backward()
524 | ```
525 |
526 |
527 | ```python
528 | linear.weight
529 | ```
530 |
531 |
532 |
533 |
534 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
535 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
536 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]], grad_fn=)
537 |
538 |
539 |
540 |
541 | ```python
542 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0)
543 | ```
544 |
545 |
546 |
547 |
548 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
549 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
550 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]])
551 |
552 |
553 |
554 |
555 | ```python
556 | for param in linear.parameters():
557 | param.data = param.data - (param.grad.data*0.01)
558 | ```
559 |
560 | Weights `weight_v` and `weight_g` changed. Hence `weight` should now be equal to:
561 |
562 |
563 | ```python
564 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0)
565 | ```
566 |
567 |
568 |
569 |
570 | tensor([[ 0.2946, 0.1498, -0.0491, -0.0748, 0.3667],
571 | [ 0.0258, -0.2643, 0.1122, -0.3771, -0.0488],
572 | [ 0.2303, 0.4506, 0.0095, 0.2294, -0.3068]])
573 |
574 |
575 |
576 | But it's not:
577 |
578 |
579 | ```python
580 | linear.weight
581 | ```
582 |
583 |
584 |
585 |
586 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820],
587 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354],
588 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]], grad_fn=)
589 |
590 |
591 |
592 | This is why we need a hook. The hook will basically update `linear.weight` by calling `torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0)` during the next forward propagation:
593 |
594 |
595 | ```python
596 | _ = linear(x)
597 | ```
598 |
599 | `linear.weight` is updated thanks to the hook:
600 |
601 |
602 | ```python
603 | linear.weight
604 | ```
605 |
606 |
607 |
608 |
609 | tensor([[ 0.2946, 0.1498, -0.0491, -0.0748, 0.3667],
610 | [ 0.0258, -0.2643, 0.1122, -0.3771, -0.0488],
611 | [ 0.2303, 0.4506, 0.0095, 0.2294, -0.3068]], grad_fn=)
612 |
613 |
614 |
--------------------------------------------------------------------------------