├── .gitignore ├── README.md ├── assets ├── detached_y1.png ├── groups_1.png ├── groups_2.png └── undetached_y1.png └── markdowns ├── IN_LN_manual_calculations.md ├── Understanding losses.md ├── backprop.md ├── basics.md ├── batchnorm.md ├── dropout.md ├── einstein_summation.md ├── groups.md ├── hooks.md ├── register_buffer.md └── weightnorm.md /.gitignore: -------------------------------------------------------------------------------- 1 | create_toc.py 2 | toc.md 3 | .ipynb_checkpoints/** 4 | __pycache__/** 5 | *.ipynb 6 | **/.DS_Store 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Understanding nuts and bolts of neural networks with PyTorch 2 | 3 | __(Work in progress)__ 4 | 5 | This is a series of articles in an attempt to understand the witchcraft that is the neural networks (using PyTorch). 6 | 7 | ### Basic 8 | 9 | * [__Gradient descent and various ways to implement it in PyTorch__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/basics.md) 10 | 11 | * [__Dropout__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/dropout.md) 12 | 13 | * [__Batch normalization__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/batchnorm.md) 14 | 15 | * [__Backpropagation and detach__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/backprop.md) 16 | 17 | * [__Losses in PyTorch__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/Understanding%20losses.md) 18 | 19 | * [__register_buffer__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/register_buffer.md) 20 | 21 | ### Intermediate 22 | 23 | * [__Groups in convolution__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/groups.md) 24 | 25 | * [__Hooks__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/hooks.md) 26 | 27 | * [__Weight normalization__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/weightnorm.md) 28 | 29 | * [__Einstein summation__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/einstein_summation.md) 30 | 31 | * [__Instance and layer normalizations__](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/markdowns/IN_LN_manual_calculations.md) -------------------------------------------------------------------------------- /assets/detached_y1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/detached_y1.png -------------------------------------------------------------------------------- /assets/groups_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/groups_1.png -------------------------------------------------------------------------------- /assets/groups_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/groups_2.png -------------------------------------------------------------------------------- /assets/undetached_y1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/understanding-neuralnetworks-pytorch/2960bfed7da6a0c4415ee849c5d8f40208226b0e/assets/undetached_y1.png -------------------------------------------------------------------------------- /markdowns/IN_LN_manual_calculations.md: -------------------------------------------------------------------------------- 1 | ```python 2 | import torch 3 | import torch.nn as nn 4 | ``` 5 | 6 | ### 3D tensors 7 | 8 | 9 | ```python 10 | x = torch.arange(2*3*4).view(2,3,4).float() 11 | y = nn.InstanceNorm1d(3)(x) 12 | z = nn.LayerNorm(4, elementwise_affine=False)(x) 13 | ``` 14 | 15 | 16 | ```python 17 | num = x - x.mean(-1, keepdim=True) 18 | den = (num.pow(2).mean(-1, keepdim=True) + 1e-5).sqrt() 19 | num/den 20 | ``` 21 | 22 | 23 | 24 | 25 | tensor([[[-1.3416, -0.4472, 0.4472, 1.3416], 26 | [-1.3416, -0.4472, 0.4472, 1.3416], 27 | [-1.3416, -0.4472, 0.4472, 1.3416]], 28 | 29 | [[-1.3416, -0.4472, 0.4472, 1.3416], 30 | [-1.3416, -0.4472, 0.4472, 1.3416], 31 | [-1.3416, -0.4472, 0.4472, 1.3416]]]) 32 | 33 | 34 | 35 | 36 | ```python 37 | z 38 | ``` 39 | 40 | 41 | 42 | 43 | tensor([[[-1.3416, -0.4472, 0.4472, 1.3416], 44 | [-1.3416, -0.4472, 0.4472, 1.3416], 45 | [-1.3416, -0.4472, 0.4472, 1.3416]], 46 | 47 | [[-1.3416, -0.4472, 0.4472, 1.3416], 48 | [-1.3416, -0.4472, 0.4472, 1.3416], 49 | [-1.3416, -0.4472, 0.4472, 1.3416]]]) 50 | 51 | 52 | 53 | 54 | ```python 55 | y 56 | ``` 57 | 58 | 59 | 60 | 61 | tensor([[[-1.3416, -0.4472, 0.4472, 1.3416], 62 | [-1.3416, -0.4472, 0.4472, 1.3416], 63 | [-1.3416, -0.4472, 0.4472, 1.3416]], 64 | 65 | [[-1.3416, -0.4472, 0.4472, 1.3416], 66 | [-1.3416, -0.4472, 0.4472, 1.3416], 67 | [-1.3416, -0.4472, 0.4472, 1.3416]]]) 68 | 69 | 70 | 71 | However, `y` and `z` are not identical (I don't know why): 72 | 73 | 74 | ```python 75 | (y - z).abs().max() 76 | ``` 77 | 78 | 79 | 80 | 81 | tensor(1.1921e-06) 82 | 83 | 84 | 85 | ### 4D tensors 86 | 87 | 88 | ```python 89 | x = torch.arange(2*3*4*5).view(2,3,4,5).float() 90 | y = nn.InstanceNorm2d(3)(x) 91 | z = nn.LayerNorm(5, elementwise_affine=False)(x) 92 | z2 = nn.LayerNorm([4,5], elementwise_affine=False)(x) 93 | ``` 94 | 95 | 96 | ```python 97 | num = x - x.view(2,3,-1).mean(-1, keepdim=True).unsqueeze(-1) 98 | den = (num.view(2,3,-1).pow(2).mean() + 1e-5).sqrt() 99 | (num/den)[0,0] 100 | ``` 101 | 102 | 103 | 104 | 105 | tensor([[-1.6475, -1.4741, -1.3007, -1.1272, -0.9538], 106 | [-0.7804, -0.6070, -0.4336, -0.2601, -0.0867], 107 | [ 0.0867, 0.2601, 0.4336, 0.6070, 0.7804], 108 | [ 0.9538, 1.1272, 1.3007, 1.4741, 1.6475]]) 109 | 110 | 111 | 112 | 113 | ```python 114 | y[0,0] 115 | ``` 116 | 117 | 118 | 119 | 120 | tensor([[-1.6475, -1.4741, -1.3007, -1.1272, -0.9538], 121 | [-0.7804, -0.6070, -0.4336, -0.2601, -0.0867], 122 | [ 0.0867, 0.2601, 0.4336, 0.6070, 0.7804], 123 | [ 0.9538, 1.1272, 1.3007, 1.4741, 1.6475]]) 124 | 125 | 126 | 127 | 128 | ```python 129 | z2[0,0] 130 | ``` 131 | 132 | 133 | 134 | 135 | tensor([[-1.6475, -1.4741, -1.3007, -1.1272, -0.9538], 136 | [-0.7804, -0.6070, -0.4336, -0.2601, -0.0867], 137 | [ 0.0867, 0.2601, 0.4336, 0.6070, 0.7804], 138 | [ 0.9538, 1.1272, 1.3007, 1.4741, 1.6475]]) 139 | 140 | 141 | 142 | 143 | ```python 144 | num = x - x.mean(-1, keepdim=True) 145 | den = (num.pow(2).mean(-1, keepdim=True) + 1e-5).sqrt() 146 | (num/den)[0,0] 147 | ``` 148 | 149 | 150 | 151 | 152 | tensor([[-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 153 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 154 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 155 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142]]) 156 | 157 | 158 | 159 | 160 | ```python 161 | z[0,0] 162 | ``` 163 | 164 | 165 | 166 | 167 | tensor([[-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 168 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 169 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 170 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142]]) 171 | 172 | 173 | 174 | 175 | ```python 176 | num = x - x.mean(-1, keepdim=True) 177 | den = (x.var(-1, keepdim=True, unbiased=False) + 1e-5).sqrt() 178 | (num/den)[0,0] 179 | ``` 180 | 181 | 182 | 183 | 184 | tensor([[-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 185 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 186 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142], 187 | [-1.4142, -0.7071, 0.0000, 0.7071, 1.4142]]) 188 | 189 | 190 | 191 | 192 | ```python 193 | 194 | ``` 195 | -------------------------------------------------------------------------------- /markdowns/Understanding losses.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ```python 4 | import torch 5 | import torch.nn as nn 6 | ``` 7 | 8 | ### 1. LogSoftmax 9 | 10 | 11 | ```python 12 | x = torch.randn(8, 5) 13 | nn.LogSoftmax(dim=1)(x) 14 | ``` 15 | 16 | 17 | 18 | 19 | tensor([[-3.3124, -2.3836, -1.1339, -1.4844, -1.1303], 20 | [-1.9481, -1.3503, -2.9377, -1.4468, -1.1712], 21 | [-2.1393, -1.8147, -4.4871, -1.4645, -0.7404], 22 | [-3.1521, -2.6411, -1.9489, -2.7951, -0.3821], 23 | [-1.6078, -1.3891, -2.0410, -3.2716, -0.9610], 24 | [-1.2204, -2.6508, -1.1108, -2.0820, -1.7130], 25 | [-2.0168, -0.9498, -1.7237, -1.5997, -2.3051], 26 | [-2.3434, -1.8218, -2.2064, -0.5169, -3.3294]]) 27 | 28 | 29 | 30 | ### Same as taking the softmax first and then taking log of each value 31 | * Softmax: `x.exp() / x.exp().sum(dim=1, keepdim=True)` 32 | * Log: `().log()` 33 | 34 | 35 | ```python 36 | (x.exp() / x.exp().sum(dim=1, keepdim=True)).log() 37 | ``` 38 | 39 | 40 | 41 | 42 | tensor([[-3.3124, -2.3836, -1.1339, -1.4844, -1.1303], 43 | [-1.9481, -1.3503, -2.9377, -1.4468, -1.1712], 44 | [-2.1393, -1.8147, -4.4871, -1.4645, -0.7404], 45 | [-3.1521, -2.6411, -1.9489, -2.7951, -0.3821], 46 | [-1.6078, -1.3891, -2.0410, -3.2716, -0.9610], 47 | [-1.2204, -2.6508, -1.1108, -2.0820, -1.7130], 48 | [-2.0168, -0.9498, -1.7237, -1.5997, -2.3051], 49 | [-2.3434, -1.8218, -2.2064, -0.5169, -3.3294]]) 50 | 51 | 52 | 53 | ### 2. NLLLoss 54 | 55 | > The negative log likelihood loss. It is useful to train a classification problem with C classes.
56 | If provided, the optional argument weight should be a 1D Tensor assigning weight to each of the classes. This is particularly useful when you have an unbalanced training set. 57 | 58 | ### Simply put, (log of probability * -1) 59 | 60 | 61 | ```python 62 | targets = torch.tensor([4,2,3,4,0,0,1,2]) 63 | softmax_x = x.exp() / x.exp().sum(dim=1,keepdim=True) 64 | log_likelihood_x = softmax_x.log() 65 | print(log_likelihood_x.size(), targets.size()) 66 | nn.NLLLoss()(log_likelihood_x, targets) 67 | ``` 68 | 69 | torch.Size([8, 5]) torch.Size([8]) 70 | 71 | 72 | 73 | 74 | 75 | tensor(1.4874) 76 | 77 | 78 | 79 | 80 | ```python 81 | nn.NLLLoss()(nn.LogSoftmax(dim=1)(x), targets) 82 | ``` 83 | 84 | 85 | 86 | 87 | tensor(1.4874) 88 | 89 | 90 | 91 | ### Manual calculation 92 | 93 | 94 | ```python 95 | loss = 0 96 | for i in range(8): 97 | loss += nn.LogSoftmax(dim=1)(x)[i][targets[i]] 98 | 99 | loss/8 100 | ``` 101 | 102 | 103 | 104 | 105 | tensor(-1.4874) 106 | 107 | 108 | 109 | ### 3. CrossEntropyLoss 110 | 111 | This criterion combines `nn.LogSoftmax()` and `nn.NLLLoss()` in one single class. 112 | 113 | 114 | ```python 115 | nn.CrossEntropyLoss()(x, targets) 116 | ``` 117 | 118 | 119 | 120 | 121 | tensor(1.4874) 122 | 123 | 124 | 125 | ### 4. BCELoss 126 | 127 | > Creates a criterion that measures the Binary Cross Entropy between the target and the output. 128 | 129 | `CrossEntropyLoss` measures how close is the probability of true class close to 1. It does not consider what other possibilities are. Thus, it works well with `Softmax`. 130 | 131 | When you want to measure how close are the probabilities of true classes close to 1 and how close are the probabilities of non-true classes close to 0, using `BCELoss` makes sense. 132 | 133 | 134 | ```python 135 | x = (torch.randn(4,3)).sigmoid_() 136 | x 137 | ``` 138 | 139 | 140 | 141 | 142 | tensor([[0.4929, 0.8562, 0.5976], 143 | [0.3183, 0.7167, 0.5629], 144 | [0.4342, 0.5078, 0.7811], 145 | [0.6997, 0.6790, 0.7381]]) 146 | 147 | 148 | 149 | 150 | ```python 151 | targets = torch.FloatTensor([[0,1,1],[1,1,0],[0,0,0],[1,0,0]]) 152 | targets 153 | ``` 154 | 155 | 156 | 157 | 158 | tensor([[0., 1., 1.], 159 | [1., 1., 0.], 160 | [0., 0., 0.], 161 | [1., 0., 0.]]) 162 | 163 | 164 | 165 | 166 | ```python 167 | nn.BCELoss()(x, targets) 168 | ``` 169 | 170 | 171 | 172 | 173 | tensor(0.7738) 174 | 175 | 176 | 177 | ### Manual calculation 178 | 179 | 180 | ```python 181 | def loss_per_class(p,t): 182 | if t == 0: 183 | return -1 * (1-t) * torch.log(1-p) 184 | else: 185 | return -1 * t * torch.log(p) 186 | 187 | loss = 0 188 | for index in range(x.size(0)): 189 | predicted = x[index] 190 | true = targets[index] 191 | loss += torch.FloatTensor([loss_per_class(p,t) for p,t in zip(predicted, true)]).sum() 192 | 193 | loss / (4*3) 194 | ``` 195 | 196 | 197 | 198 | 199 | tensor(0.7738) 200 | 201 | 202 | 203 | ### 5. KLDivLoss 204 | 205 | As with `NLLLoss`, the input given is expected to contain _log-probabilities_. However, unlike `NLLLoss`, input is not restricted to a 2D Tensor, because the criterion is applied element-wise. The targets are given as _probabilities_ (i.e. without taking the logarithm). 206 | 207 | This criterion expects a target Tensor of the same size as the input Tensor. 208 | 209 | 210 | ```python 211 | x = torch.randn(8,5) 212 | targets = torch.randn_like(x).sigmoid_() 213 | nn.KLDivLoss()(nn.LogSoftmax(dim=1)(x), targets) 214 | ``` 215 | 216 | 217 | 218 | 219 | tensor(0.8478) 220 | 221 | 222 | 223 | ### Manual calculation 224 | 225 | 226 | ```python 227 | loss = 0 228 | for i in range(x.size(0)): 229 | for j in range(x.size(1)): 230 | loss += targets[i,j] * (torch.log(targets[i,j]) - x[i,j] + x[i,:].exp().sum().log()) 231 | ``` 232 | 233 | 234 | ```python 235 | nn.KLDivLoss(reduction='sum')(nn.LogSoftmax(dim=1)(x), targets), loss 236 | ``` 237 | 238 | 239 | 240 | 241 | (tensor(33.9101), tensor(33.9101)) 242 | 243 | 244 | 245 | 246 | ```python 247 | nn.KLDivLoss()(nn.LogSoftmax(dim=1)(x), targets), loss / (x.size(0) * x.size(1)) 248 | ``` 249 | 250 | 251 | 252 | 253 | (tensor(0.8478), tensor(0.8478)) 254 | 255 | 256 | -------------------------------------------------------------------------------- /markdowns/backprop.md: -------------------------------------------------------------------------------- 1 | [1. Simple example](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#1-simple-example) 2 | 3 | * [Define a batch of 8 inputs](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#define-a-batch-of-8-inputs) 4 | 5 | * [Note on values of gradients:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#note-on-values-of-gradients) 6 | 7 | * [`x.sum(0)` is same as the weight gradients:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#xsum0-is-same-as-the-weight-gradients) 8 | 9 | [2. Losses calculated from two different batches in two different ways](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#2-losses-calculated-from-two-different-batches-in-two-different-ways) 10 | 11 | * [A note on values of gradients](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#a-note-on-values-of-gradients) 12 | 13 | [ 3. Using `nn.utils.clip_grad_norm` to prevent exploding gradients](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#3-using-nnutilsclip_grad_norm-to-prevent-exploding-gradients) 14 | 15 | * [`nn.utils.clip_grad_norm` returns total norm of parameter gradients _before_ they are clipped](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#nnutilsclip_grad_norm-returns-total-norm-of-parameter-gradients-before-they-are-clipped) 16 | 17 | [4. Using `detach()`](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#using-detach) 18 | 19 | * [Another example of `detach`](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#another-example-of-detach) 20 | 21 | * [Not detaching `y1`:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#not-detaching-y1) 22 | 23 | * [When is `detach()` used?](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/backprop.md#when-is-detach-used) 24 | 25 | ### 1. Simple example 26 | 27 | Consider a fully connected layer with 5 input neurons and 2 output neurons. Let's feed it a batch of 8 inputs, each with dimension 5: 28 | 29 | 30 | ```python 31 | import torch 32 | from torch.autograd import Variable 33 | import torch.nn as nn 34 | ``` 35 | 36 | Since no gradient has been backpropogated, its weights are all `None`: 37 | 38 | 39 | ```python 40 | linear = nn.Linear(5,2) 41 | print('=== Weights and biases ===') 42 | _ = [print(p.data) for p in linear.parameters()] 43 | print('=== Gradients ===') 44 | print([p.grad for p in linear.parameters()]) 45 | ``` 46 | 47 | === Weights and biases === 48 | 49 | -0.3933 0.3731 0.1929 -0.3506 0.4152 50 | -0.2833 -0.2681 -0.0346 0.2800 -0.2322 51 | [torch.FloatTensor of size 2x5] 52 | 53 | 54 | 0.4141 55 | 0.2196 56 | [torch.FloatTensor of size 2] 57 | 58 | === Gradients === 59 | [None, None] 60 | 61 | 62 | #### Define a batch of 8 inputs 63 | 64 | 65 | ```python 66 | x = Variable(torch.randn(8,5)) 67 | y = linear(x) 68 | y 69 | ``` 70 | 71 | 72 | 73 | 74 | Variable containing: 75 | 1.0078 -0.4197 76 | 1.7015 -0.4344 77 | -0.1489 0.5115 78 | 0.0158 -0.3754 79 | 2.0604 -0.3813 80 | 0.0680 0.7753 81 | 1.1815 -0.1078 82 | 1.5967 0.0897 83 | [torch.FloatTensor of size 8x2] 84 | 85 | 86 | 87 | Let's assume the gradient of loss with respect to `y` is 1. We can backpropogate these gradients using `y.backward`: 88 | 89 | 90 | ```python 91 | grads = Variable(torch.ones(8,2)) 92 | y.backward(grads) 93 | ``` 94 | 95 | The parameters of the linear layer (weights and biases) now have non-None values: 96 | 97 | 98 | ```python 99 | print('=== Weights and biases ===') 100 | _ = [print(p.data) for p in linear.parameters()] 101 | print('=== Gradients ===') 102 | print([p.grad for p in linear.parameters()]) 103 | ``` 104 | 105 | === Weights and biases === 106 | 107 | -0.3933 0.3731 0.1929 -0.3506 0.4152 108 | -0.2833 -0.2681 -0.0346 0.2800 -0.2322 109 | [torch.FloatTensor of size 2x5] 110 | 111 | 112 | 0.4141 113 | 0.2196 114 | [torch.FloatTensor of size 2] 115 | 116 | === Gradients === 117 | [Variable containing: 118 | -0.6444 3.9640 3.1489 -1.0274 3.5395 119 | -0.6444 3.9640 3.1489 -1.0274 3.5395 120 | [torch.FloatTensor of size 2x5] 121 | , Variable containing: 122 | 8 123 | 8 124 | [torch.FloatTensor of size 2] 125 | ] 126 | 127 | 128 | #### Note on values of gradients: 129 | 130 | To make it easy to understand, let's consider only on neuron. 131 | 132 | For a given linear neuron which does the following operation ... 133 | 134 | > $y = \sum_{i=1}^n w_i x_i + b $ 135 | 136 | ... the derivatives wrt bias and weights are given as: 137 | 138 | > $\frac{\partial Loss}{\partial b} = \frac{\partial Loss}{\partial y}$ 139 | 140 | and 141 | 142 | > $\frac{\partial Loss}{\partial w_i} = \frac{\partial Loss}{\partial y} * x_i$ 143 | 144 | Also note that these operations take place _per input sample_. In our case, there are __eight__ samples and the above operations take place eight times. For each sample, we have $\frac{\partial Loss}{\partial y} = 1$. Hence, we do the following operation eight times: 145 | 146 | $\frac{\partial Loss}{\partial b} = 1$ 147 | 148 | Each time, the new gradient gets added to the existing gradient. Hence, the final value of $\frac{\partial Loss}{\partial b}$ is equal to 8. 149 | 150 | Similarly, for each sample input `i`, we have: 151 | 152 | $\frac{\partial Loss}{\partial w_i} = \frac{\partial Loss}{\partial y} * x_i = x_i$ 153 | 154 | After backpropagating for each of the eight samples, the final value of $\frac{\partial Loss}{\partial w_i}$ is 155 | $\sum_{i=1}^8 x_i$ 156 | 157 | Let's verify this: 158 | 159 | #### `x.sum(0)` is same as the weight gradients: 160 | 161 | 162 | ```python 163 | x.sum(0) 164 | ``` 165 | 166 | 167 | 168 | 169 | Variable containing: 170 | -0.6444 171 | 3.9640 172 | 3.1489 173 | -1.0274 174 | 3.5395 175 | [torch.FloatTensor of size 5] 176 | 177 | 178 | 179 | 180 | ```python 181 | linear.weight.grad 182 | ``` 183 | 184 | 185 | 186 | 187 | Variable containing: 188 | -0.6444 3.9640 3.1489 -1.0274 3.5395 189 | -0.6444 3.9640 3.1489 -1.0274 3.5395 190 | [torch.FloatTensor of size 2x5] 191 | 192 | 193 | 194 | Before we move on, make sure you've understood everything up until this point. 195 | 196 | --- 197 | 198 | ### 2. Losses calculated from two different batches in two different ways 199 | 200 | Consider two batches `x1` and `x2` of sizes 2 and 3 respectively: 201 | 202 | 203 | ```python 204 | x1 = Variable(torch.randn(2,5)) 205 | x2 = Variable(torch.randn(3,5)) 206 | y1, y2 = linear(x1), linear(x2) 207 | print(y1.data) 208 | print(y2.data) 209 | ``` 210 | 211 | 212 | 1.3945 0.9929 213 | 0.3436 -0.1620 214 | [torch.FloatTensor of size 2x2] 215 | 216 | 217 | -0.1487 0.1283 218 | 1.3610 0.0088 219 | -0.3215 0.2808 220 | [torch.FloatTensor of size 3x2] 221 | 222 | 223 | 224 | Let's assume the target values for `y1` and `y2` are all 1. But we want to give the loss from `y2` twice the weight: 225 | 226 | 227 | ```python 228 | loss = (1 - y1).sum() + 2*(1 - y2).sum() 229 | print(loss) 230 | ``` 231 | 232 | Variable containing: 233 | 10.8135 234 | [torch.FloatTensor of size 1] 235 | 236 | 237 | 238 | Make sure the gradients of our layer are reset to zero before we backprop again: 239 | 240 | 241 | ```python 242 | linear.zero_grad() 243 | [p.grad for p in linear.parameters()] 244 | ``` 245 | 246 | 247 | 248 | 249 | [Variable containing: 250 | 0 0 0 0 0 251 | 0 0 0 0 0 252 | [torch.FloatTensor of size 2x5], Variable containing: 253 | 0 254 | 0 255 | [torch.FloatTensor of size 2]] 256 | 257 | 258 | 259 | 260 | ```python 261 | loss.backward() 262 | print('=== Gradients ===') 263 | print([p.grad for p in linear.parameters()]) 264 | ``` 265 | 266 | === Gradients === 267 | [Variable containing: 268 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483 269 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483 270 | [torch.FloatTensor of size 2x5] 271 | , Variable containing: 272 | -8 273 | -8 274 | [torch.FloatTensor of size 2] 275 | ] 276 | 277 | 278 | #### A note on values of gradients 279 | 280 | Here, we have: 281 | 282 | > $\frac{\partial Loss}{\partial y1} = -1$. This operation takes place for each sample in `x1` (i.e. 2 times) 283 | 284 | and 285 | 286 | > $\frac{\partial Loss}{\partial y2} = -2$. This operation takes place for each sample in `x2` (i.e. 3 times) 287 | 288 | For each case we have: 289 | 290 | > $\frac{\partial Loss}{\partial b} = \frac{\partial Loss}{\partial y}$ 291 | 292 | Thus the final value of $\frac{\partial Loss}{\partial b}$ is: 293 | 294 | $2 * -1 + 3 * -2 = -8$ 295 | 296 | Similarly, __weight gradients are equal to negative of `x1.sum(0) + 2*x2.sum(0)`__ 297 | 298 | Let's verify this: 299 | 300 | 301 | ```python 302 | -1 * (x1.sum(0) + 2*x2.sum(0)) 303 | ``` 304 | 305 | 306 | 307 | 308 | Variable containing: 309 | -0.0289 310 | -1.5594 311 | -0.9259 312 | -0.7791 313 | 0.6483 314 | [torch.FloatTensor of size 5] 315 | 316 | 317 | 318 | 319 | ```python 320 | linear.weight.grad 321 | ``` 322 | 323 | 324 | 325 | 326 | Variable containing: 327 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483 328 | -0.0289 -1.5594 -0.9259 -0.7791 0.6483 329 | [torch.FloatTensor of size 2x5] 330 | 331 | 332 | 333 | --- 334 | 335 | ### 3. Using `nn.utils.clip_grad_norm` to prevent exploding gradients 336 | 337 | _After_ the gradients have been backpropagated, the gradients can be clipped to keep them small. This can be done using `nn.utils.clip_grad_norm` like so: 338 | 339 | 340 | ```python 341 | linear.zero_grad() 342 | x = Variable( torch.randn(8,5) ) 343 | y = linear(x) 344 | grads = Variable( torch.ones(8,2) ) 345 | y.backward(grads) 346 | ``` 347 | 348 | 349 | ```python 350 | [p.grad.norm() for p in linear.parameters()] 351 | ``` 352 | 353 | 354 | 355 | 356 | [Variable containing: 357 | 8.7584 358 | [torch.FloatTensor of size 1], Variable containing: 359 | 11.3137 360 | [torch.FloatTensor of size 1]] 361 | 362 | 363 | 364 | #### `nn.utils.clip_grad_norm` returns total norm of parameter gradients _before_ they are clipped 365 | 366 | 367 | ```python 368 | nn.utils.clip_grad_norm(linear.parameters(), 3) 369 | ``` 370 | 371 | 372 | 373 | 374 | 14.30769861897899 375 | 376 | 377 | 378 | Let's verify this: 379 | 380 | 381 | ```python 382 | (8.7584**2 + 11.3137**2)**0.5 # should be the same as the value returned by the clip_grad_norm above 383 | ``` 384 | 385 | 386 | 387 | 388 | 14.307668512025291 389 | 390 | 391 | 392 | The norm of the new gradients is 3 now: 393 | 394 | 395 | ```python 396 | [p.grad.norm() for p in linear.parameters()] 397 | ``` 398 | 399 | 400 | 401 | 402 | [Variable containing: 403 | 1.8364 404 | [torch.FloatTensor of size 1], Variable containing: 405 | 2.3722 406 | [torch.FloatTensor of size 1]] 407 | 408 | 409 | 410 | 411 | ```python 412 | (1.8364**2 + 2.3722**2)**0.5 413 | ``` 414 | 415 | 416 | 417 | 418 | 2.999949632910526 419 | 420 | 421 | 422 | ### Using `detach()` 423 | 424 | Remember that a PyTorch `Variable` contains history of a graph (if it is a part of a graph, obviously). This includes intermediate values and backprop operations in the graph. In the example below, `y` requires a gradient and the first operation it will do during backprop is `AddmmBackward` 425 | 426 | 427 | ```python 428 | linear.zero_grad() 429 | 430 | x = Variable(torch.randn(8,5)) 431 | y = linear(x) 432 | print(y.requires_grad, y.grad_fn) 433 | ``` 434 | 435 | True 436 | 437 | 438 | When we create a new Variable, say `y_new` from `y` having the same data as `y`, `y_new` does not get the graph-related history from `y`. It only gets the data from `y`. Hence, in the example below, `y_new` does not require a gradient and its `grad_fn` is `None`: 439 | 440 | 441 | ```python 442 | y_new = Variable(y.data) 443 | print(y_new.requires_grad, y_new.grad_fn) 444 | ``` 445 | 446 | False None 447 | 448 | 449 | You can say that `y_new` is _detached_ from the graph. When a backprop operation encounters `y_new`, it will see that its `grad_fn` is `None` and it will not continue backpropagating along the path where `y_new` was encountered. 450 | 451 | One can achieve the same effect without defining a new variable; just call the `detach()` method. There's also a `detach_()` method for in-place operation. 452 | 453 | Let's understanding it in greater detail. 454 | 455 | 456 | ```python 457 | y_detached = y.detach() 458 | print(y_detached.requires_grad, y_detached.grad_fn) 459 | ``` 460 | 461 | False None 462 | 463 | 464 | 465 | ```python 466 | print(y.requires_grad, y.grad_fn) 467 | y.detach_() 468 | print(y.requires_grad, y.grad_fn) 469 | 470 | ``` 471 | 472 | True 473 | False None 474 | 475 | 476 | Let's understand it in greater detail. 477 | 478 | ### Another example of `detach` 479 | 480 | Consider two serial operations on a single fully connected layer. Notice that `y1` is detached _before_ `y2` is calculated. Thus, backpropagation from `y2` will stop at `y1`. The gradients of the linear layer are calculated as if `y1` was a leaf variable. The gradient values can be calculated in a fashion similar to what we did in [our simple example](#1-Simple-example) 481 | 482 | ![](../assets/detached_y1.png) 483 | 484 | 485 | ```python 486 | linear = nn.Linear(4,4) 487 | y0 = Variable( torch.randn(10,4) ) 488 | y1 = linear(y0) 489 | y1.detach_() 490 | y2 = linear(y1) 491 | ``` 492 | 493 | 494 | ```python 495 | [p.grad for p in linear.parameters()] 496 | ``` 497 | 498 | 499 | 500 | 501 | [None, None] 502 | 503 | 504 | 505 | 506 | ```python 507 | y2.backward( torch.ones(10,4) ) 508 | [p.grad for p in linear.parameters()] 509 | ``` 510 | 511 | 512 | 513 | 514 | [Variable containing: 515 | 0.7047 -1.9666 4.8305 -0.1647 516 | 0.7047 -1.9666 4.8305 -0.1647 517 | 0.7047 -1.9666 4.8305 -0.1647 518 | 0.7047 -1.9666 4.8305 -0.1647 519 | [torch.FloatTensor of size 4x4], Variable containing: 520 | 10 521 | 10 522 | 10 523 | 10 524 | [torch.FloatTensor of size 4]] 525 | 526 | 527 | 528 | Note how the weight gradients are simply the sum of $x_is$: 529 | 530 | 531 | ```python 532 | y1.sum(0) 533 | ``` 534 | 535 | 536 | 537 | 538 | Variable containing: 539 | 0.7047 540 | -1.9666 541 | 4.8305 542 | -0.1647 543 | [torch.FloatTensor of size 4] 544 | 545 | 546 | 547 | ### Not detaching `y1`: 548 | 549 | In this case, backpropagation will continue beyond `y1`: 550 | 551 | ![](../assets/undetached_y1.png) 552 | 553 | Let's verify this. For sake of simplicity, let's just take a look at how bias gradients are calculated. 554 | 555 | When gradients from `y2` are backpropagated, similar to the calculations we did in [our simple example](#1-Simple-example), we have: 556 | 557 | > $\frac{\partial Loss}{\partial bias} = 10$ 558 | 559 | When gradients from `y1` are backpropagated, for a neuron $y1_i$, we have: 560 | 561 | > $\frac{\partial Loss}{\partial y1_i} = \sum_{i=1}^4 \frac{\partial Loss}{\partial w1_i} * w1_i = \sum_{i=1}^4 w1_i$ 562 | 563 | and finally: 564 | 565 | > $\frac{\partial Loss}{\partial bias} = $\frac{\partial Loss}{\partial y1_i}$ 566 | 567 | Hence, final bias gradient is equal to: 568 | 569 | $10 + \sum_{i=1}^4 w1_i$ 570 | 571 | Let's verify this: 572 | 573 | 574 | ```python 575 | y1 = linear(y0) 576 | y2 = linear(y1) 577 | linear.zero_grad() 578 | y2.backward( torch.ones(10,4) ) 579 | ``` 580 | 581 | 582 | ```python 583 | linear.bias.grad 584 | ``` 585 | 586 | 587 | 588 | 589 | Variable containing: 590 | 4.5625 591 | 4.7037 592 | 14.1218 593 | 5.4621 594 | [torch.FloatTensor of size 4] 595 | 596 | 597 | 598 | Now let's calculate $10 + \sum_{i=1}^4 w1_i$. The result should be (and is) equal to the bias gradients shown above: 599 | 600 | 601 | ```python 602 | 10 + (linear.weight.sum(0) * 10) 603 | ``` 604 | 605 | 606 | 607 | 608 | Variable containing: 609 | 4.5625 610 | 4.7037 611 | 14.1218 612 | 5.4621 613 | [torch.FloatTensor of size 4] 614 | 615 | 616 | 617 | #### When is `detach()` used? 618 | 619 | It can useful when you don't want to backpropagate beyond a few steps in time, for example, in training language models. 620 | -------------------------------------------------------------------------------- /markdowns/basics.md: -------------------------------------------------------------------------------- 1 | 2 | ### Solve the same problem in different ways 3 | 4 | __Problem__: Find square root of `5`. 5 | 6 | __Actual solutions__: `±2.2360679775` 7 | 8 | ### 1. Start with a guess and improve the guess using gradient descent on loss 9 | 10 | 11 | ```python 12 | import torch 13 | ``` 14 | 15 | 16 | ```python 17 | guess = torch.tensor(4.0, requires_grad=True) 18 | learning_rate = 0.01 19 | ``` 20 | 21 | 22 | ```python 23 | for _ in range(20): 24 | loss = torch.pow(5 - torch.pow(guess, 2), 2) 25 | if guess.grad is not None: 26 | guess.grad.data.zero_() 27 | loss.backward() 28 | guess.data = guess.data - learning_rate * guess.grad.data 29 | print(guess.item()) 30 | ``` 31 | 32 | 2.240000009536743 33 | 2.2384231090545654 34 | 2.2374794483184814 35 | 2.2369143962860107 36 | 2.2365756034851074 37 | 2.236372470855713 38 | 2.236250638961792 39 | 2.236177682876587 40 | 2.2361338138580322 41 | 2.236107587814331 42 | 2.2360918521881104 43 | 2.2360823154449463 44 | 2.236076593399048 45 | 2.2360732555389404 46 | 2.2360711097717285 47 | 2.236069917678833 48 | 2.2360692024230957 49 | 2.2360687255859375 50 | 2.2360684871673584 51 | 2.2360682487487793 52 | 53 | 54 | ### 2. Parameterize the guess: `shift` the guess 55 | 56 | The idea here is that you do not change the input parameter. Instead, you come up with some variables which interact with the input to create a guess. The variables are then updated using gradient descent. These variables are also called `parameters`. 57 | 58 | Here we use a parameter called `shift`. 59 | 60 | __This is an import shift (no pun intended) in how you think about getting the right answers__: 61 | > __Don't update the guess. Instead update the variables that interact with the guess. A similar line of thought employed in reparameterization trick used in variational inference on neuron networks.__ 62 | 63 | 64 | ```python 65 | input_val = torch.tensor(4.0) 66 | shift = torch.tensor(1.0, requires_grad=True) 67 | ``` 68 | 69 | 70 | ```python 71 | for _ in range(30): 72 | guess = input_val + shift 73 | loss = torch.pow(5 - torch.pow(guess, 2), 2) 74 | if shift.grad is not None: 75 | shift.grad.data.zero_() 76 | loss.backward() 77 | shift.data = shift.data - learning_rate * shift.grad.data 78 | print(guess.item()) 79 | ``` 80 | 81 | 5.0 82 | 1.0 83 | 1.1600000858306885 84 | 1.3295643329620361 85 | 1.5014641284942627 86 | 1.6663613319396973 87 | 1.8145501613616943 88 | 1.9384772777557373 89 | 2.034804582595825 90 | 2.104766845703125 91 | 2.152751922607422 92 | 2.184238910675049 93 | 2.2042551040649414 94 | 2.216710090637207 95 | 2.2243528366088867 96 | 2.2290022373199463 97 | 2.2318150997161865 98 | 2.233511447906494 99 | 2.234532356262207 100 | 2.2351458072662354 101 | 2.2355144023895264 102 | 2.2357358932495117 103 | 2.235868453979492 104 | 2.235948324203491 105 | 2.2359962463378906 106 | 2.236024856567383 107 | 2.236042022705078 108 | 2.2360525131225586 109 | 2.2360587120056152 110 | 2.236062526702881 111 | 112 | 113 | ### 2.1 Parameterize the guess: `shift` and `scale` the guess 114 | 115 | 116 | ```python 117 | input_val = torch.tensor(4.0) 118 | shift = torch.tensor(1.0, requires_grad=True) 119 | scale = torch.tensor(1.0, requires_grad=True) 120 | ``` 121 | 122 | The learning rate of 0.01 is pretty high for this operation. Thus we reduce it (in log scale). After a couple of runs, we decide on the rate of `0.0005`. 123 | 124 | 125 | ```python 126 | learning_rate = 0.0005 127 | for _ in range(30): 128 | guess = input_val*scale + shift 129 | loss = torch.pow(5 - torch.pow(guess, 2), 2) 130 | if shift.grad is not None: 131 | shift.grad.data.zero_() 132 | if scale.grad is not None: 133 | scale.grad.data.zero_() 134 | loss.backward() 135 | shift.data = shift.data - learning_rate * shift.grad.data 136 | scale.data = scale.data - learning_rate * scale.grad.data 137 | print(guess.item()) 138 | ``` 139 | 140 | 5.0 141 | 1.5999999046325684 142 | 1.7327359914779663 143 | 1.8504221439361572 144 | 1.9495712518692017 145 | 2.0290589332580566 146 | 2.0899698734283447 147 | 2.134881019592285 148 | 2.1669845581054688 149 | 2.1893954277038574 150 | 2.204770803451538 151 | 2.2151894569396973 152 | 2.222188949584961 153 | 2.2268640995025635 154 | 2.2299740314483643 155 | 2.2320375442504883 156 | 2.2334041595458984 157 | 2.2343082427978516 158 | 2.234905958175659 159 | 2.2353007793426514 160 | 2.2355613708496094 161 | 2.2357337474823 162 | 2.235847234725952 163 | 2.235922336578369 164 | 2.2359719276428223 165 | 2.23600435256958 166 | 2.2360260486602783 167 | 2.2360403537750244 168 | 2.2360496520996094 169 | 2.236055850982666 170 | 171 | 172 | ### 2.3 Note that the loss function has two minima: with a little bit of tweaking, a relatively higher learning rate pushes it to another minimum. 173 | 174 | Another interesting thing to note is that there are two solutions to square root of 5: `2.2360` and `-2.2360`. If we increase the learning rate to `0.001`, it pushes it into another minima. The solution converges to `-2.2360680103302`. 175 | 176 | 177 | ```python 178 | input_val = torch.tensor(4.0) 179 | shift = torch.tensor(1.0, requires_grad=True) 180 | scale = torch.tensor(1.0, requires_grad=True) 181 | 182 | learning_rate = 0.001 183 | for _ in range(30): 184 | guess = input_val*scale + shift 185 | loss = torch.pow(5 - torch.pow(guess, 2), 2) 186 | if shift.grad is not None: 187 | shift.grad.data.zero_() 188 | if scale.grad is not None: 189 | scale.grad.data.zero_() 190 | loss.backward() 191 | shift.data = shift.data - learning_rate * shift.grad.data 192 | scale.data = scale.data - learning_rate * scale.grad.data 193 | print(guess.item()) 194 | ``` 195 | 196 | 5.0 197 | -1.8000000715255737 198 | -2.0154242515563965 199 | -2.1439850330352783 200 | -2.202786445617676 201 | -2.2249152660369873 202 | -2.2324423789978027 203 | -2.2349019050598145 204 | -2.235694408416748 205 | -2.235948085784912 206 | -2.236029624938965 207 | -2.236055850982666 208 | -2.2360641956329346 209 | -2.2360668182373047 210 | -2.236067533493042 211 | -2.236067771911621 212 | -2.2360680103302 213 | -2.2360680103302 214 | -2.2360680103302 215 | -2.2360680103302 216 | -2.2360680103302 217 | -2.2360680103302 218 | -2.2360680103302 219 | -2.2360680103302 220 | -2.2360680103302 221 | -2.2360680103302 222 | -2.2360680103302 223 | -2.2360680103302 224 | -2.2360680103302 225 | -2.2360680103302 226 | 227 | 228 | ### 3. Replacing `scale` and `shift` with a neuron 229 | 230 | 231 | ```python 232 | neuron = torch.nn.Linear(1,1) 233 | input_val = torch.tensor(4.0).view(1,-1) # input to a linear layer should be in the form [batch_size, input_size] 234 | 235 | learning_rate = 0.001 236 | for _ in range(30): 237 | guess = neuron(input_val) 238 | loss = torch.pow(5 - torch.pow(guess, 2), 2) 239 | neuron.zero_grad() 240 | loss.backward() 241 | for param in neuron.parameters(): # parameters of a model can be iterated through easily using the `.parameters()` method 242 | param.data = param.data - learning_rate * param.grad.data 243 | print(guess.item()) 244 | ``` 245 | 246 | 0.2342168092727661 247 | 0.3129768371582031 248 | 0.41730424761772156 249 | 0.5542460680007935 250 | 0.7311122417449951 251 | 0.9531161785125732 252 | 1.2182986736297607 253 | 1.5095584392547607 254 | 1.7888929843902588 255 | 2.0078368186950684 256 | 2.1400814056396484 257 | 2.201209545135498 258 | 2.2243618965148926 259 | 2.232259511947632 260 | 2.234842538833618 261 | 2.23567533493042 262 | 2.2359421253204346 263 | 2.236027717590332 264 | 2.2360548973083496 265 | 2.2360639572143555 266 | 2.2360665798187256 267 | 2.236067533493042 268 | 2.236067771911621 269 | 2.2360680103302 270 | 2.2360680103302 271 | 2.2360680103302 272 | 2.2360680103302 273 | 2.2360680103302 274 | 2.2360680103302 275 | 2.2360680103302 276 | 277 | 278 | ### 5. Using the optimizer to do parameter updates for you 279 | 280 | We tell the optimizer two things: 281 | a) what paramters need to be updated 282 | b) what is the learning rate 283 | 284 | Then it does all the hard work for you: 285 | 286 | 287 | ```python 288 | neuron = torch.nn.Linear(1,1) 289 | optimizer = torch.optim.SGD(neuron.parameters(), lr=0.001) 290 | 291 | input_val = torch.tensor(4.0).view(1,-1) # input to a linear layer should be in the form [batch_size, input_size] 292 | 293 | learning_rate = 0.001 294 | for _ in range(30): 295 | guess = neuron(input_val) 296 | loss = torch.pow(5 - torch.pow(guess, 2), 2) 297 | neuron.zero_grad() 298 | loss.backward() 299 | optimizer.step() 300 | print(guess.item()) 301 | ``` 302 | 303 | 1.4186859130859375 304 | 1.7068755626678467 305 | 1.949059247970581 306 | 2.108257293701172 307 | 2.187859058380127 308 | 2.2195885181427 309 | 2.230670928955078 310 | 2.234327793121338 311 | 2.2355096340179443 312 | 2.235889196395874 313 | 2.236010789871216 314 | 2.2360496520996094 315 | 2.2360620498657227 316 | 2.2360661029815674 317 | 2.236067295074463 318 | 2.236067771911621 319 | 2.236067771911621 320 | 2.2360680103302 321 | 2.2360680103302 322 | 2.2360680103302 323 | 2.2360680103302 324 | 2.2360680103302 325 | 2.2360680103302 326 | 2.2360680103302 327 | 2.2360680103302 328 | 2.2360680103302 329 | 2.2360680103302 330 | 2.2360680103302 331 | 2.2360680103302 332 | 2.2360680103302 333 | 334 | 335 | ### 6. Using the loss function to calculate the loss automatically 336 | 337 | In some cases calculating the loss may not be as trivial as it is in this case. In those scenarios, we can use PyTorch's in-built loss functions. 338 | 339 | Here we will be using `MSELoss`: 340 | 341 | 342 | ```python 343 | neuron = torch.nn.Linear(1,1) 344 | optimizer = torch.optim.SGD(neuron.parameters(), lr=0.001) 345 | loss_function = torch.nn.MSELoss() 346 | 347 | input_val = torch.tensor(4.0).view(1,-1) # input to a linear layer should be in the form [batch_size, input_size] 348 | 349 | learning_rate = 0.001 350 | for _ in range(30): 351 | guess = neuron(input_val) 352 | predicted_output = torch.pow(guess, 2) 353 | actual_output = torch.tensor(5.0).view(1,-1) 354 | loss = loss_function(predicted_output, actual_output) 355 | neuron.zero_grad() 356 | loss.backward() 357 | optimizer.step() 358 | print(guess.item()) 359 | ``` 360 | 361 | 3.8363184928894043 362 | 1.3013595342636108 363 | 1.593956470489502 364 | 1.860517978668213 365 | 2.0551581382751465 366 | 2.1636502742767334 367 | 2.2105278968811035 368 | 2.2275986671447754 369 | 2.2333250045776367 370 | 2.235186815261841 371 | 2.235785722732544 372 | 2.2359776496887207 373 | 2.236039161682129 374 | 2.2360587120056152 375 | 2.236064910888672 376 | 2.236067295074463 377 | 2.236067533493042 378 | 2.236067771911621 379 | 2.2360680103302 380 | 2.2360680103302 381 | 2.2360680103302 382 | 2.2360680103302 383 | 2.2360680103302 384 | 2.2360680103302 385 | 2.2360680103302 386 | 2.2360680103302 387 | 2.2360680103302 388 | 2.2360680103302 389 | 2.2360680103302 390 | 2.2360680103302 391 | 392 | -------------------------------------------------------------------------------- /markdowns/batchnorm.md: -------------------------------------------------------------------------------- 1 | [Introduction](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#introduction) 2 | 3 | [An intuitive example of internal covariate shift](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#an-intuitive-example-of-internal-covariate-shift) 4 | 5 | [Solution to internal covariate shift](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#solution-to-internal-covariate-shift) 6 | 7 | [Batch Normalization may not always be optimal for learning](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-may-not-always-be-optimal-for-learning) 8 | 9 | [Batch Normalization with backpropagation](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-with-backpropagation) 10 | 11 | [Batch Normalization for 2D data](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-for-2d-data) 12 | 13 | [Batch Normalization for 3D inputs](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-for-3d-inputs) 14 | 15 | [Batch Normalization for images (or any 4D input)](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/batchnorm.md#batch-normalization-for-images-or-any-4d-input) 16 | 17 | ### Introduction 18 | 19 | Let's start with a discussion of what problem were the authors of the [original paper](https://arxiv.org/abs/1502.03167) dealing with when they came up with the idea of Batch Normalization: 20 | 21 | > [...] the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change. This slows down the training by requiring lower learning rates and careful parameter initialization, and makes it notoriously hard to train models with saturating nonlinearities. We refer to this phenomenon as internal covariate shift, and address the problem by normalizing layer inputs. 22 | 23 | ### An intuitive example of internal covariate shift 24 | Let's first talk about covariate shift. A covariate shift occurs when: 25 | 26 | i) given the same observation X = x, the conditional distributions of Y in training and test sets are the same BUT 27 |
28 | ii) marginal distribution of X in training set is different from marginal distribution of X in test set. 29 | 30 | In other words: 31 |
Ptrain[Y|X=x] = Ptest[Y|X=x] BUT
Ptrain[X] ≠ Ptest[X] 32 | 33 | Now let's talk about the second layer _l_ inside a network with just two layers. The layer takes in an input and spits out an output. The output depends on the input (X) as well as the parameters of the network (θ). It can be thought of as a single layer neural network trying to learn P[Y|X,θ] where θ is the set of network parameters. Now if the distribution of θ changes during the training process, the network has to re-learn P[Y|X,θnew]. This is the internal covariate shift problem. This may not be easy or fast if the network finds itself spending more time in linear extremities of non-linear functions like sigmoid or tanh. Quoting from the paper: 34 | 35 | > Consider a network computing
36 | `l = F2(F1(u, Θ1), Θ2)`
37 | where F1 and F2 are arbitrary transformations, and the parameters Θ1,Θ2 are to be learned so as to minimize the loss l. Learning Θ2 can be viewed as if the inputs x = F1 (u, Θ1 ) are fed into the sub-network 38 | l = F2(x,Θ2). For example, a gradient descent step
39 | Θ2 <- Θ2 - (α/m)\*Σm(∂F2(xi,Θ2)/∂Θ2)
40 | (for batch size m and learning rate α) is exactly equivalent to that for a stand-alone network F2 with input x. There- fore, the input distribution properties that make training more efficient – such as having the same distribution be- tween the training and test data – apply to training the sub-network as well. As such it is advantageous for the distribution of x to remain fixed over time. Then, Θ2 does not have to readjust to compensate for the change in the distribution of x. 41 | 42 | So the solution is to normalize the inputs. 43 | 44 | 45 | ### Solution to internal covariate shift 46 | 47 | If a neuron is able to see the entire range of values across the entire training set that it is going to get as subsequent inputs, training can be made faster by normalizing the distribution (i.e. making its mean zero and variance one). The authors call it _input whitening_: 48 | 49 | > It has been long known (LeCun et al., 1998b; Wiesler & Ney, 2011) that the network training converges faster if its inputs are whitened – i.e., linearly transformed to have zero means and unit variances, and decorrelated. As each layer observes the inputs produced by the layers below, it would be advantageous to achieve the same whitening of the inputs of each layer. By whitening the inputs to each layer, we would take a step towards achieving the fixed distributions of inputs that would remove the ill effects of the internal covariate shift. 50 | 51 | However, for a large training set, it is not computationally feasible to normalize inputs for each neuron. Hence, the normalization is carried out on a per-batch basis: 52 | 53 | > In the batch setting where each training step is based on the entire training set, we would use the whole set to normalize activations. However, this is impractical when using stochastic optimization. Therefore, we make the second simplification: since we use mini-batches in stochastic gradient training, each mini-batch produces estimates of the mean and variance of each activation. This way, the statistics used for normalization can fully participate in the gradient backpropagation. 54 | 55 | ### Batch Normalization may not always be optimal for learning 56 | 57 | Batch normalization may lead to, say, inputs always lying in the linear range of sigmoid function. Hence the inputs are shifted and scaled by parameters γ and β. Optimal values of these hyperparameters can be __learnt__ by using backpropagation. 58 | 59 | > Note that simply normalizing each input of a layer may change what the layer can represent. For instance, normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity. To address this, we make sure that the transformation inserted in the network can represent the identity transform. To accomplish this, we introduce, for each activation x(k) , a pair of parameters γ(k), β(k), which scale and shift the normalized value:
60 | y(k) = γ(k)xnormalized(k) + β(k). 61 | 62 | ### Batch Normalization with backpropagation 63 | 64 | If backpropagation ignores batch normalization, it may undo the modification due to batch normalization. Hence, backprop needs to incorporate batch normalization while calculating gradients. [This link](https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html) provides a nice explanation of how to derive backpropagation gradients for batch normalization. Let's see it in action now: 65 | 66 | ### Batch Normalization for 2D data 67 | 68 | Let's start with 1D normalization. 69 | ```python 70 | import numpy as np 71 | import torch 72 | import torch.nn as nn 73 | from torch.autograd import Variable 74 | ``` 75 | 76 | Let's define X and calculate its mean and variance. Note that X has size (20,100) which means it has 20 samples with each sample having 100 features (or dimensions). For normalizing, we need to look across all samples. In other words, we normalize across 20 values for each dimension (with each value coming on a sample). 77 | 78 | ```python 79 | X = torch.randn(20,100) * 5 + 10 80 | X = Variable(X) 81 | 82 | mu = torch.mean(X[:,1]) 83 | var_ = torch.var(X[:,1], unbiased=False) 84 | sigma = torch.sqrt(var_ + 1e-5) 85 | x = (X[:,1] - mu)/sigma 86 | ``` 87 | 88 | Note that in the line above we set `unbiased = False` while calculating `var_`. This is to prevent [Bessel's correction](https://en.wikipedia.org/wiki/Bessel's_correction). In other words, we want to divide by N and not N-1 while calculating the variance. `x` is the same as the result of batch normalization. Also note that in the code below we set `affine = False`. This is to prevent creation of parameters γ(k), β(k) which may scale and shift normalized data again. While training, `affine` is set to `True`. 89 | 90 | ```python 91 | B = nn.BatchNorm1d(100, affine=False) 92 | y = B(X) 93 | print(x.data / y[:,1].data) 94 | ``` 95 | 96 | Output: 97 | ``` 98 | 1.0000 99 | 1.0000 100 | 1.0000 101 | ... 102 | ``` 103 | This also works for 3D input. 104 | 105 | ### Batch Normalization for 3D inputs 106 | 107 | Let's define some 3D data with mean 4 and variance 4: 108 | ```python 109 | X3 = torch.randn(150,20,100) * 2 + 4 110 | X3 = Variable(X3) 111 | B2 = nn.BatchNorm1d(20) 112 | ``` 113 | 114 | Note that here we did not set `affine = False`. Instead, we can manually set those values to what we want. To preserve normalization, we want γ(k) = 1 and β(k) = 0. These values are stored in parameters `weight` and `bias` of the BatchNorm variable. 115 | 116 | ```python 117 | B2.weight.data.fill_(1) 118 | B2.bias.data.fill_(0) 119 | Y = B2(X3) 120 | 121 | #Manual calculation 122 | mu = X3[:,0,:].mean() 123 | sigma = torch.sqrt(torch.var(X3[:,0,:], unbiased=False) + 1e-5) 124 | X_normalized = (X3[:,0,:] - mu)/sigma 125 | ``` 126 | 127 | In the above example, `X_normalized` has the same values as `Y[:,0,:]`. 128 | 129 | ### Batch Normalization for images (or any 4D input) 130 | 131 | A batch of RGB images has four dimensions: (B,C,X,Y) or (B,X,Y,C) where B is batch number, C is channel number, and X, Y are locations. In the last example, we only had (B,N) where N was the number of dimensions so it was pretty straightforward to figure out the axis of normalization. 132 | 133 | Along which axis do we normalize here? Normalizing all values across all batch samples is not the solution here. Why? Because each batch sample 134 | 135 | __Hint__: We want to choose an axis that enables a neuron look through all samples in the batch. 136 | 137 | Batch normalization is images is done along the channel axis. 138 | 139 | ```python 140 | X = torch.randn(5,25,100,100) * 2 + 4 141 | X = Variable(X) 142 | B = nn.BatchNorm2d(25, affine=False) 143 | Y = B(X) 144 | ``` 145 | 146 | Here `Y[:,i,:,:]` is the same as 147 | `((X[:,i,:,:] - X[:,i,:,:].data.mean())/((X[:,i,:,:].data.var(unbiased=False) + 1e-5)**0.5))` for all valid values of `i`. 148 | -------------------------------------------------------------------------------- /markdowns/dropout.md: -------------------------------------------------------------------------------- 1 | [1. Theory, motivation and food for thought](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#1-theory-motivation-and-food-for-thought)
[2. Dropout in action](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#2-dropout-in-action)
2 | * [Importing items we need](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#importing-items-we-need)
3 | * [Create a dropout layer with p = 0.6](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#create-a-dropout-layer-with-p--06)
4 | * [Create an input with every element equal to 1](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#create-an-input-with-every-element-equal-to-1)
5 | * [Let's see what the output looks like](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#lets-see-what-the-output-looks-like)
6 | 7 | [2.1 Why?](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#21-why)
8 | * [Rule for scaling the input](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#rule-for-scaling-the-input)
9 | * [Dropout + Linear Layer](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#dropout--linear-layer)
10 | * [Create input, dropout and linear layer](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#create-input-dropout-and-linear-layer)
11 | * [Let's look at the output of passing the input through dropout layer:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#lets-look-at-the-output-of-passing-the-input-through-dropout-layer)
12 | 13 | [2.2 Randomness of sampling subnetworks](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#22-randomness-of-sampling-subnetworks)
14 | 15 | [2.3 Dropout + Backpropagation](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#23-dropout--backpropagation)
16 | * [Calculate weight gradients for each training case individually and store in a variable. We use `p=0.9` here.](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#calculate-weight-gradients-for-each-training-case-individually-and-store-in-a-variable-we-use-p09-here)
17 | * [Calculate gradients per training case](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#calculate-gradients-per-training-case)
18 | * [Let's take a look at the average of `dWs`:](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#lets-take-a-look-at-the-average-of-dws)
19 | * [Now let's take a look at the gradient of the whole batch. It should be the same as `dWs/n` above.](https://github.com/vinsis/understanding-neuralnetworks-pytorch/blob/master/dropout.md#now-lets-take-a-look-at-the-gradient-of-the-whole-batch-it-should-be-the-same-as-dwsn-above) 20 | 21 | ### 1. Theory, motivation and food for thought 22 | 23 | The [original paper](http://www.jmlr.org/papers/volume15/srivastava14a.old/srivastava14a.pdf) on Dropout is fairly easy to understand and is one of the more interesting research papers I have read. Before I jump into the what and how of Dropout, I think it's important to appreciate _why_ it was needed and what motivated the people involved to come up with this idea. 24 | 25 | The below snippet from the original paper gives a good idea: 26 | 27 | > Combining several models is most 28 | helpful when the individual models are different from each other and in order to make 29 | neural net models different, they should either have different architectures or be trained 30 | on different data. Training many different architectures is hard because finding optimal 31 | hyperparameters for each architecture is a daunting task and training each large network 32 | requires a lot of computation. Moreover, large networks normally require large amounts of 33 | training data and there may not be enough data available to train different networks on 34 | different subsets of the data. Even if one was able to train many different large networks, 35 | using them all at test time is infeasible in applications where it is important to respond 36 | quickly. 37 | 38 | So what is Dropout? 39 | 40 | > The term “dropout” refers to dropping out units (hidden and 41 | visible) in a neural network. By dropping a unit out, we mean temporarily removing it from 42 | the network, along with all its incoming and outgoing connections, as shown in Figure 1. 43 | The choice of which units to drop is random. 44 | 45 | __Caveat__: The part `By dropping a unit out, we mean temporarily removing it from 46 | the network, along with all its incoming and outgoing connections` can be a bit misleading. While theoretically correct, it gives an impression that the network architecture changes randomly over time. But the implementation of Dropout does not change the network architecture every iteration. Instead, the output of each node chosen to be dropped out is set to zero regardless of what the input was. This also results in back propagation gradients of weights attached to that node becoming zero (we will soon see this in action). 47 | 48 | Before we jump into the code, I would recommend read you read Section 2 of the original paper gives possible explanations for why Dropout works so well. Here's a snippet: 49 | 50 | > One possible explanation for the superiority of sexual reproduction is that, over the long 51 | term, the criterion for natural selection may not be individual fitness but rather mix-ability 52 | of genes. The ability of a set of genes to be able to work well with another random set of 53 | genes makes them more robust. Since a gene cannot rely on a large set of partners to be 54 | present at all times, it must learn to do something useful on its own or in collaboration with 55 | a 56 | small 57 | number of other genes. According to this theory, the role of sexual reproduction 58 | is not just to allow useful new genes to spread throughout the population, but also to 59 | facilitate this process by reducing complex co-adaptations that would reduce the chance of 60 | a new gene improving the fitness of an individual. Similarly, each hidden unit in a neural 61 | network trained with dropout must learn to work with a randomly chosen sample of other 62 | units. This should make each hidden unit more robust and drive it towards creating useful 63 | features on its own without relying on other hidden units to correct its mistakes. However, 64 | the hidden units within a layer will still learn to do different things from each other. 65 | 66 | 67 | ### 2. Dropout in action 68 | 69 | Think of Dropout as a layer. The output of the layer has the same size (or dimensions) as the input. However, a fraction `p` of the elements in the output is set to zero while the remaining fraction `1-p` of elements are identical (not really but we will see why in a while) with the corresponding input. We can choose any value between (0 and 1 of course) for p. 70 | 71 | #### Importing items we need 72 | ```python 73 | import torch 74 | import torch.nn as nn 75 | from torch.autograd import Variable 76 | ``` 77 | 78 | #### Create a dropout layer with p = 0.6 79 | ```python 80 | p = 0.6 81 | do = nn.Dropout(p) 82 | ``` 83 | 84 | #### Create an input with every element equal to 1 85 | ```python 86 | X = torch.ones(5,2) 87 | print(X) 88 | ``` 89 | Output: 90 | ``` 91 | 1 1 92 | 1 1 93 | 1 1 94 | 1 1 95 | 1 1 96 | [torch.FloatTensor of size 5x2] 97 | ``` 98 | 99 | #### Let's see what the output looks like 100 | ```python 101 | do(X) 102 | ``` 103 | Output: 104 | ``` 105 | Variable containing: 106 | 2.5000 2.5000 107 | 0.0000 2.5000 108 | 2.5000 0.0000 109 | 0.0000 0.0000 110 | 2.5000 2.5000 111 | [torch.FloatTensor of size 5x2] 112 | ``` 113 | 114 | ### 2.1 Why? 115 | We see that some of the nodes (__approximately__ a fraction 0.6 have been set of zero in the output). But the remaining ones are not equal to 1. 116 | 117 | Why? Because we scaled the input values linearly. Why? Because we want the expected output of each node to be the same with and without dropout. 118 | 119 | Why? Because during testing or evaluation dropout is not used. 120 | 121 | Why? Because the overall learned network can be thought of as a _combination of several models_ where each _model_ was the result of a dropout layer involved during the training process. Using dropout during testing or evaluation means we are still using a subset of several possible models and not a combination of them which is suboptimal. 122 | 123 | #### Rule for scaling the input 124 | Let consider a node `n` with value `x` that is subjected to a dropout layer 100 times. It's output will be `0` approximately `p*100` times and `x_out` `(1-p)*100` times. Its expected output is: 125 | 126 | `p*0 + (1-p)*x_out` which is equal to `(1-p)*x_out` 127 | 128 | We want it to be equal to x. Hence `x_out = x/(1-p)` 129 | 130 | In the example above, `p = 0.6`. Hence, the output is `1/(1-0.6) = 2.5`. 131 | 132 | #### Dropout + Linear Layer 133 | Let's create a super small network that looks like this: 134 | 135 | ``` 136 | Input Size: (5X10) 137 | ↓ 138 | Dropout p = 0.9 139 | ↓ 140 | Linear Layer Size: (10X5) 141 | ↓ 142 | Output (5X5) 143 | ``` 144 | 145 | #### Create input, dropout and linear layer 146 | ```python 147 | #input 148 | inputs = torch.ones(5,10) 149 | inputs = Variable(inputs) 150 | 151 | #dropout 152 | p = 0.9 153 | do = nn.Dropout(p) 154 | 155 | #linear 156 | fc = nn.Linear(10,5) 157 | ``` 158 | 159 | #### Let's look at the output of passing the input through dropout layer: 160 | 161 | ```python 162 | out = do(inputs) 163 | print(out) 164 | ``` 165 | Output: 166 | ``` 167 | Variable containing: 168 | 0 0 0 0 0 0 0 0 0 0 169 | 0 0 0 0 0 0 0 0 0 0 170 | 0 0 0 0 0 10 0 0 0 0 171 | 0 0 0 10 0 0 0 0 0 0 172 | 0 0 0 0 0 0 0 0 0 0 173 | [torch.FloatTensor of size 5x10] 174 | ``` 175 | 176 | Since `p = 0.9` about 90% of nodes have been set to 0 in the output. The remaining ones have been scaled by `1/(1-0.9) = 10`. 177 | 178 | If you run the above code again and again, the output will be different each time. The output I got by running the above code again is: 179 | 180 | ``` 181 | Variable containing: 182 | 0 10 0 0 0 0 0 0 0 0 183 | 0 0 0 0 10 0 0 0 0 0 184 | 0 0 0 0 0 0 0 0 0 0 185 | 0 0 0 0 0 0 0 0 0 0 186 | 0 0 0 0 0 0 0 0 0 0 187 | [torch.FloatTensor of size 5x10] 188 | ``` 189 | 190 | ### 2.2 Randomness of sampling subnetworks 191 | 192 | Thus, in a way we are sampling random subsets of network by randomly shutting some of the nodes off. But _how random_ are they actually? We see most of the rows tend to be the same (full of zeros). There's not much of randomness going on there. 193 | 194 | Let's try setting `p=0.1` and try again: 195 | 196 | ```python 197 | p = 0.1 198 | do = nn.Dropout(p) 199 | out = do(inputs) 200 | print(out) 201 | ``` 202 | Output: 203 | ``` 204 | Variable containing: 205 | 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 0.0000 1.1111 1.1111 206 | 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 0.0000 1.1111 207 | 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 208 | 1.1111 1.1111 1.1111 1.1111 0.0000 1.1111 1.1111 1.1111 1.1111 1.1111 209 | 1.1111 0.0000 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 1.1111 210 | [torch.FloatTensor of size 5x10] 211 | ``` 212 | 213 | Run the above code a few times and you will see there's not much of randomness going on. In this case, most rows tend to be full of the same value: `1.1111`. 214 | 215 | How do we get maximum randomness in our subnetworks? By setting `p=0.5`. If you are familiar with information theory, we are maximizing entropy here. 216 | 217 | ```python 218 | p = 0.5 219 | do = nn.Dropout(p) 220 | out = do(inputs) 221 | print(out) 222 | ``` 223 | Output: 224 | ``` 225 | Variable containing: 226 | 2 0 2 2 2 2 0 2 2 2 227 | 2 2 2 2 0 2 0 2 2 2 228 | 0 0 2 2 2 2 2 0 2 2 229 | 0 2 2 2 2 2 2 0 0 0 230 | 0 2 2 0 2 2 2 0 0 0 231 | [torch.FloatTensor of size 5x10] 232 | ``` 233 | 234 | It is a standard practice to use dropout with `p=0.5` while training a neural network. 235 | 236 | ### 2.3 Dropout + Backpropagation 237 | 238 | Training a network with dropout layer is pretty straightforward. Quoting from the original paper again: 239 | 240 | > Dropout neural networks can be trained using stochastic gradient descent in a manner simi- lar to standard neural nets. The only difference is that for each training case in a mini-batch, we sample a thinned network by dropping out units. Forward and backpropagation for that training case are done only on this thinned network. The gradients for each parameter are averaged over the training cases in each mini-batch. Any training case which does not use a parameter contributes a gradient of zero for that parameter. 241 | 242 | Here we put to use the linear layer we created in the last section. 243 | 244 | First, let's take a look at the weight and bias parameters of our linear layer: 245 | ```python 246 | print(fc.weight) 247 | print(fc.bias) 248 | ``` 249 | Output: 250 | ``` 251 | Parameter containing: 252 | 0.0738 0.3089 -0.0683 -0.2662 0.2217 -0.2005 0.0918 -0.0003 -0.0241 -0.0386 253 | -0.0407 -0.1743 0.0462 0.1506 0.2435 -0.2997 -0.1420 -0.1419 -0.1131 -0.2221 254 | -0.0406 0.2344 0.1432 -0.2777 -0.1128 0.0976 -0.1798 -0.0479 0.2498 -0.0814 255 | -0.0947 0.2826 -0.0856 0.2716 -0.1775 0.2035 -0.3161 -0.2716 0.0440 -0.1010 256 | 0.0102 -0.0008 -0.0904 0.2708 -0.0478 -0.1248 -0.0073 0.2026 -0.2273 -0.0355 257 | [torch.FloatTensor of size 5x10] 258 | 259 | Parameter containing: 260 | 0.1163 261 | -0.2310 262 | -0.0791 263 | 0.1346 264 | -0.2195 265 | [torch.FloatTensor of size 5] 266 | ``` 267 | 268 | We will create a dummy target and use it to calculate cross entropy loss. We will then use the loss value to backpropagate and take a look at how the gradients look. 269 | ```python 270 | target = torch.LongTensor([0,4,2,1,4]) 271 | target = Variable(target) 272 | print(target) 273 | ``` 274 | Output: 275 | ``` 276 | Variable containing: 277 | 0 278 | 4 279 | 2 280 | 1 281 | 4 282 | [torch.LongTensor of size 5] 283 | ``` 284 | 285 | #### Calculate weight gradients for each training case individually and store in a variable. We use `p=0.9` here. 286 | ```python 287 | do = nn.Dropout(0.9) 288 | out = do(inputs) 289 | print(out) 290 | ``` 291 | Output: 292 | ``` 293 | Variable containing: 294 | 0 10 0 0 0 0 10 0 10 0 295 | 0 0 10 0 0 0 0 0 0 0 296 | 0 10 0 0 0 0 0 10 0 0 297 | 0 0 0 0 0 0 0 0 0 0 298 | 0 0 0 0 0 0 0 0 0 0 299 | [torch.FloatTensor of size 5x10] 300 | ``` 301 | 302 | #### Calculate gradients per training case 303 | We initialize an zero tensor `dWs` and keep adding gradients from each training case to it. 304 | ```python 305 | dWs = torch.zeros_like(fc.weight) 306 | for i in range(out.size(0)): 307 | i_ = out[i].view(1,-1) 308 | t = target[i] 309 | o = fc(i_) 310 | fc.weight.grad.zero_() 311 | loss = nn.CrossEntropyLoss()(o, t) 312 | loss.backward() 313 | dWs += fc.weight.grad 314 | ``` 315 | 316 | #### Let's take a look at the average of `dWs`: 317 | ```python 318 | n = dWs.size(0) 319 | print(dWs/n) 320 | ``` 321 | Output: 322 | ``` 323 | Variable containing: 324 | 0.0000 0.6987 0.1744 0.0000 0.0000 0.0000 -0.5988 1.2975 -0.5988 0.0000 325 | 0.0000 0.0021 0.3873 0.0000 0.0000 0.0000 0.0003 0.0018 0.0003 0.0000 326 | 0.0000 -1.1258 1.1889 0.0000 0.0000 0.0000 0.5595 -1.6853 0.5595 0.0000 327 | 0.0000 0.1041 0.1495 0.0000 0.0000 0.0000 0.0367 0.0674 0.0367 0.0000 328 | 0.0000 0.3209 -1.9001 0.0000 0.0000 0.0000 0.0022 0.3187 0.0022 0.0000 329 | [torch.FloatTensor of size 5x10] 330 | ``` 331 | 332 | #### Now let's take a look at the gradient of the whole batch. It should be the same as `dWs/n` above. 333 | 334 | ```python 335 | out = fc(out) 336 | criterion = nn.CrossEntropyLoss() 337 | loss = criterion(out, target) 338 | fc.weight.grad.zero_() 339 | loss.backward() 340 | print(fc.weight.grad) 341 | ``` 342 | Output: 343 | ``` 344 | Variable containing: 345 | 0.0000 0.6987 0.1744 0.0000 0.0000 0.0000 -0.5988 1.2975 -0.5988 0.0000 346 | 0.0000 0.0021 0.3873 0.0000 0.0000 0.0000 0.0003 0.0018 0.0003 0.0000 347 | 0.0000 -1.1258 1.1889 0.0000 0.0000 0.0000 0.5595 -1.6853 0.5595 0.0000 348 | 0.0000 0.1041 0.1495 0.0000 0.0000 0.0000 0.0367 0.0674 0.0367 0.0000 349 | 0.0000 0.3209 -1.9001 0.0000 0.0000 0.0000 0.0022 0.3187 0.0022 0.0000 350 | [torch.FloatTensor of size 5x10] 351 | ``` 352 | 353 | We get the same output as `dWs/n`. 354 | -------------------------------------------------------------------------------- /markdowns/einstein_summation.md: -------------------------------------------------------------------------------- 1 | 2 | ### Einstein summation 3 | 4 | Einstein summation is a notation to multiply and/or add tensors which can be convenient and intuitive. The posts below do a pretty good job of explaining how it works. This notebook is more of a practice session to get used to applying einstein summation notation. 5 | 6 | - [Einstein Summation in Numpy](https://obilaniu6266h16.wordpress.com/2016/02/04/einstein-summation-in-numpy/) 7 | - [A basic introduction to NumPy's einsum](https://ajcr.net/Basic-guide-to-einsum/) 8 | 9 | The examples here have been taken from the tables in the second link above. 10 | 11 | 12 | ```python 13 | import numpy as np 14 | import torch 15 | ``` 16 | 17 | 18 | ```python 19 | v = np.arange(6) 20 | # array([0, 1, 2, 3, 4, 5]) 21 | 22 | w = np.arange(6) + 6 23 | # array([ 6, 7, 8, 9, 10, 11]) 24 | 25 | A = np.arange(6).reshape(2,3) 26 | # array([[0, 1, 2], 27 | # [3, 4, 5]]) 28 | 29 | B = (np.arange(6) + 6).reshape(3,2) 30 | # array([[ 6, 7], 31 | # [ 8, 9], 32 | # [10, 11]]) 33 | 34 | C = np.arange(9).reshape(3,3) 35 | # array([[0, 1, 2], 36 | # [3, 4, 5], 37 | # [6, 7, 8]]) 38 | 39 | 40 | v_torch = torch.from_numpy(v) 41 | w_torch = torch.from_numpy(w) 42 | A_torch = torch.from_numpy(A) 43 | B_torch = torch.from_numpy(B) 44 | C_torch = torch.from_numpy(C) 45 | ``` 46 | 47 | ## 1. Vectors 48 | 49 | ### 1.1 Return a view of itself 50 | 51 | 52 | ```python 53 | np.einsum('i', v) 54 | ``` 55 | 56 | 57 | 58 | 59 | array([0, 1, 2, 3, 4, 5]) 60 | 61 | 62 | 63 | 64 | ```python 65 | torch.einsum('i', v_torch) 66 | ``` 67 | 68 | 69 | 70 | 71 | tensor([0, 1, 2, 3, 4, 5]) 72 | 73 | 74 | 75 | ### 1.2 Sum up elements of a vector 76 | 77 | 78 | ```python 79 | np.einsum('i->', v) 80 | ``` 81 | 82 | 83 | 84 | 85 | 15 86 | 87 | 88 | 89 | 90 | ```python 91 | torch.einsum('i->', v_torch) 92 | ``` 93 | 94 | 95 | 96 | 97 | tensor(15) 98 | 99 | 100 | 101 | 102 | ```python 103 | v.sum(), v_torch.sum() 104 | ``` 105 | 106 | 107 | 108 | 109 | (15, tensor(15)) 110 | 111 | 112 | 113 | ### 1.3 Element-wise operations 114 | 115 | 116 | ```python 117 | np.einsum('i,i->i', v, w) 118 | ``` 119 | 120 | 121 | 122 | 123 | array([ 0, 7, 16, 27, 40, 55]) 124 | 125 | 126 | 127 | 128 | ```python 129 | v * w 130 | ``` 131 | 132 | 133 | 134 | 135 | array([ 0, 7, 16, 27, 40, 55]) 136 | 137 | 138 | 139 | 140 | ```python 141 | torch.einsum('i,i->i', v_torch, w_torch) 142 | ``` 143 | 144 | 145 | 146 | 147 | tensor([ 0, 7, 16, 27, 40, 55]) 148 | 149 | 150 | 151 | #### 1.3.1 One can even multiply (element-wise) three or more vectors in the same manner 152 | 153 | 154 | ```python 155 | np.einsum('i,i,i->i', v, w, w) 156 | ``` 157 | 158 | 159 | 160 | 161 | array([ 0, 49, 128, 243, 400, 605]) 162 | 163 | 164 | 165 | 166 | ```python 167 | v*w*w 168 | ``` 169 | 170 | 171 | 172 | 173 | array([ 0, 49, 128, 243, 400, 605]) 174 | 175 | 176 | 177 | 178 | ```python 179 | torch.einsum('i,i,i->i', v_torch, w_torch, w_torch) 180 | ``` 181 | 182 | 183 | 184 | 185 | tensor([ 0, 49, 128, 243, 400, 605]) 186 | 187 | 188 | 189 | ### 1.4 Inner product 190 | 191 | 192 | ```python 193 | np.einsum('i,i->', v,w) 194 | ``` 195 | 196 | 197 | 198 | 199 | 145 200 | 201 | 202 | 203 | 204 | ```python 205 | np.dot(v,w) 206 | ``` 207 | 208 | 209 | 210 | 211 | 145 212 | 213 | 214 | 215 | 216 | ```python 217 | torch.einsum('i,i->',v_torch,w_torch) 218 | ``` 219 | 220 | 221 | 222 | 223 | tensor(145) 224 | 225 | 226 | 227 | ### 1.5 Outer product 228 | 229 | 230 | ```python 231 | np.einsum('i,j->ij',v,w) 232 | ``` 233 | 234 | 235 | 236 | 237 | array([[ 0, 0, 0, 0, 0, 0], 238 | [ 6, 7, 8, 9, 10, 11], 239 | [12, 14, 16, 18, 20, 22], 240 | [18, 21, 24, 27, 30, 33], 241 | [24, 28, 32, 36, 40, 44], 242 | [30, 35, 40, 45, 50, 55]]) 243 | 244 | 245 | 246 | 247 | ```python 248 | np.outer(v,w) 249 | ``` 250 | 251 | 252 | 253 | 254 | array([[ 0, 0, 0, 0, 0, 0], 255 | [ 6, 7, 8, 9, 10, 11], 256 | [12, 14, 16, 18, 20, 22], 257 | [18, 21, 24, 27, 30, 33], 258 | [24, 28, 32, 36, 40, 44], 259 | [30, 35, 40, 45, 50, 55]]) 260 | 261 | 262 | 263 | 264 | ```python 265 | torch.einsum('i,j->ij',v_torch, w_torch) 266 | ``` 267 | 268 | 269 | 270 | 271 | tensor([[ 0, 0, 0, 0, 0, 0], 272 | [ 6, 7, 8, 9, 10, 11], 273 | [12, 14, 16, 18, 20, 22], 274 | [18, 21, 24, 27, 30, 33], 275 | [24, 28, 32, 36, 40, 44], 276 | [30, 35, 40, 45, 50, 55]]) 277 | 278 | 279 | 280 | #### 1.5.1 Transpose is just as easy 281 | 282 | 283 | ```python 284 | np.einsum('i,j->ji',v,w) 285 | ``` 286 | 287 | 288 | 289 | 290 | array([[ 0, 6, 12, 18, 24, 30], 291 | [ 0, 7, 14, 21, 28, 35], 292 | [ 0, 8, 16, 24, 32, 40], 293 | [ 0, 9, 18, 27, 36, 45], 294 | [ 0, 10, 20, 30, 40, 50], 295 | [ 0, 11, 22, 33, 44, 55]]) 296 | 297 | 298 | 299 | ## 2. Matrices 300 | 301 | ### 2.1 Do nothing 302 | 303 | 304 | ```python 305 | np.einsum('ij',B) 306 | ``` 307 | 308 | 309 | 310 | 311 | array([[ 6, 7], 312 | [ 8, 9], 313 | [10, 11]]) 314 | 315 | 316 | 317 | 318 | ```python 319 | torch.einsum('ij', B_torch) 320 | ``` 321 | 322 | 323 | 324 | 325 | tensor([[ 6, 7], 326 | [ 8, 9], 327 | [10, 11]]) 328 | 329 | 330 | 331 | ### 2.2 Transpose 332 | 333 | 334 | ```python 335 | A 336 | ``` 337 | 338 | 339 | 340 | 341 | array([[0, 1, 2], 342 | [3, 4, 5]]) 343 | 344 | 345 | 346 | 347 | ```python 348 | np.einsum('ij->ji', A) 349 | ``` 350 | 351 | 352 | 353 | 354 | array([[0, 3], 355 | [1, 4], 356 | [2, 5]]) 357 | 358 | 359 | 360 | 361 | ```python 362 | torch.einsum('ij->ji', A_torch) 363 | ``` 364 | 365 | 366 | 367 | 368 | tensor([[0, 3], 369 | [1, 4], 370 | [2, 5]]) 371 | 372 | 373 | 374 | 375 | ```python 376 | print(A.T) 377 | print('---') 378 | print(A_torch.t()) 379 | ``` 380 | 381 | [[0 3] 382 | [1 4] 383 | [2 5]] 384 | --- 385 | tensor([[0, 3], 386 | [1, 4], 387 | [2, 5]]) 388 | 389 | 390 | ### 2.3 Diagonal of a square matrix 391 | 392 | 393 | ```python 394 | C 395 | ``` 396 | 397 | 398 | 399 | 400 | array([[0, 1, 2], 401 | [3, 4, 5], 402 | [6, 7, 8]]) 403 | 404 | 405 | 406 | 407 | ```python 408 | np.einsum('ii->i',C) 409 | ``` 410 | 411 | 412 | 413 | 414 | array([0, 4, 8]) 415 | 416 | 417 | 418 | 419 | ```python 420 | torch.einsum('ii->i', C_torch) 421 | ``` 422 | 423 | 424 | 425 | 426 | tensor([0, 4, 8]) 427 | 428 | 429 | 430 | 431 | ```python 432 | np.diag(C), torch.diag(C_torch) 433 | ``` 434 | 435 | 436 | 437 | 438 | (array([0, 4, 8]), tensor([0, 4, 8])) 439 | 440 | 441 | 442 | ### 2.4 Trace of a matrix 443 | 444 | 445 | ```python 446 | np.einsum('ii->',C) 447 | ``` 448 | 449 | 450 | 451 | 452 | 12 453 | 454 | 455 | 456 | 457 | ```python 458 | torch.einsum('ii->', C_torch) 459 | ``` 460 | 461 | 462 | 463 | 464 | tensor(12) 465 | 466 | 467 | 468 | 469 | ```python 470 | np.trace(C), torch.trace(C_torch) 471 | ``` 472 | 473 | 474 | 475 | 476 | (12, tensor(12)) 477 | 478 | 479 | 480 | ### 2.5 Sum of matrix 481 | 482 | 483 | ```python 484 | np.einsum('ij->',A) 485 | ``` 486 | 487 | 488 | 489 | 490 | 15 491 | 492 | 493 | 494 | 495 | ```python 496 | torch.einsum('ij->',A_torch) 497 | ``` 498 | 499 | 500 | 501 | 502 | tensor(15) 503 | 504 | 505 | 506 | 507 | ```python 508 | A.sum(), A_torch.sum() 509 | ``` 510 | 511 | 512 | 513 | 514 | (15, tensor(15)) 515 | 516 | 517 | 518 | ### 2.6 Sum of matrix along axes 519 | 520 | 521 | ```python 522 | np.einsum('ij->j',B) 523 | ``` 524 | 525 | 526 | 527 | 528 | array([24, 27]) 529 | 530 | 531 | 532 | 533 | ```python 534 | B.sum(0) 535 | ``` 536 | 537 | 538 | 539 | 540 | array([24, 27]) 541 | 542 | 543 | 544 | 545 | ```python 546 | torch.einsum('ij->j', B_torch) 547 | ``` 548 | 549 | 550 | 551 | 552 | tensor([24, 27]) 553 | 554 | 555 | 556 | 557 | ```python 558 | torch.einsum('ij->i', B_torch) 559 | ``` 560 | 561 | 562 | 563 | 564 | tensor([13, 17, 21]) 565 | 566 | 567 | 568 | 569 | ```python 570 | B_torch.sum(1) 571 | ``` 572 | 573 | 574 | 575 | 576 | tensor([13, 17, 21]) 577 | 578 | 579 | 580 | ### 2.7 Element-wise multiplication 581 | 582 | 583 | ```python 584 | np.einsum('ij,ij->ij',A,B.T) 585 | ``` 586 | 587 | 588 | 589 | 590 | array([[ 0, 8, 20], 591 | [21, 36, 55]]) 592 | 593 | 594 | 595 | 596 | ```python 597 | A * B.T 598 | ``` 599 | 600 | 601 | 602 | 603 | array([[ 0, 8, 20], 604 | [21, 36, 55]]) 605 | 606 | 607 | 608 | #### 2.7.1 Element-wise multiplication can be done in various ways just by permuting the indices 609 | 610 | 611 | ```python 612 | torch.einsum('ij,ji->ij', C_torch.t(), C_torch) 613 | ``` 614 | 615 | 616 | 617 | 618 | tensor([[ 0, 9, 36], 619 | [ 1, 16, 49], 620 | [ 4, 25, 64]]) 621 | 622 | 623 | 624 | 625 | ```python 626 | C_torch.t() * C_torch.t() 627 | ``` 628 | 629 | 630 | 631 | 632 | tensor([[ 0, 9, 36], 633 | [ 1, 16, 49], 634 | [ 4, 25, 64]]) 635 | 636 | 637 | 638 | #### 2.7.2 The below two operations are also the same 639 | 640 | 641 | ```python 642 | torch.einsum('ij,ji->ij', C_torch, C_torch) 643 | ``` 644 | 645 | 646 | 647 | 648 | tensor([[ 0, 3, 12], 649 | [ 3, 16, 35], 650 | [12, 35, 64]]) 651 | 652 | 653 | 654 | 655 | ```python 656 | C_torch * C_torch.t() 657 | ``` 658 | 659 | 660 | 661 | 662 | tensor([[ 0, 3, 12], 663 | [ 3, 16, 35], 664 | [12, 35, 64]]) 665 | 666 | 667 | 668 | ### 2.8 Matrix multiplication 669 | 670 | 671 | ```python 672 | np.einsum('ij,jk->ik',A,B) 673 | ``` 674 | 675 | 676 | 677 | 678 | array([[ 28, 31], 679 | [100, 112]]) 680 | 681 | 682 | 683 | 684 | ```python 685 | np.einsum('ij,jk',A,B) 686 | ``` 687 | 688 | 689 | 690 | 691 | array([[ 28, 31], 692 | [100, 112]]) 693 | 694 | 695 | 696 | 697 | ```python 698 | np.dot(A,B) 699 | ``` 700 | 701 | 702 | 703 | 704 | array([[ 28, 31], 705 | [100, 112]]) 706 | 707 | 708 | 709 | 710 | ```python 711 | torch.einsum('ij,jk->ik', A_torch, B_torch) 712 | ``` 713 | 714 | 715 | 716 | 717 | tensor([[ 28, 31], 718 | [100, 112]]) 719 | 720 | 721 | 722 | 723 | ```python 724 | torch.einsum('ij,jk', A_torch, B_torch) 725 | ``` 726 | 727 | 728 | 729 | 730 | tensor([[ 28, 31], 731 | [100, 112]]) 732 | 733 | 734 | 735 | ### 2.9 Inner product of two matrices 736 | 737 | 738 | ```python 739 | A.shape, B.shape, C.shape 740 | ``` 741 | 742 | 743 | 744 | 745 | ((2, 3), (3, 2), (3, 3)) 746 | 747 | 748 | 749 | 750 | ```python 751 | print(A) 752 | print('---') 753 | print(C) 754 | ``` 755 | 756 | [[0 1 2] 757 | [3 4 5]] 758 | --- 759 | [[0 1 2] 760 | [3 4 5] 761 | [6 7 8]] 762 | 763 | 764 | 765 | ```python 766 | np.einsum('ij,kj->ik', A, C) 767 | ``` 768 | 769 | 770 | 771 | 772 | array([[ 5, 14, 23], 773 | [14, 50, 86]]) 774 | 775 | 776 | 777 | 778 | ```python 779 | torch.einsum('ij,kj->ik', A_torch, C_torch) 780 | ``` 781 | 782 | 783 | 784 | 785 | tensor([[ 5, 14, 23], 786 | [14, 50, 86]]) 787 | 788 | 789 | 790 | 791 | ```python 792 | i,j = A.shape 793 | k,j2 = C.shape 794 | assert j == j2 795 | 796 | result = np.empty((i,k)) 797 | 798 | for index_i in range(i): 799 | for index_k in range(k): 800 | total = 0 801 | for index_j in range(j): 802 | total += A[index_i, index_j] * C[index_k, index_j] 803 | result[index_i, index_k] = total 804 | 805 | result 806 | ``` 807 | 808 | 809 | 810 | 811 | array([[ 5., 14., 23.], 812 | [14., 50., 86.]]) 813 | 814 | 815 | 816 | ## 3. Higher-order tensors 817 | 818 | ### 3.1 Each row of A multiplied (element-wise) by C 819 | 820 | 821 | ```python 822 | np.einsum('ij,kj->ikj', A, C) 823 | ``` 824 | 825 | 826 | 827 | 828 | array([[[ 0, 1, 4], 829 | [ 0, 4, 10], 830 | [ 0, 7, 16]], 831 | 832 | [[ 0, 4, 10], 833 | [ 9, 16, 25], 834 | [18, 28, 40]]]) 835 | 836 | 837 | 838 | 839 | ```python 840 | A 841 | ``` 842 | 843 | 844 | 845 | 846 | array([[0, 1, 2], 847 | [3, 4, 5]]) 848 | 849 | 850 | 851 | 852 | ```python 853 | C 854 | ``` 855 | 856 | 857 | 858 | 859 | array([[0, 1, 2], 860 | [3, 4, 5], 861 | [6, 7, 8]]) 862 | 863 | 864 | 865 | 866 | ```python 867 | torch.einsum('ij,kj->ikj', A_torch, C_torch) 868 | ``` 869 | 870 | 871 | 872 | 873 | tensor([[[ 0, 1, 4], 874 | [ 0, 4, 10], 875 | [ 0, 7, 16]], 876 | 877 | [[ 0, 4, 10], 878 | [ 9, 16, 25], 879 | [18, 28, 40]]]) 880 | 881 | 882 | 883 | 884 | ```python 885 | i,j = A.shape 886 | k,j2 = C.shape 887 | assert j == j2 888 | 889 | result = np.empty((i,k,j)) 890 | 891 | for index_i in range(i): 892 | for index_k in range(k): 893 | for index_j in range(j): 894 | result[index_i, index_k, index_j] = A[index_i, index_j] * C[index_k, index_j] 895 | 896 | result 897 | ``` 898 | 899 | 900 | 901 | 902 | array([[[ 0., 1., 4.], 903 | [ 0., 4., 10.], 904 | [ 0., 7., 16.]], 905 | 906 | [[ 0., 4., 10.], 907 | [ 9., 16., 25.], 908 | [18., 28., 40.]]]) 909 | 910 | 911 | 912 | ### 3.2 Each element of first matrix multiplied (element-wise) by second matrix 913 | 914 | 915 | ```python 916 | A 917 | ``` 918 | 919 | 920 | 921 | 922 | array([[0, 1, 2], 923 | [3, 4, 5]]) 924 | 925 | 926 | 927 | 928 | ```python 929 | B 930 | ``` 931 | 932 | 933 | 934 | 935 | array([[ 6, 7], 936 | [ 8, 9], 937 | [10, 11]]) 938 | 939 | 940 | 941 | 942 | ```python 943 | np.einsum('ij,kl->ijkl', A, B) 944 | ``` 945 | 946 | 947 | 948 | 949 | array([[[[ 0, 0], 950 | [ 0, 0], 951 | [ 0, 0]], 952 | 953 | [[ 6, 7], 954 | [ 8, 9], 955 | [10, 11]], 956 | 957 | [[12, 14], 958 | [16, 18], 959 | [20, 22]]], 960 | 961 | 962 | [[[18, 21], 963 | [24, 27], 964 | [30, 33]], 965 | 966 | [[24, 28], 967 | [32, 36], 968 | [40, 44]], 969 | 970 | [[30, 35], 971 | [40, 45], 972 | [50, 55]]]]) 973 | 974 | 975 | 976 | 977 | ```python 978 | i,j = A.shape 979 | k,l = B.shape 980 | 981 | result = np.empty((i,j,k,l)) 982 | 983 | for index_i in range(i): 984 | for index_j in range(j): 985 | for index_k in range(k): 986 | for index_l in range(l): 987 | result[index_i, index_j, index_k, index_l] = A[index_i, index_j] * B[index_k, index_l] 988 | 989 | result 990 | ``` 991 | 992 | 993 | 994 | 995 | array([[[[ 0., 0.], 996 | [ 0., 0.], 997 | [ 0., 0.]], 998 | 999 | [[ 6., 7.], 1000 | [ 8., 9.], 1001 | [10., 11.]], 1002 | 1003 | [[12., 14.], 1004 | [16., 18.], 1005 | [20., 22.]]], 1006 | 1007 | 1008 | [[[18., 21.], 1009 | [24., 27.], 1010 | [30., 33.]], 1011 | 1012 | [[24., 28.], 1013 | [32., 36.], 1014 | [40., 44.]], 1015 | 1016 | [[30., 35.], 1017 | [40., 45.], 1018 | [50., 55.]]]]) 1019 | 1020 | 1021 | 1022 | ## 4. Some examples in PyTorch 1023 | 1024 | ### 4.1 Batch multiplication 1025 | 1026 | 1027 | ```python 1028 | m1 = torch.randn(4,5,3) 1029 | m2 = torch.randn(4,3,5) 1030 | torch.bmm(m1,m2) 1031 | ``` 1032 | 1033 | 1034 | 1035 | 1036 | tensor([[[ 8.2181e-01, -5.8035e-01, 2.2078e+00, 1.4295e+00, 1.8635e+00], 1037 | [ 8.4052e-01, -1.0589e-01, 1.4207e+00, 9.9271e-01, 2.3920e+00], 1038 | [-1.5352e-02, -5.3438e-01, 3.2493e+00, 1.4200e+00, 3.1127e+00], 1039 | [-1.3939e+00, 1.3775e+00, -2.2805e+00, -1.9652e+00, 5.8474e-01], 1040 | [ 2.0536e+00, -6.2420e-01, 2.3070e+00, 2.0755e+00, 2.6713e+00]], 1041 | 1042 | [[ 7.0778e-01, 1.4530e-01, 1.9873e+00, 2.1278e+00, -3.3463e-01], 1043 | [ 1.5298e-01, -1.7556e+00, -2.0336e+00, -3.3895e+00, 2.8165e-03], 1044 | [-4.9915e-01, 2.9698e-02, -9.9656e-01, -5.8711e-01, 3.3424e-01], 1045 | [-7.3538e-02, -6.1606e-01, -1.1962e+00, -1.9709e+00, -1.5120e-02], 1046 | [-8.9944e-01, -2.8037e-01, -2.4217e+00, -2.2450e+00, 5.3909e-01]], 1047 | 1048 | [[ 7.2014e-01, 1.4626e+00, -5.0004e-01, 1.5049e+00, -4.8354e-01], 1049 | [-1.4394e+00, 2.9822e-01, 1.2814e+00, 2.6319e+00, 3.4886e+00], 1050 | [ 4.2215e-01, 1.8446e+00, -5.1419e-03, 1.3402e-01, 1.1296e-01], 1051 | [-2.9111e-01, 7.2741e-01, 1.2039e-01, 4.1217e+00, 1.5959e+00], 1052 | [ 4.1003e-01, -1.7096e-01, -2.7092e-01, -2.1486e+00, -1.2508e+00]], 1053 | 1054 | [[ 8.9113e-01, 1.7029e+00, 1.9923e-01, -3.2940e-01, -9.3853e-01], 1055 | [ 7.0222e+00, 1.5919e-01, -2.6844e+00, 1.0894e+00, 1.2812e+00], 1056 | [ 6.5914e-02, -1.0917e+00, -1.0884e-01, 6.8298e-01, 1.0476e+00], 1057 | [ 3.3019e+00, 2.1760e+00, -1.3713e+00, -1.1536e+00, -1.7116e+00], 1058 | [-4.7698e+00, -1.7244e-01, 1.6248e+00, -9.6778e-01, -1.0414e+00]]]) 1059 | 1060 | 1061 | 1062 | This is equivalent to: 1063 | 1064 | 1065 | ```python 1066 | torch.einsum('bij,bjk->bik', m1, m2) 1067 | ``` 1068 | 1069 | 1070 | 1071 | 1072 | tensor([[[ 8.2181e-01, -5.8035e-01, 2.2078e+00, 1.4295e+00, 1.8635e+00], 1073 | [ 8.4052e-01, -1.0589e-01, 1.4207e+00, 9.9271e-01, 2.3920e+00], 1074 | [-1.5352e-02, -5.3438e-01, 3.2493e+00, 1.4200e+00, 3.1127e+00], 1075 | [-1.3939e+00, 1.3775e+00, -2.2805e+00, -1.9652e+00, 5.8474e-01], 1076 | [ 2.0536e+00, -6.2420e-01, 2.3070e+00, 2.0755e+00, 2.6713e+00]], 1077 | 1078 | [[ 7.0778e-01, 1.4530e-01, 1.9873e+00, 2.1278e+00, -3.3463e-01], 1079 | [ 1.5298e-01, -1.7556e+00, -2.0336e+00, -3.3895e+00, 2.8165e-03], 1080 | [-4.9915e-01, 2.9698e-02, -9.9656e-01, -5.8711e-01, 3.3424e-01], 1081 | [-7.3538e-02, -6.1606e-01, -1.1962e+00, -1.9709e+00, -1.5120e-02], 1082 | [-8.9944e-01, -2.8037e-01, -2.4217e+00, -2.2450e+00, 5.3909e-01]], 1083 | 1084 | [[ 7.2014e-01, 1.4626e+00, -5.0004e-01, 1.5049e+00, -4.8354e-01], 1085 | [-1.4394e+00, 2.9822e-01, 1.2814e+00, 2.6319e+00, 3.4886e+00], 1086 | [ 4.2215e-01, 1.8446e+00, -5.1419e-03, 1.3402e-01, 1.1296e-01], 1087 | [-2.9111e-01, 7.2741e-01, 1.2039e-01, 4.1217e+00, 1.5959e+00], 1088 | [ 4.1003e-01, -1.7096e-01, -2.7092e-01, -2.1486e+00, -1.2508e+00]], 1089 | 1090 | [[ 8.9113e-01, 1.7029e+00, 1.9923e-01, -3.2940e-01, -9.3853e-01], 1091 | [ 7.0222e+00, 1.5919e-01, -2.6844e+00, 1.0894e+00, 1.2812e+00], 1092 | [ 6.5914e-02, -1.0917e+00, -1.0884e-01, 6.8298e-01, 1.0476e+00], 1093 | [ 3.3019e+00, 2.1760e+00, -1.3713e+00, -1.1536e+00, -1.7116e+00], 1094 | [-4.7698e+00, -1.7244e-01, 1.6248e+00, -9.6778e-01, -1.0414e+00]]]) 1095 | 1096 | 1097 | 1098 | ### The official implementation of [MoCo paper](https://arxiv.org/pdf/1911.05722.pdf) uses `einsum` for calculating dot products with positive and negative examples [as shown here](https://github.com/facebookresearch/moco/blob/master/moco/builder.py#L143-L146): 1099 | 1100 | 1101 | ```python 1102 | # positive logits: Nx1 1103 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 1104 | # negative logits: NxK 1105 | l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) 1106 | ``` 1107 | -------------------------------------------------------------------------------- /markdowns/groups.md: -------------------------------------------------------------------------------- 1 | 2 | ### Groups in convolution 3 | 4 | It is easier to read the explanation and then look at the visuals and the code to understand how it works. There will not be much text to explain what is going on. 5 | 6 | In the code shown below, x has value `(c % 16) + 1` where `c` is the index of second axis. 7 | 8 | 9 | ```python 10 | import torch 11 | import torch.nn as nn 12 | from itertools import product 13 | 14 | x = torch.ones(2,16,10,10) 15 | for n,c,h,w in product( *[range(i) for i in x.size()] ): 16 | val = (c % 16) + 1 17 | x[n,c,h,w] = val 18 | ``` 19 | 20 | ### Official documentation: 21 | 22 | * `groups` controls the connections between inputs and outputs. 23 | * `in_channels` and `out_channels` must both be divisible by `groups`. For example, 24 | 25 | * At groups=1, all inputs are convolved to all outputs. 26 | * At groups=2, the operation becomes equivalent to having two conv 27 | layers side by side, each seeing half the input channels, 28 | and producing half the output channels, and both subsequently 29 | concatenated. 30 | * At groups= `in_channels`, each input channel is convolved with 31 | its own set of filters 32 | 33 | Here is how it can be explained visually: 34 | 35 | ![](../assets/groups_1.png) 36 | 37 | Define convolutional layers with number of groups = 1,2 and 4: 38 | 39 | ```python 40 | conv_g1 = nn.Conv2d(in_channels=16, out_channels=4, kernel_size=1, groups=1) 41 | conv_g2 = nn.Conv2d(in_channels=16, out_channels=4, kernel_size=1, groups=2) 42 | conv_g4 = nn.Conv2d(in_channels=16, out_channels=4, kernel_size=1, groups=4) 43 | conv_nets = [conv_g1, conv_g2, conv_g4] 44 | ``` 45 | 46 | Can you guess what the size of the weights looks like? 47 | 48 | ```python 49 | _ = [print(conv.weight.size()) for conv in conv_nets] 50 | ``` 51 | 52 | torch.Size([4, 16, 1, 1]) 53 | torch.Size([4, 8, 1, 1]) 54 | torch.Size([4, 4, 1, 1]) 55 | 56 | 57 | 58 | ```python 59 | _ = [print(conv.bias.size()) for conv in conv_nets] 60 | ``` 61 | 62 | torch.Size([4]) 63 | torch.Size([4]) 64 | torch.Size([4]) 65 | 66 | Initialize the parameters to have unit weights and zero biases: 67 | 68 | ```python 69 | for index, conv in enumerate(conv_nets, 1): 70 | nn.init.constant_(conv.weight, index) 71 | nn.init.constant_(conv.bias, 0) 72 | ``` 73 | 74 | 75 | ```python 76 | _ = [print(conv.weight.mean().item()) for conv in conv_nets] 77 | ``` 78 | 79 | 1.0 80 | 2.0 81 | 3.0 82 | 83 | 84 | 85 | ```python 86 | _ = [print(conv.bias.mean().item()) for conv in conv_nets] 87 | ``` 88 | 89 | 0.0 90 | 0.0 91 | 0.0 92 | 93 | Regardless of the number of groups, the output size will always be the same: 94 | 95 | ```python 96 | _ = [print(conv(x).size()) for conv in conv_nets] 97 | ``` 98 | 99 | torch.Size([2, 4, 10, 10]) 100 | torch.Size([2, 4, 10, 10]) 101 | torch.Size([2, 4, 10, 10]) 102 | 103 | But the values will be different: 104 | 105 | ```python 106 | _ = [print(conv(x).mean().item()) for conv in conv_nets] 107 | ``` 108 | 109 | 136.0 110 | 136.0 111 | 102.0 112 | 113 | 114 | 115 | ```python 116 | y = conv_g4(x) 117 | ``` 118 | 119 | ![](../assets/groups_2.png) 120 | 121 | 122 | ```python 123 | for i in range(4): 124 | print(y[:,i].mean(1)) 125 | ``` 126 | 127 | tensor([[ 30., 30., 30., 30., 30., 30., 30., 30., 30., 30.], 128 | [ 30., 30., 30., 30., 30., 30., 30., 30., 30., 30.]]) 129 | tensor([[ 78., 78., 78., 78., 78., 78., 78., 78., 78., 78.], 130 | [ 78., 78., 78., 78., 78., 78., 78., 78., 78., 78.]]) 131 | tensor([[ 126., 126., 126., 126., 126., 126., 126., 126., 126., 132 | 126.], 133 | [ 126., 126., 126., 126., 126., 126., 126., 126., 126., 134 | 126.]]) 135 | tensor([[ 174., 174., 174., 174., 174., 174., 174., 174., 174., 136 | 174.], 137 | [ 174., 174., 174., 174., 174., 174., 174., 174., 174., 138 | 174.]]) 139 | -------------------------------------------------------------------------------- /markdowns/hooks.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ```python 4 | import torch 5 | import torch.nn as nn 6 | ``` 7 | 8 | ### Let's create a series of basic operations and calculate gradient 9 | 10 | 11 | ```python 12 | x = torch.tensor([1.,2.,3.], requires_grad=True) 13 | y = x*x 14 | z = torch.sum(y) 15 | z 16 | ``` 17 | 18 | 19 | 20 | 21 | tensor(14., grad_fn=) 22 | 23 | 24 | 25 | 26 | ```python 27 | y.requires_grad, z.requires_grad 28 | ``` 29 | 30 | 31 | 32 | 33 | (True, True) 34 | 35 | 36 | 37 | 38 | ```python 39 | x.grad, y.grad, z.grad 40 | ``` 41 | 42 | 43 | 44 | 45 | (None, None, None) 46 | 47 | 48 | 49 | 50 | ```python 51 | z.backward() 52 | ``` 53 | 54 | 55 | ```python 56 | x.grad, y.grad, z.grad 57 | ``` 58 | 59 | 60 | 61 | 62 | (tensor([2., 4., 6.]), None, None) 63 | 64 | 65 | 66 | ### You can register a hook to a PyTorch `Module` or `Tensor` 67 | 68 | 69 | ```python 70 | x = torch.tensor([1.,2.,3.], requires_grad=True) 71 | y = x*x 72 | z = torch.sum(y) 73 | z 74 | ``` 75 | 76 | 77 | 78 | 79 | tensor(14., grad_fn=) 80 | 81 | 82 | 83 | #### `register_hook` registers a backward hook 84 | 85 | 86 | ```python 87 | h = y.register_hook(lambda grad: print(grad)) 88 | ``` 89 | 90 | 91 | ```python 92 | z.backward() 93 | ``` 94 | 95 | tensor([1., 1., 1.]) 96 | 97 | 98 | 99 | ```python 100 | x.grad, y.grad, z.grad 101 | ``` 102 | 103 | 104 | 105 | 106 | (tensor([2., 4., 6.]), None, None) 107 | 108 | 109 | 110 | #### You can also use a hook to manipulte gradients. The hook should have the following signature: 111 | 112 | `hook(grad) -> Tensor or None` 113 | 114 | 115 | ```python 116 | x = torch.tensor([1.,2.,3.], requires_grad=True) 117 | y = x*x 118 | z = torch.sum(y) 119 | z 120 | ``` 121 | 122 | 123 | 124 | 125 | tensor(14., grad_fn=) 126 | 127 | 128 | 129 | #### Let's add random noise to gradient of `y` 130 | 131 | 132 | ```python 133 | def add_random_noise_and_print(grad): 134 | noise = torch.randn(grad.size()) 135 | print('Noise:', noise) 136 | return grad + noise 137 | 138 | h = y.register_hook(add_random_noise_and_print) 139 | ``` 140 | 141 | 142 | ```python 143 | x.grad, y.grad, z.grad 144 | ``` 145 | 146 | 147 | 148 | 149 | (None, None, None) 150 | 151 | 152 | 153 | 154 | ```python 155 | z.backward() 156 | ``` 157 | 158 | Noise: tensor([1.1173, 1.2854, 0.0611]) 159 | 160 | 161 | 162 | ```python 163 | x.grad, y.grad, z.grad 164 | ``` 165 | 166 | 167 | 168 | 169 | (tensor([4.2346, 9.1415, 6.3665]), None, None) 170 | 171 | 172 | 173 | ### Forward and backward hooks 174 | 175 | There are three hooks listed below. They are available only for `nn.Module`s 176 | 177 | * __`register_forward_pre_hook`__: function is called BEFORE forward call 178 | * Signature: `hook(module, input) -> None 179 | 180 | 181 | * __`register_forward_hook`__: function is called AFTER forward call 182 | * Signature: `hook(module, input, output) -> None` 183 | 184 | 185 | * __`register_backward_hook`__: function is called AFTER gradients wrt module input are computed 186 | * Signature: `hook(module, grad_input, grad_output) -> Tensor or None` 187 | 188 | #### `register_forward_pre_hook` 189 | 190 | 191 | ```python 192 | linear = nn.Linear(10,1) 193 | ``` 194 | 195 | 196 | ```python 197 | [param for param in linear.parameters()] 198 | ``` 199 | 200 | 201 | 202 | 203 | [Parameter containing: 204 | tensor([[-0.3009, 0.0351, -0.2786, 0.1136, -0.2712, 0.0183, -0.2881, -0.1555, 205 | -0.3108, 0.0767]], requires_grad=True), Parameter containing: 206 | tensor([0.0377], requires_grad=True)] 207 | 208 | 209 | 210 | 211 | ```python 212 | h = linear.register_forward_pre_hook(lambda _, inp: print(inp)) 213 | ``` 214 | 215 | 216 | ```python 217 | x = torch.ones([8,10]) 218 | ``` 219 | 220 | 221 | ```python 222 | y = linear(x) 223 | ``` 224 | 225 | (tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 226 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 227 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 228 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 229 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 230 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 231 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 232 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),) 233 | 234 | 235 | #### `register_forward_hook` 236 | 237 | 238 | ```python 239 | linear = nn.Linear(10,1) 240 | [param for param in linear.parameters()] 241 | ``` 242 | 243 | 244 | 245 | 246 | [Parameter containing: 247 | tensor([[ 0.0978, -0.1878, 0.0189, 0.3040, 0.1120, 0.1977, 0.2137, -0.2841, 248 | -0.0718, 0.2079]], requires_grad=True), Parameter containing: 249 | tensor([0.2796], requires_grad=True)] 250 | 251 | 252 | 253 | 254 | ```python 255 | h = linear.register_forward_hook(lambda _, inp, out: print(inp, '\n\n', out)) 256 | ``` 257 | 258 | 259 | ```python 260 | y = linear(x) 261 | ``` 262 | 263 | (tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 264 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 265 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 266 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 267 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 268 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 269 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], 270 | [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]),) 271 | 272 | tensor([[0.8878], 273 | [0.8878], 274 | [0.8878], 275 | [0.8878], 276 | [0.8878], 277 | [0.8878], 278 | [0.8878], 279 | [0.8878]], grad_fn=) 280 | 281 | 282 | Just to verify, the result above can also be computed manually like so: 283 | 284 | 285 | ```python 286 | [param for param in linear.parameters()][0].sum() + [param for param in linear.parameters()][1] 287 | ``` 288 | 289 | 290 | 291 | 292 | tensor([0.8878], grad_fn=) 293 | 294 | 295 | 296 | #### `register_backward_hook` 297 | 298 | 299 | ```python 300 | linear = nn.Linear(3,1) 301 | [param for param in linear.parameters()] 302 | ``` 303 | 304 | 305 | 306 | 307 | [Parameter containing: 308 | tensor([[0.5395, 0.2303, 0.5583]], requires_grad=True), Parameter containing: 309 | tensor([-0.3510], requires_grad=True)] 310 | 311 | 312 | 313 | 314 | ```python 315 | def print_sizes(module, grad_inp, grad_out): 316 | print(len(grad_inp), len(grad_out)) 317 | print('===') 318 | print('Grad input sizes:', [i.size() for i in grad_inp if i is not None]) 319 | print('Grad output sizes:', [i.size() for i in grad_out if i is not None]) 320 | print('===') 321 | print('Grad_input 0:', grad_inp[0]) 322 | print('Grad_input 1:', grad_inp[1]) 323 | print('Grad_input 2:', grad_inp[2]) 324 | print('===') 325 | print(grad_out) 326 | 327 | h = linear.register_backward_hook(print_sizes) 328 | ``` 329 | 330 | 331 | ```python 332 | x = torch.ones([8,3]) * 10 333 | y = linear(x) 334 | ``` 335 | 336 | 337 | ```python 338 | y.backward(torch.ones_like(y) * 1.5) 339 | ``` 340 | 341 | 3 1 342 | === 343 | Grad input sizes: [torch.Size([1]), torch.Size([3, 1])] 344 | Grad output sizes: [torch.Size([8, 1])] 345 | === 346 | Grad_input 0: tensor([12.]) 347 | Grad_input 1: None 348 | Grad_input 2: tensor([[120.], 349 | [120.], 350 | [120.]]) 351 | === 352 | (tensor([[1.5000], 353 | [1.5000], 354 | [1.5000], 355 | [1.5000], 356 | [1.5000], 357 | [1.5000], 358 | [1.5000], 359 | [1.5000]]),) 360 | 361 | 362 | `register_backward_hook` can be used to tweak `grad_inp` values after they have been calculated. To do so, the hook function should return a new gradient with respect to input that will be used in place of `grad_input` in subsequent computations 363 | 364 | In the example below, we add random noise to the gradient of bias using the hook: 365 | 366 | 367 | ```python 368 | linear = nn.Linear(3,1) 369 | [param for param in linear.parameters()] 370 | ``` 371 | 372 | 373 | 374 | 375 | [Parameter containing: 376 | tensor([[-0.5438, 0.5539, 0.5210]], requires_grad=True), 377 | Parameter containing: 378 | tensor([-0.4839], requires_grad=True)] 379 | 380 | 381 | 382 | 383 | ```python 384 | def add_noise_and_print(module, grad_inp, grad_out): 385 | noise = torch.randn(grad_inp[0].size()) 386 | print('Noise:', noise) 387 | print('===') 388 | print('Grad input sizes:', [i.size() for i in grad_inp if i is not None]) 389 | print('Grad output sizes:', [i.size() for i in grad_out if i is not None]) 390 | print('===') 391 | print('Grad_input 0:', grad_inp[0]) 392 | print('Grad_input 1:', grad_inp[1]) 393 | print('Grad_input 2:', grad_inp[2]) 394 | print('===') 395 | print(grad_out) 396 | return (grad_inp[0] + noise, None, grad_inp[2]) 397 | 398 | h = linear.register_backward_hook(add_noise_and_print) 399 | ``` 400 | 401 | 402 | ```python 403 | x = torch.ones([8,3]) * 10 404 | y = linear(x) 405 | ``` 406 | 407 | 408 | ```python 409 | y.backward(torch.ones_like(y) * 1.5) 410 | ``` 411 | 412 | Noise: tensor([0.7553]) 413 | === 414 | Grad input sizes: [torch.Size([1]), torch.Size([3, 1])] 415 | Grad output sizes: [torch.Size([8, 1])] 416 | === 417 | Grad_input 0: tensor([12.]) 418 | Grad_input 1: None 419 | Grad_input 2: tensor([[120.], 420 | [120.], 421 | [120.]]) 422 | === 423 | (tensor([[1.5000], 424 | [1.5000], 425 | [1.5000], 426 | [1.5000], 427 | [1.5000], 428 | [1.5000], 429 | [1.5000], 430 | [1.5000]]),) 431 | 432 | 433 | `linear.bias.grad` was originally `12.0` but a random noise `Noise: tensor([0.7553])` was added to it: 434 | 435 | ``` 436 | 12 + 0.7553 = 12.7553 437 | ``` 438 | 439 | 440 | ```python 441 | linear.bias.grad 442 | ``` 443 | 444 | 445 | 446 | 447 | tensor([12.7553]) 448 | 449 | 450 | 451 | 452 | ```python 453 | linear.weight.grad 454 | ``` 455 | 456 | 457 | 458 | 459 | tensor([[120., 120., 120.]]) 460 | 461 | 462 | -------------------------------------------------------------------------------- /markdowns/register_buffer.md: -------------------------------------------------------------------------------- 1 | 2 | ## `register_buffer` 3 | 4 | The documentation says: 5 | 6 | > This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the persistent state. 7 | 8 | > Buffers can be accessed as attributes using given names. 9 | 10 | As an example we will just implement a simple running average using `register_buffer`: 11 | 12 | 13 | ```python 14 | import torch 15 | import torch.nn as nn 16 | ``` 17 | 18 | `running_average` and `count` are registered as buffers inside the `Model` definition. When an object is created from the `Model` class, it will have `running_average` and `count` as attributed. 19 | 20 | In the `forward` method, you can even define how these attributes will be updated once the model is called. 21 | 22 | 23 | ```python 24 | class Model(nn.Module): 25 | def __init__(self): 26 | super(Model, self).__init__() 27 | self.register_buffer('running_average', torch.Tensor([0.0])) 28 | self.register_buffer('count', torch.Tensor([0.0])) 29 | 30 | def forward(self, x): 31 | # self.count keeps a count of how many times the model was called 32 | self.count += 1 33 | 34 | # self.running_average keeps the running average in memory 35 | self.running_average = self.running_average.mul(self.count-1).add(x).div(self.count) 36 | return x 37 | ``` 38 | 39 | ### Note that items registered as buffers are not considered "parameters" of the model. Hence they will not show up under `model.parameters()`: 40 | 41 | 42 | ```python 43 | model = Model() 44 | list(model.parameters()) 45 | ``` 46 | 47 | 48 | 49 | 50 | [] 51 | 52 | 53 | 54 | ### However they do show up in the `state_dict`. What this means is that when you save the `state_dict`, these values will be saved (and later retrieved) as well: 55 | 56 | 57 | ```python 58 | model.state_dict() 59 | ``` 60 | 61 | 62 | 63 | 64 | OrderedDict([('running_average', tensor([0.])), ('count', tensor([0.]))]) 65 | 66 | 67 | 68 | ### Now let's just call the model a couple of times with different values and see how these values change: 69 | 70 | 71 | ```python 72 | model(10) 73 | ``` 74 | 75 | 76 | 77 | 78 | 10 79 | 80 | 81 | 82 | 83 | ```python 84 | model.count, model.running_average 85 | ``` 86 | 87 | 88 | 89 | 90 | (tensor([1.]), tensor([10.])) 91 | 92 | 93 | 94 | 95 | ```python 96 | model(10) 97 | model.count, model.running_average 98 | ``` 99 | 100 | 101 | 102 | 103 | (tensor([2.]), tensor([10.])) 104 | 105 | 106 | 107 | 108 | ```python 109 | model(5) 110 | model.count, model.running_average 111 | ``` 112 | 113 | 114 | 115 | 116 | (tensor([3.]), tensor([8.3333])) 117 | 118 | 119 | 120 | 121 | ```python 122 | model(15) 123 | model.count, model.running_average 124 | ``` 125 | 126 | 127 | 128 | 129 | (tensor([4.]), tensor([10.])) 130 | 131 | 132 | 133 | 134 | ```python 135 | model(1) 136 | model.count, model.running_average 137 | ``` 138 | 139 | 140 | 141 | 142 | (tensor([5.]), tensor([8.2000])) 143 | 144 | 145 | -------------------------------------------------------------------------------- /markdowns/weightnorm.md: -------------------------------------------------------------------------------- 1 | 2 | Weight normalization was introduced by OpenAI in their paper https://arxiv.org/abs/1602.07868 3 | 4 | It is now a part of PyTorch as built-in functionality. It helps in speeding up gradient descent and has been used in Wavenet. 5 | 6 | The basic idea is pretty straight-forward: for each neuron, the weight vector `w` is broken down into two components: its magnitude (`g`) and direction (unit vector in direction of `v`: `v / ||v||`). This we have: 7 | 8 | > `w = g * v / ||v||` where `w` and `v` are vectors and `g` is a scalar 9 | 10 | 11 | Instead of optimizing `w` directly, optimizing on `g` and `v` separately has been found to be faster and more accurate. 12 | 13 | 14 | ```python 15 | import torch 16 | import torch.nn as nn 17 | ``` 18 | 19 | ### There are two things to observe about weight normalization: 20 | 21 | 1) It increases the number of parameters (by the amount of neurons) 22 | 23 | 2) It adds a forward pre-hook to the module 24 | 25 | 26 | ```python 27 | get_num_of_params = lambda model: sum([p.numel() for p in model.parameters()]) 28 | get_param_name_and_size = lambda model: [(name, param.size()) for (name, param) in model.named_parameters()] 29 | ``` 30 | 31 | 32 | ```python 33 | linear = nn.Linear(5,3) 34 | ``` 35 | 36 | 37 | ```python 38 | linear._forward_pre_hooks #no hooks present 39 | ``` 40 | 41 | 42 | 43 | 44 | OrderedDict() 45 | 46 | 47 | 48 | 49 | ```python 50 | get_num_of_params(linear), get_param_name_and_size(linear) 51 | ``` 52 | 53 | 54 | 55 | 56 | (18, [('weight', torch.Size([3, 5])), ('bias', torch.Size([3]))]) 57 | 58 | 59 | 60 | ### Weight normalization in PyTorch can be done by calling the `nn.utils.weight_norm` function. 61 | 62 | By default, it normalizes the `weight` of a module: 63 | 64 | 65 | ```python 66 | _ = nn.utils.weight_norm(linear) 67 | ``` 68 | 69 | The number of parameters increased by 3 (we have 3 neurons here). Also the parameter `name` is replaced by two parameters `name_g` and `name_v` respectively: 70 | 71 | 72 | ```python 73 | get_num_of_params(linear), get_param_name_and_size(linear) 74 | ``` 75 | 76 | 77 | 78 | 79 | (21, 80 | [('bias', torch.Size([3])), 81 | ('weight_g', torch.Size([3, 1])), 82 | ('weight_v', torch.Size([3, 5]))]) 83 | 84 | 85 | 86 | 87 | ```python 88 | linear.weight_g.data 89 | ``` 90 | 91 | 92 | 93 | 94 | tensor([[0.5075], 95 | [0.4952], 96 | [0.6064]]) 97 | 98 | 99 | 100 | 101 | ```python 102 | linear.weight_v.data 103 | ``` 104 | 105 | 106 | 107 | 108 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 109 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 110 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 111 | 112 | 113 | 114 | Also note that `linear` module now has a forward pre-hook added to it: 115 | 116 | 117 | ```python 118 | linear._forward_pre_hooks 119 | ``` 120 | 121 | 122 | 123 | 124 | OrderedDict([(0, )]) 125 | 126 | 127 | 128 | The original `weight` is also present but is not a part of `parameters` attribute of `linear` module. Also note that `weight_v.data` is the same as `weight.data`: 129 | 130 | 131 | ```python 132 | linear.weight.data 133 | ``` 134 | 135 | 136 | 137 | 138 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 139 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 140 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 141 | 142 | 143 | 144 | 145 | ```python 146 | linear.weight_v.data 147 | ``` 148 | 149 | 150 | 151 | 152 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 153 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 154 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 155 | 156 | 157 | 158 | ### How is `weight_g` calculated? We will look at it in a moment 159 | 160 | 161 | ```python 162 | linear.weight_g.data 163 | ``` 164 | 165 | 166 | 167 | 168 | tensor([[0.5075], 169 | [0.4952], 170 | [0.6064]]) 171 | 172 | 173 | 174 | ### We can get `name` from `name_g` and `name_v` using `torch._weight_norm` function. We will also look at it in a moment 175 | 176 | 177 | ```python 178 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data) 179 | ``` 180 | 181 | 182 | 183 | 184 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 185 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 186 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 187 | 188 | 189 | 190 | ### Let's look at `norm_except_dim` function 191 | 192 | Pretty self-explanatory: it calculates the norm of a tensor except at dimsension provided to the function. We can calculate any Lp norm and omit any dimension we want. 193 | 194 | A few examples should make it clear. 195 | 196 | 197 | ```python 198 | ones = torch.ones(5,5,5) 199 | ones.size() 200 | ``` 201 | 202 | 203 | 204 | 205 | torch.Size([5, 5, 5]) 206 | 207 | 208 | 209 | It we omit dimension `0`, we have 25 elements each of value `1`. Thus their L2 norm is `5`: 210 | 211 | 212 | ```python 213 | norm = 2 214 | dim = 0 215 | y = torch.norm_except_dim(ones, norm, dim) 216 | print('y.size():', y.size()) 217 | print('y:', y) 218 | ``` 219 | 220 | y.size(): torch.Size([5, 1, 1]) 221 | y: tensor([[[5.]], 222 | 223 | [[5.]], 224 | 225 | [[5.]], 226 | 227 | [[5.]], 228 | 229 | [[5.]]]) 230 | 231 | 232 | Similar, omitting dim = `0` and calculating L3 norm gives 25 ** (1/3) = 2.9240: 233 | 234 | 235 | ```python 236 | norm = 3 237 | dim = 0 238 | y = torch.norm_except_dim(ones, norm, dim) 239 | print('y.size():', y.size()) 240 | print('y:', y) 241 | ``` 242 | 243 | y.size(): torch.Size([5, 1, 1]) 244 | y: tensor([[[2.9240]], 245 | 246 | [[2.9240]], 247 | 248 | [[2.9240]], 249 | 250 | [[2.9240]], 251 | 252 | [[2.9240]]]) 253 | 254 | 255 | Omitting dim = `2` changes the shape of the output: 256 | 257 | 258 | ```python 259 | norm = 2 260 | dim = 2 261 | y = torch.norm_except_dim(ones, norm, dim) 262 | print('y.size():', y.size()) 263 | print('y:', y) 264 | ``` 265 | 266 | y.size(): torch.Size([1, 1, 5]) 267 | y: tensor([[[5., 5., 5., 5., 5.]]]) 268 | 269 | 270 | Omitting dim = `-1` is the same as not omitting anything at all: 271 | 272 | 273 | ```python 274 | torch.norm_except_dim(ones, 2, -1) 275 | ``` 276 | 277 | 278 | 279 | 280 | tensor(11.1803) 281 | 282 | 283 | 284 | 285 | ```python 286 | ones.norm() 287 | ``` 288 | 289 | 290 | 291 | 292 | tensor(11.1803) 293 | 294 | 295 | 296 | ### By default `nn.utils.weight_norm` calls `torch.norm_except_dim` with `dim=0`. This is how we get `weight_g`: 297 | 298 | 299 | ```python 300 | torch.norm_except_dim(linear.weight.data, 2, 0) 301 | ``` 302 | 303 | 304 | 305 | 306 | tensor([[0.5075], 307 | [0.4952], 308 | [0.6064]]) 309 | 310 | 311 | 312 | It is the same as doing the below operation: 313 | 314 | 315 | ```python 316 | linear.weight.data.norm(dim=1) 317 | ``` 318 | 319 | 320 | 321 | 322 | tensor([0.5075, 0.4952, 0.6064]) 323 | 324 | 325 | 326 | ### Let's look at `_weight_norm` function 327 | 328 | This function is used to calculate `name` from `name_v` and `name_g`. Let's see how it works: 329 | 330 | 331 | ```python 332 | v = torch.randn(5,3)*10 + 4 333 | g = torch.randn(5,1) 334 | ``` 335 | 336 | For the given values of `g` and `v`, `torch._weight_norm(v,g,0)` is basically the same as `g * v/v.norm(dim=1,keepdim=True)`: 337 | 338 | 339 | ```python 340 | torch._weight_norm(v,g,0) 341 | ``` 342 | 343 | 344 | 345 | 346 | tensor([[-0.1452, -0.2595, -0.2552], 347 | [ 0.3853, 0.1218, 0.6896], 348 | [-0.0612, -0.0946, -0.0837], 349 | [-0.1722, 0.1708, 0.4013], 350 | [-0.0354, 0.0114, -0.0197]]) 351 | 352 | 353 | 354 | 355 | ```python 356 | g * v/v.norm(dim=1,keepdim=True) 357 | ``` 358 | 359 | 360 | 361 | 362 | tensor([[-0.1452, -0.2595, -0.2552], 363 | [ 0.3853, 0.1218, 0.6896], 364 | [-0.0612, -0.0946, -0.0837], 365 | [-0.1722, 0.1708, 0.4013], 366 | [-0.0354, 0.0114, -0.0197]]) 367 | 368 | 369 | 370 | Similarly, `torch._weight_norm(v,g,1)` is basically the same as `g * v/v.norm(dim=0,keepdim=True)` 371 | 372 | 373 | ```python 374 | torch._weight_norm(v,g,1) 375 | ``` 376 | 377 | 378 | 379 | 380 | tensor([[-0.2282, -0.2974, -0.2645], 381 | [ 0.3056, 0.0704, 0.3607], 382 | [-0.0779, -0.0878, -0.0702], 383 | [-0.0785, 0.0568, 0.1207], 384 | [-0.0178, 0.0042, -0.0065]]) 385 | 386 | 387 | 388 | 389 | ```python 390 | g * v/v.norm(dim=0,keepdim=True) 391 | ``` 392 | 393 | 394 | 395 | 396 | tensor([[-0.2282, -0.2974, -0.2645], 397 | [ 0.3056, 0.0704, 0.3607], 398 | [-0.0779, -0.0878, -0.0702], 399 | [-0.0785, 0.0568, 0.1207], 400 | [-0.0178, 0.0042, -0.0065]]) 401 | 402 | 403 | 404 | And `torch._weight_norm(v,g,-1)` is basically the same as `g * v/v.norm()` 405 | 406 | 407 | ```python 408 | torch._weight_norm(v,g,-1) 409 | ``` 410 | 411 | 412 | 413 | 414 | tensor([[-0.1003, -0.1792, -0.1762], 415 | [ 0.1343, 0.0424, 0.2403], 416 | [-0.0342, -0.0529, -0.0468], 417 | [-0.0345, 0.0342, 0.0804], 418 | [-0.0078, 0.0025, -0.0043]]) 419 | 420 | 421 | 422 | 423 | ```python 424 | g * v/v.norm() 425 | ``` 426 | 427 | 428 | 429 | 430 | tensor([[-0.1003, -0.1792, -0.1762], 431 | [ 0.1343, 0.0424, 0.2403], 432 | [-0.0342, -0.0529, -0.0468], 433 | [-0.0345, 0.0342, 0.0804], 434 | [-0.0078, 0.0025, -0.0043]]) 435 | 436 | 437 | 438 | For `linear` module, this is how we get `weight` from `weight_v` and `weight_g` (notice `dim=0`): 439 | 440 | 441 | ```python 442 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0) 443 | ``` 444 | 445 | 446 | 447 | 448 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 449 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 450 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 451 | 452 | 453 | 454 | 455 | ```python 456 | linear.weight.data 457 | ``` 458 | 459 | 460 | 461 | 462 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 463 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 464 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 465 | 466 | 467 | 468 | ### But what is the point of the forward pre-hook? 469 | 470 | 471 | ```python 472 | linear._forward_pre_hooks 473 | ``` 474 | 475 | 476 | 477 | 478 | OrderedDict([(0, )]) 479 | 480 | 481 | 482 | 483 | ```python 484 | hook = linear._forward_pre_hooks[0] 485 | ``` 486 | 487 | #### Let's first see what a hook does: it basically returns the value of `weight.data`. 488 | 489 | 490 | ```python 491 | hook.compute_weight(linear) 492 | ``` 493 | 494 | 495 | 496 | 497 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 498 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 499 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]], grad_fn=) 500 | 501 | 502 | 503 | Let's say you are training `linear` on a dataset with `batch_size` = 8. After back-propagation and weight update, the values of `weight_g` and `weight_v` will be different: 504 | 505 | 506 | ```python 507 | batch_size = 8 508 | x = torch.randn(batch_size, 5) 509 | ``` 510 | 511 | 512 | ```python 513 | y = linear(x) 514 | ``` 515 | 516 | 517 | ```python 518 | loss = (y-1).sum() 519 | ``` 520 | 521 | 522 | ```python 523 | loss.backward() 524 | ``` 525 | 526 | 527 | ```python 528 | linear.weight 529 | ``` 530 | 531 | 532 | 533 | 534 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 535 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 536 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]], grad_fn=) 537 | 538 | 539 | 540 | 541 | ```python 542 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0) 543 | ``` 544 | 545 | 546 | 547 | 548 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 549 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 550 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]]) 551 | 552 | 553 | 554 | 555 | ```python 556 | for param in linear.parameters(): 557 | param.data = param.data - (param.grad.data*0.01) 558 | ``` 559 | 560 | Weights `weight_v` and `weight_g` changed. Hence `weight` should now be equal to: 561 | 562 | 563 | ```python 564 | torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0) 565 | ``` 566 | 567 | 568 | 569 | 570 | tensor([[ 0.2946, 0.1498, -0.0491, -0.0748, 0.3667], 571 | [ 0.0258, -0.2643, 0.1122, -0.3771, -0.0488], 572 | [ 0.2303, 0.4506, 0.0095, 0.2294, -0.3068]]) 573 | 574 | 575 | 576 | But it's not: 577 | 578 | 579 | ```python 580 | linear.weight 581 | ``` 582 | 583 | 584 | 585 | 586 | tensor([[ 0.2782, 0.1382, -0.0764, -0.0963, 0.3820], 587 | [ 0.0089, -0.2772, 0.0862, -0.3995, -0.0354], 588 | [ 0.2126, 0.4394, -0.0193, 0.2077, -0.2931]], grad_fn=) 589 | 590 | 591 | 592 | This is why we need a hook. The hook will basically update `linear.weight` by calling `torch._weight_norm(linear.weight_v.data, linear.weight_g.data, 0)` during the next forward propagation: 593 | 594 | 595 | ```python 596 | _ = linear(x) 597 | ``` 598 | 599 | `linear.weight` is updated thanks to the hook: 600 | 601 | 602 | ```python 603 | linear.weight 604 | ``` 605 | 606 | 607 | 608 | 609 | tensor([[ 0.2946, 0.1498, -0.0491, -0.0748, 0.3667], 610 | [ 0.0258, -0.2643, 0.1122, -0.3771, -0.0488], 611 | [ 0.2303, 0.4506, 0.0095, 0.2294, -0.3068]], grad_fn=) 612 | 613 | 614 | --------------------------------------------------------------------------------