├── README.md ├── RF1.png ├── Receptive_Field.ipynb └── compute_RF.py /README.md: -------------------------------------------------------------------------------- 1 | # Receptive-Field-in-Pytorch 2 | ## Numerical Computation of Receptive Field in Pytorch 3 | I present a simple pytorch code that computes numerically the Receptive Field (RF) of a convolutional network. It can work with very complicated networks in 2D, 3D, with dilation, skip/residual connections, etc. 4 | 5 | * In the Jupyter notebook I explain how can we compute the RF both analitycally and numerically. I show some code that computes both. 6 | 7 | * In the python file a simple function to compute it. 8 | 9 | ## Requeriments 10 | * First you must change the max pooling layers of your network by average pooling and turn off any batchnorm and dropout that you might have. This is in order to avoid sparse gradients (More detailed explanation in the Jupyter Notebook). 11 | 12 | * You must provide also an numpy array that will be filled with ones and with the appropiate shape for your specific network. 13 | -------------------------------------------------------------------------------- /RF1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogertrullo/Receptive-Field-in-Pytorch/815df8ab5924735fc09f4e777019cb633830845e/RF1.png -------------------------------------------------------------------------------- /Receptive_Field.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Receptive Field\n", 8 | "If you have worked a little or you are somewhat familiar with Convolutional Network, you probably have heard about the term receptive field (RF). \n", 9 | "It is defined as the window size of input voxels that affects one particular output voxel. This hyperparameter is important since it indicates the context size that the network is using in order to compute one particular output voxel. \n", 10 | "There are some posts that explain more in detail about it, and how to compute it analitycally for simple architectures like AlexNet. Look [here](https://medium.com/@nikasa1889/a-guide-to-receptive-field-arithmetic-for-convolutional-neural-networks-e0f514068807) for example.\n", 11 | "\n", 12 | "To make it more clear I think is better to use a 1D image:\n", 13 | "\n", 14 | "\n", 15 | "\n", 16 | "In this image, I show an input with say 12 positions, and I will apply 2 convolutional layers (1D in this case).\n", 17 | "for each layer we need to define a size and a stride. Here I used a kernel size $F=3$ and a stride $s=1$.\n", 18 | "We can see that every position in the first output depends on 3 input positions.\n", 19 | "Now if we apply the second convolution, each position in the final output depends on 5 input positions!\n", 20 | "You can imgine that adding more layers the RF will keep increasing. \n", 21 | "How can we compute it?\n", 22 | "Well, as explained in the linked post, it is easy to compute by finding the necessary input size that produces an output of exactly one voxel. This can be computed by a recursive program which starts at the last layer by setting the output size to one and finding the necessary input size. This size will be used for the previous layer as the desired output and we keep going until the first layer. For a given convolutional (and pooling layer) we have\n", 23 | "\n", 24 | "\\begin{equation}\n", 25 | "O_{sz}=\\frac{I_{sz}-F}{s}+1\n", 26 | "\\end{equation}\n", 27 | "\n", 28 | "where $O_{sz}$ and $I_{sz}$ refer to the output and input sizes, $F$ is the filter (or kernel) size and $s$ is the stride. \n", 29 | "If we want to compute the RF we first set $O_{sz}=1$ and find the corresponding $I_{sz}$. In this case we would find $I_{sz}=3$. That is the RF of the last layer. Now if we keep going for the additional first layer, now setting $O_{sz}=3$ (the value we just found), we get $I_{sz}=5$.\n", 30 | "This is the RF of the network of the figure!\n", 31 | "\n", 32 | "We can build a simple script to compute this value\n", 33 | "\n", 34 | ".\n", 35 | "\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "def compute_N(out,f,s):\n", 53 | " return s*(out-1)+f if s>0.5 else ((out+(f-2))/2)+1#\n", 54 | "\n", 55 | "def compute_RF(layers):\n", 56 | " out=1\n", 57 | " for f,s in reversed(layers):\n", 58 | " out=compute_N(out,f,s)\n", 59 | " return out" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "Here we just pass a list of tuples ($F$,$s$)\n", 67 | "For example " 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "35" 79 | ] 80 | }, 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "layers=[(9,1),(3,1),(3,1),(3,1),(9,1),(3,1),(3,1),(7,1),(3,1)]\n", 88 | "compute_RF(layers)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "For that network we find that the RF is 35.\n", 96 | "\n", 97 | "\n", 98 | "But what if the network is very complicated, and it does not have a structred architecture?\n", 99 | "It can be really tedious to do it analitycally and sometimes just not possible.\n", 100 | "Turns out there is another way to compute this value numerically.\n", 101 | "In particular, we can use only one output channel in the last layer which we call $f=[f_1,\\dots,f_N]$. Now if we define a dummy loss function $l$ and we set its gradient with respect to $f ~\\nabla_f l$ to be zero everywhere except in a particular location $j$ which for convenience we set to 1:\n", 102 | "\\begin{equation}\n", 103 | "\\frac{\\partial l}{\\partial f_i}=\n", 104 | "\\begin{cases}\n", 105 | " 0,& \\forall i \\neq j\\\\\n", 106 | " 1, & \\text{otherwise}\n", 107 | "\\end{cases}\n", 108 | "\\end{equation}\n", 109 | "\n", 110 | "If we perform backpropagation until the input $x=[x1,\\dots,x_N]$ which is equivalent to compute $\\nabla_x l$ using the chain rule, we would find that $\\frac{\\partial l}{\\partial x_i}\\neq 0$ only if $x_i$ has some effect in $f_j$ which in turn, equals to find the RF. To be more precise, we chose the position $j$ to be in the center of the image, we set the weights of the network to be a positive constant (one in our case), and the biases to zero. This is because we use ReLUs as activation functions which would set to zero any negative value. In addition, the Max-Pooling layers are changed to Average-Pooling in order to avoid sparsity in the gradients. \n", 111 | "\n", 112 | "I will show an implementation using PyTorch.\n", 113 | "Fisrt I implement the same CNN that I defined through the layers list. The code is fairly simple:\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 7, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "analytical RF: 35\n", 126 | "numerical RF [35, 35]\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "import numpy as np\n", 132 | "import torch \n", 133 | "import torch.nn as nn\n", 134 | "from torch.autograd import Variable\n", 135 | "import torch.nn.init as init\n", 136 | "import torch.nn.functional as F\n", 137 | "\n", 138 | "def compute_RF_numerical(net,img_np):\n", 139 | " '''\n", 140 | " @param net: Pytorch network\n", 141 | " @param img_np: numpy array to use as input to the networks, it must be full of ones and with the correct\n", 142 | " shape.\n", 143 | " '''\n", 144 | " def weights_init(m):\n", 145 | " classname = m.__class__.__name__\n", 146 | " if classname.find('Conv') != -1:\n", 147 | " m.weight.data.fill_(1)\n", 148 | " m.bias.data.fill_(0)\n", 149 | " net.apply(weights_init)\n", 150 | " img_ = Variable(torch.from_numpy(img_np).float(),requires_grad=True)\n", 151 | " out_cnn=net(img_)\n", 152 | " out_shape=out_cnn.size()\n", 153 | " ndims=len(out_cnn.size())\n", 154 | " grad=torch.zeros(out_cnn.size())\n", 155 | " l_tmp=[]\n", 156 | " for i in xrange(ndims):\n", 157 | " if i==0 or i ==1:#batch or channel\n", 158 | " l_tmp.append(0)\n", 159 | " else:\n", 160 | " l_tmp.append(out_shape[i]/2)\n", 161 | " \n", 162 | " grad[tuple(l_tmp)]=1\n", 163 | " out_cnn.backward(gradient=grad)\n", 164 | " grad_np=img_.grad[0,0].data.numpy()\n", 165 | " idx_nonzeros=np.where(grad_np!=0)\n", 166 | " RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros]\n", 167 | " \n", 168 | " return RF\n", 169 | "\n", 170 | "class CNN(nn.Module):\n", 171 | " def __init__(self,layer_list):\n", 172 | " #layers is a list of tuples [(f,s)]\n", 173 | " super(CNN, self).__init__()\n", 174 | " f_ini,s_ini=layer_list[0]\n", 175 | " f_end,s_end=layer_list[-1]\n", 176 | " self.layers=[]\n", 177 | " self.layers.append(nn.Conv2d(1, 16, kernel_size=f_ini, padding=1,stride=s_ini,dilation=1))\n", 178 | " for f,s in layer_list[1:-1]:\n", 179 | " self.layers.append(nn.Conv2d(16, 16, kernel_size=f, padding=1,stride=s,dilation=1))\n", 180 | " self.layers.append(nn.ReLU(inplace=True))\n", 181 | " self.layers.append(nn.Conv2d(16, 1, kernel_size=f_end, padding=1,stride=s_end,dilation=1))\n", 182 | " self.all_layers=nn.Sequential(*self.layers)\n", 183 | " \n", 184 | " \n", 185 | " def forward(self, x):\n", 186 | " out = self.all_layers(x)\n", 187 | " return out\n", 188 | "\n", 189 | "###########################################################\n", 190 | "print 'analytical RF:',compute_RF(layers)\n", 191 | "\n", 192 | "mycnn=CNN(layers)\n", 193 | "\n", 194 | "\n", 195 | "img_np=np.ones((1,1,100,100))\n", 196 | "print 'numerical RF',compute_RF_numerical(mycnn,img_np)" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "We can see that both methods find the same RF value.\n", 204 | "You just need to be careful when computing the RF by initializing the parameters, changing max pool layers by average pooling and switch off batchnorm and dropout. This method is general and will work even for very complicated networks. " 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "collapsed": true 212 | }, 213 | "outputs": [], 214 | "source": [] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python 2", 220 | "language": "python", 221 | "name": "python2" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 2 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython2", 233 | "version": "2.7.10" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 2 238 | } 239 | -------------------------------------------------------------------------------- /compute_RF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | 8 | def compute_RF_numerical(net,img_np): 9 | ''' 10 | @param net: Pytorch network 11 | @param img_np: numpy array to use as input to the networks, it must be full of ones and with the correct 12 | shape. 13 | ''' 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv') != -1: 17 | m.weight.data.fill_(1) 18 | if m.bias: 19 | m.bias.data.fill_(0) 20 | net.apply(weights_init) 21 | img_ = Variable(torch.from_numpy(img_np).float(),requires_grad=True) 22 | out_cnn=net(img_) 23 | out_shape=out_cnn.size() 24 | ndims=len(out_cnn.size()) 25 | grad=torch.zeros(out_cnn.size()) 26 | l_tmp=[] 27 | for i in xrange(ndims): 28 | if i==0 or i ==1:#batch or channel 29 | l_tmp.append(0) 30 | else: 31 | l_tmp.append(out_shape[i]/2) 32 | print tuple(l_tmp) 33 | grad[tuple(l_tmp)]=1 34 | out_cnn.backward(gradient=grad) 35 | grad_np=img_.grad[0,0].data.numpy() 36 | idx_nonzeros=np.where(grad_np!=0) 37 | RF=[np.max(idx)-np.min(idx)+1 for idx in idx_nonzeros] 38 | 39 | return RF 40 | --------------------------------------------------------------------------------