├── .gitignore ├── 1_mlp.ipynb ├── 2_lenet.ipynb ├── 3_alexnet.ipynb ├── 4_vgg.ipynb ├── 5_resnet.ipynb ├── LICENSE ├── README.md ├── assets ├── alexnet.png ├── batch-norm.png ├── cifar10.png ├── filter-mnist.png ├── filter-mnist.xml ├── filtered-image.png ├── lenet5.png ├── lr-scheduler.png ├── mlp-mnist.png ├── mlp-mnist.xml ├── multiple-channel-mnist.png ├── multiple-channel-mnist.xml ├── multiple-filter-mnist.png ├── multiple-filter-mnist.xml ├── relu.png ├── resnet-blocks.png ├── resnet-pad.png ├── resnet-pad.xml ├── resnet-skip.png ├── resnet-table.png ├── single-filter.png ├── single-filter.xml ├── single-pool.png ├── single-pool.xml ├── subsample-mnist.png ├── subsample-mnist.xml ├── subsampled-image.png ├── vgg-resnet.png ├── vgg-table.png └── vgg.png └── misc ├── 4 - VGG.ipynb ├── 5 - ResNet.ipynb ├── 6 - ResNet - Dogs vs Cats.ipynb ├── conv order.ipynb ├── download_dogs-vs-cats.sh └── process_dogs-vs-cats.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/ 107 | .data/ 108 | models/ 109 | *.pt 110 | .vscode/ 111 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ben Trevett 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Image Classification 2 | 3 | This repo contains tutorials covering image classification using [PyTorch](https://github.com/pytorch/pytorch) 1.7, [torchvision](https://github.com/pytorch/vision) 0.8, [matplotlib](https://matplotlib.org/) 3.3 and [scikit-learn](https://scikit-learn.org/stable/index.html) 0.24, with Python 3.8. 4 | 5 | We'll start by implementing a multilayer perceptron (MLP) and then move on to architectures using convolutional neural networks (CNNs). Specifically, we'll implement [LeNet](http://yann.lecun.com/exdb/lenet/), [AlexNet](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf), [VGG](https://arxiv.org/abs/1409.1556) and [ResNet](https://arxiv.org/abs/1512.03385). 6 | 7 | **If you find any mistakes or disagree with any of the explanations, please do not hesitate to [submit an issue](https://github.com/bentrevett/pytorch-image-classification/issues/new). I welcome any feedback, positive or negative!** 8 | 9 | ## Getting Started 10 | 11 | To install PyTorch, see installation instructions on the [PyTorch website](https://pytorch.org/). 12 | 13 | The instructions to install PyTorch should also detail how to install torchvision but can also be installed via: 14 | 15 | ``` bash 16 | pip install torchvision 17 | ``` 18 | 19 | ## Tutorials 20 | 21 | * 1 - [Multilayer Perceptron](https://github.com/bentrevett/pytorch-image-classification/blob/master/1_mlp.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/1_mlp.ipynb) 22 | 23 | This tutorial provides an introduction to PyTorch and TorchVision. We'll learn how to: load datasets, augment data, define a multilayer perceptron (MLP), train a model, view the outputs of our model, visualize the model's representations, and view the weights of the model. The experiments will be carried out on the MNIST dataset - a set of 28x28 handwritten grayscale digits. 24 | 25 | * 2 - [LeNet](https://github.com/bentrevett/pytorch-image-classification/blob/master/2_lenet.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/2_lenet.ipynb) 26 | 27 | In this tutorial we'll implement the classic [LeNet](http://yann.lecun.com/exdb/lenet/) architecture. We'll look into convolutional neural networks and how convolutional layers and subsampling (aka pooling) layers work. 28 | 29 | * 3 - [AlexNet](https://github.com/bentrevett/pytorch-image-classification/blob/master/3_alexnet.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/3_alexnet.ipynb) 30 | 31 | In this tutorial we will implement [AlexNet](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf), the convolutional neural network architecture that helped start the current interest in deep learning. We will move on to the CIFAR10 dataset - 32x32 color images in ten classes. We show: how to define architectures using `nn.Sequential`, how to initialize the parameters of your neural network, and how to use the learning rate finder to determine a good initial learning rate. 32 | 33 | * 4 - [VGG](https://github.com/bentrevett/pytorch-image-classification/blob/master/4_vgg.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/4_vgg.ipynb) 34 | 35 | This tutorial will cover implementing the [VGG](https://arxiv.org/abs/1409.1556) model. However, instead of training the model from scratch we will instead load a VGG model pre-trained on the [ImageNet](http://www.image-net.org/challenges/LSVRC/) dataset and show how to perform transfer learning to adapt its weights to the CIFAR10 dataset using a technique called discriminative fine-tuning. We'll also explain how adaptive pooling layers and batch normalization works. 36 | 37 | * 5 - [ResNet](https://github.com/bentrevett/pytorch-image-classification/blob/master/5_resnet.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-image-classification/blob/master/5_resnet.ipynb) 38 | 39 | In this tutorial we will be implementing the [ResNet](https://arxiv.org/abs/1512.03385) model. We'll show how to load your own dataset, using the [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset as an example, and also how to use learning rate schedulers which dynamically alter the learning rate of your model whilst training. Specifially, we'll use the one cycle policy introduced in [this](https://arxiv.org/abs/1803.09820) paper and is now starting to be commonly used for training computer vision models. 40 | 41 | ## References 42 | 43 | Here are some things I looked at while making these tutorials. Some of it may be out of date. 44 | 45 | - https://github.com/pytorch/tutorials 46 | - https://github.com/pytorch/examples 47 | - https://colah.github.io/posts/2014-10-Visualizing-MNIST/ 48 | - https://distill.pub/2016/misread-tsne/ 49 | - https://towardsdatascience.com/visualising-high-dimensional-datasets-using-pca-and-t-sne-in-python-8ef87e7915b 50 | - https://github.com/activatedgeek/LeNet-5 51 | - https://github.com/ChawDoe/LeNet5-MNIST-PyTorch 52 | - https://github.com/kuangliu/pytorch-cifar 53 | - https://github.com/akamaster/pytorch_resnet_cifar10 54 | - https://sgugger.github.io/the-1cycle-policy.html 55 | -------------------------------------------------------------------------------- /assets/alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/alexnet.png -------------------------------------------------------------------------------- /assets/batch-norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/batch-norm.png -------------------------------------------------------------------------------- /assets/cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/cifar10.png -------------------------------------------------------------------------------- /assets/filter-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/filter-mnist.png -------------------------------------------------------------------------------- /assets/filter-mnist.xml: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /assets/filtered-image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/filtered-image.png -------------------------------------------------------------------------------- /assets/lenet5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/lenet5.png -------------------------------------------------------------------------------- /assets/lr-scheduler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/lr-scheduler.png -------------------------------------------------------------------------------- /assets/mlp-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/mlp-mnist.png -------------------------------------------------------------------------------- /assets/mlp-mnist.xml: -------------------------------------------------------------------------------- 1 | 7VvbcqM4EP0aP8bFVciPuTmzWzO1U5WHzTxtySAbbbDEgIjt/fqVQLLNJQ7Z2EbObFViUINAnO7TanXDyL1drh8ylMbfWISTkWNF65F7N3KcADriVwo2lcDzQCVYZCSqRPZO8Ej+wUpoKWlBIpzXTuSMJZykdWHIKMUhr8lQlrFV/bQ5S+p3TdECtwSPIUra0j9JxONKCp1gJ/+CySLWd7bBpDqyRPpk9SR5jCK22hO59yP3NmOMV3vL9S1OJHYal6rf9JWj24FlmPI+HZ7I5GH62xf+u/vj2/V89YczvQdX0FWD4xv9xDgSAKgmZVRsbjJW0AjL61iixTIeswWjKPnKWCqEthD+jTnfKPWhgjMhivkyUUfxmvCnvf0f8lJjX7Xu1urKZWOjG5Rnm6f9xl4v2dx1K1u7ftG1VLxohgnKcxJWwilJ9HDa0Ck0c1ZkIT6Al6uUyVG2wPzQieqKEs29WyjVPGC2xGLQ4oQMJ4iTl7q5IWW1i+15O8WKHaXbd+hZj/sFJYW6VUvxdS2vYsLxY4pKOFaC23WNzgWatyxhWdnXjRCGcwl0zjP2jPeOgBDi2fwQ6C8443h9ECR11NVcUs7EsVR7taOmp0TxHis9fd7RcfXbMF46fz5ADu2g3ySHYxY5tCGeihw+hpHXRQ7ozFwAjkMOz2uQw+tHji2pjk8OcyeXDxi509fIXbOM3PkMRu6DhpHDfka+FR4fV/e0uM5hiMPOmXUGfc8/aMz9cQWwjuvWKbyBqw1OhasHB3Me9kHX8V/n1QjlcTlSux6kqifZRagf9E9eX//km+WfvLd5JC4j1n34bQ6hPK0Wg3OylpgfJfh0G67HblPE6XI9J/M8vuGISe9UQ8wfGDFvYmxAYrpXARfqVYDhHGl5lY6o/awc0Rz9nyPv5khwoRwJLo0jwdDziN8C6DSEuGQ6eL2zQUeng+r6nREx6F04YjXWOG7DPqqRql4NE9kO4wNW0yPNZBTP/KEjXF0rGmIuumjq9c5RmTUT6XFfDkOGjta8wFCGnMvSj55kfWXuaC5lm3mvU88dPbKMZjFj6BjNb6/8AujJXjQteN5CTzwnr8NUT7gqw93PzioRSsiCimYo0MNCfiNRIyFKrtWBJYmi5DW91Lk5Z5Rr8nZU+D9cMQVuWy2d+fKTqaW92CjTQhbFRcbor6KXZrFueL3All5s69fTS7O+NLxeJh16EW1W8F/IjTXLU4OrRZuJsfPxdv7VActk4PkYmL76bSLWVQI9L2Kmr4ZaiA1d3wGmR8lNxLwOP3ZexEyvurYQA0MjZnrVtYXY4J6/HcKYhRiYNPzY0KwMTI8uWogNzcrAdM/fRMwbOroITPf8LcSGji4C0z1/M0npDZ2DC0x/+6KJmN/xQcF5ETO9Ft9CbPC50vToolU66Hhl+ayIQdOjiyZi7tB1aPj53hs86VdysPebUEpjhtSfYQ/vm8colbtkWX45elNurzULrC5K6ETqVzTDyXeWE06YTKjOGOds2ZFp5dJgbtQd7mLO5ceu1/LZnClec5wJq7oKGeVCM+OoCJ/l/4KNQ3E1Z0qKEoVpse0rvLTo7cv+4sf9iX+mmQNC9ExsENE4sif5C5htwMvVKsV0QSgeU8wjiq7yPKmuKvqtUn3PslmkCUNRXu47lj0pdyyn3ORomSb4r/IJhJKttTOB45QuRg4QR2SCuLQx3aJsPnvL6N6xAG5mPicd3sMd61dD6x7EOZVpddRX1g4Uf7JnqemmpX3OVL7t2mMAaxo6YTZfNHdfWFfF+N1n6u79vw== -------------------------------------------------------------------------------- /assets/multiple-channel-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/multiple-channel-mnist.png -------------------------------------------------------------------------------- /assets/multiple-channel-mnist.xml: -------------------------------------------------------------------------------- 1 | 5Z1db1vXEUV/jR8D3G+e+xi7ifOQAEadoshTQYu0xIQWBYqO7f76XkW8+piRgxaJufdKn2xRFCkt8nBmkWfOfta+ePfx5X55dfHDbrXePmuq1cdn7d+eNU3dVu30z80ln24vWXTHC873m9XxSvcXvN78e328sDpe+n6zWl8/uuJht9seNlePLzzbXV6uzw6PLlvu97sPj6/2drd9fK9Xy/N1uuD12XKbL/3nZnW4uL20NIv7y79bb84v5nuuh/H2O++W85WPf8n1xXK1+/DgovabZ+2L/W53uP3fu48v1tsbeDOX25/79jPfvfvF9uvLw3/zA69+Gq/G/t2nl3///sfN19Xw9sXPP31VV+3xhn5dbt8f/+b+hufFcgL6289fXr0/3NzT5uN6e338Ww6fZkCH9cfp7p9fHN5tpwvq6b/Xh/3ul/WL3Xa3ny653F1O13z+drPdhouW28355fTl2fQHrKfLn/+63h82E/qvj994t1mtbu7m+YeLzWH9+mp5dnOfH6Yn2nTZfvf+crW++duq6avV8vrity9ufoG3u8vD8ZlUl+nrTGr+m6c7XH98cNGR3Mv17t36sP80XeX+iXz7I/PT+Pjlh/vnRD0/ZS8ePB/mB395fBqe393y/SM1/ef4YD39wP3r1e7Hf7xbvbzevP72+fqb8+bV9XdftfO9PXjc0mPzmNFTFB88bL+L8OHD96xpV/26rLr0WE/fKc2bdhjuoCfCTzwOn4XeLh5Dr/tMvTkt9LxYviD0L4HwiSfuiRE2dISjHGELR3hX3XUIOzrCVo6wpyPUl5OBjlBfThZ0hPpyUmgIR7u+cKQjlC/k+Za5COULucHZSUCo7wtnYlyE8r6wwdlJRCgvJw3OTiJCfTnB2UlEqC8nNDvp6lCR9QuZZicJoX4h0+wkIdQvZJqdRIT6vnDuqrgI5X3hEx80whDKy0lLs5OEUF5OWpqdJITyctLS7KQLH97r+8KWZicJoX4h4+wkItQvZJydBIQGfSHOTiJCfV+Is5OIUF5OOpydRITyctLh7CQilJeTDmcnfajI+oWMs5OIUL+QcXYSEeoXMs5OAkJ9X9jh7CQilPeFHc5OIkJ9OcHZSUSoLyc4O4kI5eVk/oU4CO3mTnqcndjNnfQ4O7GbO+lxdmI3d9Lj7MRu7qTH2Ynd3EmPsxO7uZMeZyd2cyc9zk7s5k56nJ3YzZ0MODuxmzsZcHZiN3cy4OzEbu5kwNmJ3dzJgLMTu7mTAWcndnMnA81Oeru5k4FmJwmhfiHT7CQh1C9kmp1EhPq+cF4GXITyvnBBs5OEUF5OFjQ7SQjl5WRBs5OEUF5OFrm1fr25PJ/oNdXbzfbmDMmm2r0//HZcZUCLOaXyD86QD48esyee9WN30ofspH1oOKTybTlbn52lh3r6zpvS32yT+FOYl9B1NfJlctLG1YJ5J2dO63QjwUFNsNAa3UiwyAnS+txAsKnkBGltbiQorz2F1uVGgvJKUmhvwUeC+kpy0nfgJf1PXYXio1/4J33PXnIafYKuf62gvcufEOpfLGjv8ieE+r6T5j4Rob7xnN9a4yKUF6CRZj8JobycjDT9SQjl5WSk+c9d52jTSY40AUoI9QuZtgcpIdQvZNoepIRQ3heOODsJCA36QpydRIT6coKzk4hQXU66eRlwEarLSVfh7KQL5US9kLsKZycRoX4h4+wkItQvZJydRITqvrCrcHYSEMr7wq7C2UlEqC8nODuJCPXlBGcnEaG+nODsZHDrC0+bg/4lEMoXMi7VPCGUL2RcqnlCKO8LcanmEaG+L8SlmieE+nKCs5OIUF9OcHYSEerLCc5O3MZgOlyqeUKoX8g4O3Eba+nmW+YilPeFDc5O3AZbupkYF6G8nOBSzRNCeTnBpZonhPpyQrOTxm1SpcOlmieE+oVMs5OEUL+QaXaSEOr7QpqdRIT6vhCXap4QyssJLtU8IZSXE1yqeUIoLye4VPM7FjZ9IS7VPCHUL2ScnbjNnXS4VPOEUN4X4lLNI0KDvhBnJ25zJx0u1TwhlJcTXKp5QigvJ7hU88Zu7gSXap4Q6hcyzk7s5k5wqeYJobwvxKWaR4T6vhCXap4Q6ssJzk7s5k5wqeYJob6c0OwkfhIvrya4UPN0WJIeIU1O0rkCeoQ0OUkjeHqENDlJu9X1CGlykjZ26RHi5MSvnODkxK+c4OVEj/Cvn74S86fVCW/daVPQLZirI+E6XGy6W/B8h0tNd8ud73Ch6W6x8x0uM90tdb7DRaa7hc53uMR0t8z57rSB6ZrEmxhurV/4ymjLEyXeuOXad7hQ9oRQ/2JB++gnIZT3nbhQ9ohQ33jiQtkTQnkBwoWyJ4TycoILZU8I5eVkQfOfu87RppN8ItcehlC/kGmf/SSE+oVM++wnIdT3hTg7ae36QpydRITycjJ/BMpFKC8nBWcnEaG8nBScnfShnOgXMs5OIkL9QsbZSUSoX8g4O4kI5X1hwdlJQKjvC3lR9xGhvpzg7CQi1JcTnJ1EhPJywou6X4RyIl/IvKj7iFC+kHlR9xGhfiHj7CQilPeFvKj7gFDfF/Ki7iNCfTnB2UlEqC8nODuJCPXlBGcndmMwvKh7t6mWnhd17zbW0vOi7t3mWnpe1L3bYEvPi7p3m2zpeVH3bqMtPS/q3m22pcdF3Tdukyo9Luo+IdQvZJqdJIT6hUyzk4RQ3hfiou4jQn1fiIu6Twjl5QQXdZ8QyssJLuo+IZSXE1zUfeM2d9Ljou4TQv1CxtmJ29xJj4u6Twj1fSHOTtzmTnpc1H1CKC8n8y1zEcrLSYOzE7e5k34mxkHoNnfS46LuE0L9QsbZidvcSc+LunebO+l5Ufducyc9L+rebe6k50Xdu82d9Lyoe7e5kx4XdR8/iZdXE1zSfTosSY+QJifpXAE9QpqcpBE8PUKanKTd6nqENDlJG7v0CHFy4ldOcHLiV07wcqJHeFI50Zw+HsKu5JHa/fwL/D9BV6em9h3OgCJCdTxW3+EMKObcqUO4+w5nQBFhI0eIM6CIUF+AcAYUEerLCc2A7CKk+04ZwnKiNqgJNVy/9JURlydKvonQ9S8WtI+AEkL5i8VszliE+t6zxxlQRCjvPXucAUWE8gLU4wwoIpSXkx5nQBGhvpzgDCiEpes7yR5nQBGhfiHTPgNKCPULmfYZUERo0Bfi7CQilPeF8ysJF6G8nAw4O4kI5eVkwNlJRCgvJwPOToZQkfULGWcnEaF+IePsJCLUL2ScnQSE+r5wwNlJRKjvC3F2EhHqywnOTiJCeTnhRd5HhPJywou8L6EiyxcyL/I+ItQvZJydRIT6hYyzk4BQ3xfyIu8jQnlfyIu8jwj15QRnJxGhvpzg7CQi1JcTmp00dtMwuMj7hFC+kHGR9wmhfCHjIu8jQn1fiIu8TwjlfSEu8j4h1JcTmp0khPpyQrOThFBfTmh2csfCpy+k2UlCqF/IODuxmzvBRd5HhPq+EBd5nxDK+0Jc5H1CKC8nuMj7hFBeTnCR9wmhvpzg7MRu7gQXeZ8Q6hcyzk7s5k5wkfcRoUFfiLMTt7mTARd5nxCqy8mAi7xPCNXlZMBF3ieE6nIy4CLvG7e5kwEXeZ8Q6hcyzk7c5k4GXuS929zJwIu8d5s7GXiR925zJ0OFsxO3uZOBF3nvNncy4CLv4zGTrbyc4CLv42lJBghpdhIPFjBASLOTOINngJBmJ3G7ugFCnJ34lROcnfiVE5yd+JUTup3oEc63/Fc+fjykXsmztYfmpD7jAV0dnzrMjDEvFgmhOidraHAGFAPv1GncQ4MzoIhQncY9NDgDigj1BYhmQAmhvpzQDMguS3polCksJ2qDWrve86TOpIm+idDlLxbzL8R5sYgI5S8WLe4zoJgULO89W5wBRYTy3rPFGVBEKC9ALc6AIkJ9OcEZkFtq+tDiDCikpus7yRZnQBGhfiHTPgNKCPULmfYZUESo7wvnz/W4COV9YYezk4hQXk46nJ1EhPJy0uHsJCKUl5MOZyeLUJH1CxlnJxGhfiHj7CQi1C9knJ0EhAZ9Ic5OIkJ9X4izk4hQXk54mfcRobyc8DLvI0J5OeFl3o+hIusXMs5OIkL9QsbZSUSoX8g4OwkI9X0hL/M+IpT3hbzM+4hQX05wdhIR6ssJzk4iQnk5wWXeN3bTMLjM+4RQvpBxmfcJoX4h0+wkItT3hbjM+4RQ3hfiMu8TQn05odlJQqgvJzQ7SQj15YRmJ43dpAou8z4hlC9kXOZ9QihfyLjM+4hQ3xfiMu8TQnlfiMu8Twjl5QSXeZ8Q6ssJzk7s5k5wmfeN3dwJLvM+IdQvZJyd2M2d8DLv7eZOeJn3dnMnvMx7u7kTXua93dwJL/Pebu6El3lvN3fCy7y3mzvhZd7bzZ3wMu/t5k54mfd2cye8zHu7uRNe5r3d3Akv895u7gSXeR+PmWzl5QSXeR9PSzJASLOTeLCAAUKancQZPAOENDuJ29UNEOLsxK+c4OzEr5zg7MStnCx4mfdu5aTg9lrbnWtdcHut7c61Lri91nbnWhfcXmu7c60Lbq+13bnWBbfX2u5c64Lba213onDB7bW2O1G44PZa250oXHB7re1OFC64vdZ2JwoX3F5ruxOFC26vtd1ZrgW319ruLNeC22ttd5Zrwe21tjvLteD2Wtud5Vpwe63tznItuL3WdqdoFtxea7tTNAtur7XdKZoFt9fa7hTNgttrbXeKZsHttbY7RbPw9lq7nV9YeHut3c4vLLy91m7nFxbcXmu78wsLbq+13fmFBbfX2u78woLba213clzB7bW2Ozmu8PZa2+3swu+1NqjIODtx29k18vZau+3sGiucnbjt7BornJ247ewaK5yduO3sGiucnbjt7BornJ247ewaK5yduO3sGiuanaTt6vpyQrOTtL9Qj5BmJ3bJtOPcEnAQuoW4jTXOTtzyTsYaZyd25aTG24ke4Unt5O1mu32x2+72v91wu+rXZdVNl18f9rtf1g++U5o37TD8Sa+ezWPo9SCHTvOZhLDIEdJ8JiJsKjlCnM9EhI0cIc5nIsJOjhDnMxGhvJzMt8xFKC8nDc1n6s6tqZmJcRHqn4U0n4kI9U1NQ/u0JSGUNzUNzk4iQnlT0+DsJCLUlxOcnUSE+nKCs5PBrqnB2UlEKH8Wzu8aYxHqm5oWZycRobypaXF2EhHKm5oWZycRobyctDg7iQj15QRnJ8WtqWlxdhIR6p+FODspdk0Nzk4iQn1Tg7OTiFDe1My/ABehvJx0ODuJCOXlpKPZyV35sGlqOpqdJIT6ZyHNTiJCfVPT0ewkIZQ3NR3NThJCfVNDs5OEUF9OaHaSEOrLCc1O7ljYNDXzdnEuQvmzsKfZSUSob2p6nJ3YbRTucXZit1G4x9mJ3UbhHmcndhuFe5yd2G0U7nF2YrdRuMfZid1G4R5nJ3YbhedXEi5CeVMz4OzEbqPwgLMTu43CA85O7DYKDzg7sdsojEuXjwj1TQ0uXT4h1Dc1ODux2yiMS5dPCPXlBGcndhuFeeny4c2uVl5OeOnynR1Cmp3EHf8GCGl2EvcXGiDE2UllhxBnJ37lBGcnfuWEbicGCE9qJ5oTD93CkUZeHr1bONLIy6N3C0caeXn0buFIIy+P3i4ciZdHbxeOxMujtzs7m5dHb5cvxcujt8uX4uXR2+VL8fLo7fKlcHn0CaG8qcHl0SeE8nKCy6P3C0fC5dHHcCR9U4PLo08I9c9CnJ3EiC55U4PLo08I5U0NLo8+IdQ3NTg7iQj15QRnJ3Zxhbg8+hhXKG9q6goXSJ8Yqp+GE0Ocn8TYTHVbMzHECUpkqO5rJoY4Q4kM1Y3NxBCnKJGhQU3BOYpbjPDEkCYpMUfYobehWUpiaPA8pGlKyrM26G1onpIY6nsbXDR9YqjvbfjZ9Pqawgunjwz1NYWXTu+2bXhiSPOUxNDgeUjzlMjQoLfBxc0nhga9Dc5T3LYOTwxxnuK2d3hiiPMUt83DdTXfNIeh2+7hiSHOU9y2D08McZ7itn94YojzFLcNxBNDnKe47SCeGOI8xW0L8cQQ5ylue4gnhjhPcdtEPDHEeYrbLuKJIc5T3LYR1xUufD4x1Pc2uPT5xFDf2+Di5xNDfU3B5c8nhvqawgugD+99tQY1heYpcU7PgSHNU+JYgANDmqfE/YcODHGeUvsxxHmKX03BxdDH920cGNI9xYHhST0lHI74tpytz86ePXE44pvSd331J1Wh5jF1eXbIRP2kZuNBXX0+/EQd50KBoTxvZGKIc6HIsNEzxLlQZKhOHJkY4lwoMjSoQjQXSgwNagrNhexCR+rqtKH3mm7ILZR3on5Se/Kgrn+96E/qW1+AoUEP2tM+F0oM9T1oj3MhtyjfiSHOhdyyfCeGOBdyC/OdGOJcyC1mZ2KIcyG3ROSJIc6F3CKR62r+jbAMDfrDgfa5UGKo7w8HnKe4pSJPDHGe4haLPDHEeYpbLvLEEOcpbhmCE0Ocp4RMWoP+cMB5SmRosJZxnlL8+kOcp0SG+v5w3snEZajvDxc4T4kM9TVlgfOUyFBfUxY4TwkMDfrDBc1T7uqwT3+4oHlKYmiwlmmeEhka9IcLmqckhgb9Ic1TEkOD/pDmKYmhvqbMr8hchvqaUmieEhka9IeF5il3MHz6w0LzlMTQYC3jPMVvBqbgPMVvBqbgPMVvBqbgPMVvBqbgPMVvBqbgPMVvBmY+KYPD0G+iZcR5it98yojzFL/5lBHnKX7zKSPOU/zmU0acp/jNp4w4T/GbTxlxnuI3nzLiPMVvPmXEeYrdfEo9rwQsQ31/WFc4T7GbT6krnKfYzafUFc5T7OZT6grnKXbzKXWF8xS7+ZS6wnmK3XxKXeE8xW4+pa5wnmI3n1LP+WwPkK1X5+vXxy93+8PF7nx3udx+c3/p88dQ76/z/W53daT38/pw+HTEt3x/2D0GfXufN3f0v5Lcr7fLw+bXxz/3FJfjj77abaZbvHsE2pjk3Qa017v3+7P18afu6eYbiqUt3tBhuT9fH9INfb3fLz89uNrVzRWuf+cX/kwc3Wd/r9+//vSf29/g/jlzB/cPLUWc7tqNOdXzU5PLUK8ZNU537cacpppAZ6hvTWqc7tqNOdU1TXfTcacGNYWmu+m4LgOGNN1NR1oYMKTpbhr7NGCI0904GmHAEOcpfjVlvmkOQ7+a0uA8xa+mzMg4DE9XU6Yv97vd4eHbPNOfe/HDbrW+ucZ/AA== -------------------------------------------------------------------------------- /assets/multiple-filter-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/multiple-filter-mnist.png -------------------------------------------------------------------------------- /assets/multiple-filter-mnist.xml: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /assets/relu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/relu.png -------------------------------------------------------------------------------- /assets/resnet-blocks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/resnet-blocks.png -------------------------------------------------------------------------------- /assets/resnet-pad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/resnet-pad.png -------------------------------------------------------------------------------- /assets/resnet-pad.xml: -------------------------------------------------------------------------------- 1 | 3Z1Rb5swEMc/TR4nBWwDfVyzrtO0SZsyadrTRMEFNhJHjtMk+/Qjw4RgmmmdhM76PxUOY+CX6/n+Pa6ZscXqcK/TTflR5bKehfP8MGNvZmEYBOFN8+NkObaWaM5bQ6Gr3A7qDcvql7TGubXuqlxuBwONUrWpNkNjptZrmZmBLdVa7YfDHlU9vOomLeTIsMzSemz9WuWmbK1JGPf2d7Iqyu7KQWQfeJV2g+2TbMs0V/sLE7ubsYVWyrRbq8NC1id4HZf2vLdXjp5vTMu1+ZcTPqvoexHmO1EfVl+C9w/3evftlZ3lKa139oHtzZpjR0Cr3TqXp0nmM3a7Lysjl5s0Ox3dN595YyvNqm72gmbzsarrhaqV/nMuy4VMct7Yt0arn/LiSBI+sChqjtgbkNrIw9UnC868GkeTaiWNPjZD7AncEj4Od/f959WZyouPqrOl1kOK87w9xGbDcnwB03Bapo9JJrPsOaYPieBiPgnThJgpA/TTICSGygEdNYiIoQpATw3nxFAjQE8lX6diRE+lXqgSQE9l1AvVDQLUxK8stRNYUFCpf/uDifUUCVTyPDWAUFQuVepENZhYU5FQJc9UAwhR5VIlX6wmVlU0VMlXKwhZ5VAlT1YDCF11XvR9SVchhJVLlTwCTKysaHyVPGHtro/lrOQZazixuqLxVvKUtaOI5a3kK1YIUbMaYaVeskIIgeViJc9aQwiFdV6iPMlaQwiF5VIlDwETKywaqvRZK4bEcrGSZ60Q1SsXK3nWyjA0louVesViEBWsEVbqJYthaCwHK3nWyiA01jmWepK1MgiJ5VIlDwEQbwa6VMmzVgYhsUZYqbNWBlHFcrHSZ60QGmuElXzFgqhjjbBSL1ndxFjeSp61cgiN5VvbFYeQWL41XnGItwT9a73CkFi+NV9xiCqWd+1XHENj+daAxSHqWN61YHEMjeVbExbH0FietWEJDInlWR+WgHhT0LtGLIEhsXzrxBIYVSzfWrEEhsbyrRdLYNSxfGvGEhgay7duLAGhsZhn3VgCQmK5VMlDAMSbgi5V8qw1CUYUZV7Ipd1V2pSqUOu0vuutt0PO/ZgPSm0s3R/SmKP9j4Xpzqgh+/aapwu9lKOWdWqqp+F5z1Gxp35SVTNjz/9aHtZNsVU7nUl7Vs92NBG/9spnN5FJdSHNaKLXWqfHi2Gb04Dt9RvmVxqirt7X38c3G+0d9B5zhvv/TtSxAPvdpJY+EUQp1MVKLn0iCKE+wkqd9kQQxdARVuq8J4IQ6i5WcukTQQh17tsfQaOJhToRVvJMAOKFUxcrfSYwsVInwkqeCUAUQ4VvUj2eWGURYaWOrTFEOdTFSh5b44lVFhFW6tgaQ5RDhW/t/fHEKosIK3lshVBZLlb62AqhskZYyWMrhMqKfGtCjSFU1ggreWyFUFkuVvLY2olpMKzTxdZmt/+etrai2H/bHbv7DQ== -------------------------------------------------------------------------------- /assets/resnet-skip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/resnet-skip.png -------------------------------------------------------------------------------- /assets/resnet-table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/resnet-table.png -------------------------------------------------------------------------------- /assets/single-filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/single-filter.png -------------------------------------------------------------------------------- /assets/single-filter.xml: -------------------------------------------------------------------------------- 1 | 3VjBcpswFPwaXzqTGSRhjI+2k7SH9NJkUufUkZEMagWiQsS4X18JhEHYTpNOWju52Hr7JKG3uwwPRmiRVh8lzpPPglA+gh6pRuhyBOEUevrXANsG8INJA8SSkQYCHXDLflEL2nVxyQgtnIlKCK5Y7oKRyDIaKQfDUoqNO20tuHvVHMd0D7iNMN9HvzKikgYN4aTDP1EWJ+2VQTBtMiluJ9tKigQTselB6GqEFlII1YzSakG54a7lpVl3fSS7O5ikmXrOgvJ+eXd1f4dnwbK8Cb/cr/MKXNhdHjEvbcGbuppvNlGobcuEFGVGqNnMG6H5JmGK3uY4MtmNll5jiUq5joAeElwk9VwTrBnnC8GFrDdC6zCiUaTxQknxg/Yyq3Dsj83ua5Ep6wUQ6NiekkpFq6Plgx2p2oxUpFTJrZ5iF/hWBuvDcSvLpqeqhZKeoC2GrY/i3cYd1Xpg2X4B8/AY8/C9MQ/gmVGPjlGP3hv1A9MH3omZ948x77835oemPzn14z2GKdFPOhtmItN/c5d0IVUiYpFhfiNEbtn9TpXaWrJwqYQrBK2YWtrlZvzQG19W/WDbBpmubtkPHvpBt6iOulVkZp7u3dE1cs0MJXW+KdZU+LR+mhBRyoj++RGpsIypemLe9LAfJOVYsUf3HK+ubvBf1QUvVxecqbrwLag7Odm9C/7i3gXnoy56C+qGJ7t3n6kuOFN1/beg7nSv5ZG0KLn6dw0PGdOQ+IcanhCuUBAMGp7wdRqeXefSdvnBiRue9gvCAea9OtT5QEdovpKOEsHP0ryT1yRdFDVLMz0BhHlVU9Xm9Sg2/zmrKDcvzN4HU+7u/VnvPzcU1mnopuEgjdw0GqR9N63DAKfGEtmqyOujePvQwF9aWuWayPWHvVv7ZrIQ5izOdBhpK1CNz41RWIT5zCZSRoi5zEHXur52jfv6NvS9gQ2n+zYEkwM+hC/3oQ67zzd1rvcNDF39Bg== -------------------------------------------------------------------------------- /assets/single-pool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/single-pool.png -------------------------------------------------------------------------------- /assets/single-pool.xml: -------------------------------------------------------------------------------- 1 | 3VjRkpowFP0aZ9qHzghBxEd13e3D9qXuWPepE8kV0gbChLBiv74JBAHR7dqxdfVFc89NQnLOyeRCD02j/EHgJPzCCbCe3Sd5D931bHvo2epXA9sSQKNhCQSCkhKyamBOf4EB+wbNKIG01VFyziRN2qDP4xh82cKwEHzT7rbmrP3UBAfQAeY+Zl30GyUyLFHPHtb4Z6BBWD3ZckdlJsJVZ7OTNMSEbxoQmvXQVHAuy1aUT4Fp7ipeynH3R7K7hQmI5VsGZIvl02zxhMfuMnv0vi7WSW59MrO8YJaZDSc0B/bd4KncVkQInsUE9Fz9HppsQiphnmBfZzdKeYWFMmIqslST4DQs+upgTRmbcsZFMRFaez74vsJTKfhPaGRW3sAZ6NnXPJbGCpar4u5Wq3WDkJA3ILP1B+ARSLFVXUzWMSoYGw4qVTYNUQ0UNvSsMGxsFOwmrplWDUP2CcTbR4i3b414y35nzKMjzKNbY37P8m7/wsQ7R4h3bo34fctfnPlBh2Eg6pYzYcxj9Tdpk86FDHnAY8weOU8Muz9Ayq0hC2eSt4WAnMqlGa7bz432Xd4MtlUQq90tm8FzM6gHFVE9ioz1zV4vXSH3VFNS5I8qmPJM+PDn21BiEYB8pZ+54jWJr/pBAMOSvrRLibOr6/5Xda3T1bXeh7r2Vao7vNjZtf7i7FoXUxddpbrexc7uG9W13oe6zlWqO+pUPALSjMl/V/CQAXjEOVTwePYKue5eweOdp+DZVS5Vje9euOCpvh4cYL5fhCrvqghNVkK1Alnw4OJIUx2vUv0X4fzD7mV4up8tpzgbtHv5m3YXtns9OZpzPnYcpZSTbdu0HWHOZ9M+BsKMBrEKfeUHUPhE+4D6mI1NIqKE6Mcc9GnbyW2rnt94Tn/PeKOu8azhAefZpztPhfXHmiLX+OKFZr8B -------------------------------------------------------------------------------- /assets/subsample-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/subsample-mnist.png -------------------------------------------------------------------------------- /assets/subsample-mnist.xml: -------------------------------------------------------------------------------- 1 |  -------------------------------------------------------------------------------- /assets/subsampled-image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/subsampled-image.png -------------------------------------------------------------------------------- /assets/vgg-resnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/vgg-resnet.png -------------------------------------------------------------------------------- /assets/vgg-table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/vgg-table.png -------------------------------------------------------------------------------- /assets/vgg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bentrevett/pytorch-image-classification/33c9cfdb9097bc65b5800c1996f199f0ba520888/assets/vgg.png -------------------------------------------------------------------------------- /misc/4 - VGG.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.optim as optim\n", 12 | "\n", 13 | "import torchvision.transforms as transforms\n", 14 | "import torchvision.datasets as datasets\n", 15 | "\n", 16 | "import random\n", 17 | "import time\n", 18 | "\n", 19 | "import numpy as np" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "SEED = 1234\n", 29 | "\n", 30 | "random.seed(SEED)\n", 31 | "np.random.seed(SEED)\n", 32 | "torch.manual_seed(SEED)\n", 33 | "torch.cuda.manual_seed(SEED)\n", 34 | "torch.backends.cudnn.deterministic = True" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "Files already downloaded and verified\n", 47 | "Calculated means: [0.49139968 0.48215841 0.44653091]\n", 48 | "Calculated stds: [0.24703223 0.24348513 0.26158784]\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "train_data = datasets.CIFAR10(root = 'data', \n", 54 | " train = True, \n", 55 | " download = True)\n", 56 | "\n", 57 | "means = train_data.data.mean(axis = (0,1,2)) / 255\n", 58 | "stds = train_data.data.std(axis = (0,1,2)) / 255\n", 59 | "\n", 60 | "print(f'Calculated means: {means}')\n", 61 | "print(f'Calculated stds: {stds}')" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "train_transforms = transforms.Compose([\n", 71 | " transforms.RandomHorizontalFlip(),\n", 72 | " transforms.RandomRotation(10),\n", 73 | " transforms.RandomCrop(32, padding = 3),\n", 74 | " transforms.ToTensor(),\n", 75 | " transforms.Normalize(mean = means, \n", 76 | " std = stds)\n", 77 | " ])\n", 78 | "\n", 79 | "test_transforms = transforms.Compose([\n", 80 | " transforms.ToTensor(),\n", 81 | " transforms.Normalize(mean = means, \n", 82 | " std = stds)\n", 83 | " ])" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Files already downloaded and verified\n", 96 | "Files already downloaded and verified\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "train_data = datasets.CIFAR10('data', \n", 102 | " train = True, \n", 103 | " download = True, \n", 104 | " transform = train_transforms)\n", 105 | "\n", 106 | "test_data = datasets.CIFAR10('data', \n", 107 | " train = False, \n", 108 | " download = True, \n", 109 | " transform = test_transforms)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "n_train_examples = int(len(train_data)*0.9)\n", 119 | "n_valid_examples = len(train_data) - n_train_examples\n", 120 | "\n", 121 | "train_data, valid_data = torch.utils.data.random_split(train_data, \n", 122 | " [n_train_examples, n_valid_examples])" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "Number of training examples: 45000\n", 135 | "Number of validation examples: 5000\n", 136 | "Number of testing examples: 10000\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "print(f'Number of training examples: {len(train_data)}')\n", 142 | "print(f'Number of validation examples: {len(valid_data)}')\n", 143 | "print(f'Number of testing examples: {len(test_data)}')" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 8, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "BATCH_SIZE = 64\n", 153 | "\n", 154 | "train_iterator = torch.utils.data.DataLoader(train_data, \n", 155 | " shuffle = True, \n", 156 | " batch_size = BATCH_SIZE)\n", 157 | "\n", 158 | "valid_iterator = torch.utils.data.DataLoader(valid_data, \n", 159 | " batch_size = BATCH_SIZE)\n", 160 | "\n", 161 | "test_iterator = torch.utils.data.DataLoader(test_data, \n", 162 | " batch_size = BATCH_SIZE)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 9, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "class VGGBlock(nn.Module):\n", 172 | " def __init__(self, in_channels, out_channels, batch_norm):\n", 173 | " super().__init__()\n", 174 | " \n", 175 | " modules = []\n", 176 | " modules.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))\n", 177 | " if batch_norm:\n", 178 | " modules.append(nn.BatchNorm2d(out_channels))\n", 179 | " modules.append(nn.ReLU(inplace=True))\n", 180 | " \n", 181 | " self.block = nn.Sequential(*modules)\n", 182 | " \n", 183 | " def forward(self, x):\n", 184 | " return self.block(x)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 10, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "class VGG11(nn.Module):\n", 194 | " def __init__(self, output_dim, block, pool, batch_norm):\n", 195 | " super().__init__()\n", 196 | " \n", 197 | " self.features = nn.Sequential(\n", 198 | " block(3, 64, batch_norm), #in_channels, out_channels\n", 199 | " pool(2, 2), #kernel_size, stride\n", 200 | " block(64, 128, batch_norm),\n", 201 | " pool(2, 2),\n", 202 | " block(128, 256, batch_norm),\n", 203 | " block(256, 256, batch_norm),\n", 204 | " pool(2, 2),\n", 205 | " block(256, 512, batch_norm),\n", 206 | " block(512, 512, batch_norm),\n", 207 | " pool(2, 2),\n", 208 | " block(512, 512, batch_norm),\n", 209 | " block(512, 512, batch_norm),\n", 210 | " pool(2, 2),\n", 211 | " )\n", 212 | " \n", 213 | " self.classifier = nn.Linear(512, output_dim)\n", 214 | "\n", 215 | " def forward(self, x):\n", 216 | " x = self.features(x)\n", 217 | " x = x.view(x.shape[0], -1)\n", 218 | " x = self.classifier(x)\n", 219 | " return x" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "extra 64, 128, 256, 512 and 512" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 11, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "class VGG16(nn.Module):\n", 236 | " def __init__(self, output_dim, block, pool, batch_norm):\n", 237 | " super().__init__()\n", 238 | " \n", 239 | " self.features = nn.Sequential(\n", 240 | " block(3, 64, batch_norm),\n", 241 | " block(64, 64, batch_norm),\n", 242 | " pool(2, 2),\n", 243 | " block(64, 128, batch_norm),\n", 244 | " block(128, 128, batch_norm),\n", 245 | " pool(2, 2),\n", 246 | " block(128, 256, batch_norm),\n", 247 | " block(256, 256, batch_norm),\n", 248 | " block(256, 256, batch_norm),\n", 249 | " pool(2, 2),\n", 250 | " block(256, 512, batch_norm),\n", 251 | " block(512, 512, batch_norm),\n", 252 | " block(512, 512, batch_norm),\n", 253 | " pool(2, 2),\n", 254 | " block(512, 512, batch_norm),\n", 255 | " block(512, 512, batch_norm),\n", 256 | " block(512, 512, batch_norm),\n", 257 | " pool(2, 2),\n", 258 | " )\n", 259 | " \n", 260 | " self.classifier = nn.Linear(512, output_dim)\n", 261 | "\n", 262 | " def forward(self, x):\n", 263 | " x = self.features(x)\n", 264 | " x = x.view(x.shape[0], -1)\n", 265 | " x = self.classifier(x)\n", 266 | " return x" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "extra 256, 512 and 512" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 12, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "class VGG19(nn.Module):\n", 283 | " def __init__(self, output_dim, block, pool, batch_norm):\n", 284 | " super().__init__()\n", 285 | " \n", 286 | " self.features = nn.Sequential(\n", 287 | " block(3, 64, batch_norm),\n", 288 | " block(64, 64, batch_norm),\n", 289 | " pool(2, 2),\n", 290 | " block(64, 128, batch_norm),\n", 291 | " block(128, 128, batch_norm),\n", 292 | " pool(2, 2),\n", 293 | " block(128, 256, batch_norm),\n", 294 | " block(256, 256, batch_norm),\n", 295 | " block(256, 256, batch_norm),\n", 296 | " block(256, 256, batch_norm),\n", 297 | " pool(2, 2),\n", 298 | " block(256, 512, batch_norm),\n", 299 | " block(512, 512, batch_norm),\n", 300 | " block(512, 512, batch_norm),\n", 301 | " block(512, 512, batch_norm),\n", 302 | " pool(2, 2),\n", 303 | " block(512, 512, batch_norm),\n", 304 | " block(512, 512, batch_norm),\n", 305 | " block(512, 512, batch_norm),\n", 306 | " block(512, 512, batch_norm),\n", 307 | " pool(2, 2),\n", 308 | " )\n", 309 | " \n", 310 | " self.classifier = nn.Linear(512, output_dim)\n", 311 | "\n", 312 | " def forward(self, x):\n", 313 | " x = self.features(x)\n", 314 | " x = x.view(x.shape[0], -1)\n", 315 | " x = self.classifier(x)\n", 316 | " return x" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 13, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "OUTPUT_DIM = 10\n", 326 | "BATCH_NORM = True\n", 327 | "\n", 328 | "vgg11_model = VGG11(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) \n", 329 | "vgg16_model = VGG16(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) \n", 330 | "vgg19_model = VGG19(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) " 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 14, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "VGG11 has 9,231,114 trainable parameters\n", 343 | "VGG16 has 14,728,266 trainable parameters\n", 344 | "VGG19 has 20,040,522 trainable parameters\n" 345 | ] 346 | } 347 | ], 348 | "source": [ 349 | "def count_parameters(model):\n", 350 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 351 | "\n", 352 | "print(f'VGG11 has {count_parameters(vgg11_model):,} trainable parameters')\n", 353 | "print(f'VGG16 has {count_parameters(vgg16_model):,} trainable parameters')\n", 354 | "print(f'VGG19 has {count_parameters(vgg19_model):,} trainable parameters')" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 15, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "model = VGG11(OUTPUT_DIM, VGGBlock, nn.MaxPool2d, BATCH_NORM) " 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": 16, 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "data": { 373 | "text/plain": [ 374 | "VGG11(\n", 375 | " (features): Sequential(\n", 376 | " (0): VGGBlock(\n", 377 | " (block): Sequential(\n", 378 | " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 379 | " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 380 | " (2): ReLU(inplace=True)\n", 381 | " )\n", 382 | " )\n", 383 | " (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 384 | " (2): VGGBlock(\n", 385 | " (block): Sequential(\n", 386 | " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 387 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 388 | " (2): ReLU(inplace=True)\n", 389 | " )\n", 390 | " )\n", 391 | " (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 392 | " (4): VGGBlock(\n", 393 | " (block): Sequential(\n", 394 | " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 395 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 396 | " (2): ReLU(inplace=True)\n", 397 | " )\n", 398 | " )\n", 399 | " (5): VGGBlock(\n", 400 | " (block): Sequential(\n", 401 | " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 402 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 403 | " (2): ReLU(inplace=True)\n", 404 | " )\n", 405 | " )\n", 406 | " (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 407 | " (7): VGGBlock(\n", 408 | " (block): Sequential(\n", 409 | " (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 410 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 411 | " (2): ReLU(inplace=True)\n", 412 | " )\n", 413 | " )\n", 414 | " (8): VGGBlock(\n", 415 | " (block): Sequential(\n", 416 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 417 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 418 | " (2): ReLU(inplace=True)\n", 419 | " )\n", 420 | " )\n", 421 | " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 422 | " (10): VGGBlock(\n", 423 | " (block): Sequential(\n", 424 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 425 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 426 | " (2): ReLU(inplace=True)\n", 427 | " )\n", 428 | " )\n", 429 | " (11): VGGBlock(\n", 430 | " (block): Sequential(\n", 431 | " (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 432 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 433 | " (2): ReLU(inplace=True)\n", 434 | " )\n", 435 | " )\n", 436 | " (12): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 437 | " )\n", 438 | " (classifier): Linear(in_features=512, out_features=10, bias=True)\n", 439 | ")" 440 | ] 441 | }, 442 | "execution_count": 16, 443 | "metadata": {}, 444 | "output_type": "execute_result" 445 | } 446 | ], 447 | "source": [ 448 | "model" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 17, 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "optimizer = optim.Adam(model.parameters())" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 18, 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "criterion = nn.CrossEntropyLoss()" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 19, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": 20, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "model = model.to(device)\n", 485 | "criterion = criterion.to(device)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 21, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "def calculate_accuracy(fx, y):\n", 495 | " preds = fx.argmax(1, keepdim=True)\n", 496 | " correct = preds.eq(y.view_as(preds)).sum()\n", 497 | " acc = correct.float()/preds.shape[0]\n", 498 | " return acc" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": 22, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "def train(model, iterator, optimizer, criterion, device):\n", 508 | " \n", 509 | " epoch_loss = 0\n", 510 | " epoch_acc = 0\n", 511 | " \n", 512 | " model.train()\n", 513 | " \n", 514 | " for (x, y) in iterator:\n", 515 | " \n", 516 | " x = x.to(device)\n", 517 | " y = y.to(device)\n", 518 | " \n", 519 | " optimizer.zero_grad()\n", 520 | " \n", 521 | " fx = model(x)\n", 522 | " \n", 523 | " loss = criterion(fx, y)\n", 524 | " \n", 525 | " acc = calculate_accuracy(fx, y)\n", 526 | " \n", 527 | " loss.backward()\n", 528 | " \n", 529 | " optimizer.step()\n", 530 | " \n", 531 | " epoch_loss += loss.item()\n", 532 | " epoch_acc += acc.item()\n", 533 | " \n", 534 | " return epoch_loss / len(iterator), epoch_acc / len(iterator)" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 23, 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [ 543 | "def evaluate(model, iterator, criterion, device):\n", 544 | " \n", 545 | " epoch_loss = 0\n", 546 | " epoch_acc = 0\n", 547 | " \n", 548 | " model.eval()\n", 549 | " \n", 550 | " with torch.no_grad():\n", 551 | " for (x, y) in iterator:\n", 552 | "\n", 553 | " x = x.to(device)\n", 554 | " y = y.to(device)\n", 555 | "\n", 556 | " fx = model(x)\n", 557 | "\n", 558 | " loss = criterion(fx, y)\n", 559 | "\n", 560 | " acc = calculate_accuracy(fx, y)\n", 561 | "\n", 562 | " epoch_loss += loss.item()\n", 563 | " epoch_acc += acc.item()\n", 564 | " \n", 565 | " return epoch_loss / len(iterator), epoch_acc / len(iterator)" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 24, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "def epoch_time(start_time, end_time):\n", 575 | " elapsed_time = end_time - start_time\n", 576 | " elapsed_mins = int(elapsed_time / 60)\n", 577 | " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", 578 | " return elapsed_mins, elapsed_secs" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 25, 584 | "metadata": {}, 585 | "outputs": [ 586 | { 587 | "name": "stdout", 588 | "output_type": "stream", 589 | "text": [ 590 | "Epoch: 01 | Epoch Time: 0m 19s\n", 591 | "\tTrain Loss: 1.469 | Train Acc: 46.14%\n", 592 | "\t Val. Loss: 1.287 | Val. Acc: 54.75%\n", 593 | "Epoch: 02 | Epoch Time: 0m 19s\n", 594 | "\tTrain Loss: 1.035 | Train Acc: 63.19%\n", 595 | "\t Val. Loss: 0.973 | Val. Acc: 66.83%\n", 596 | "Epoch: 03 | Epoch Time: 0m 19s\n", 597 | "\tTrain Loss: 0.856 | Train Acc: 69.85%\n", 598 | "\t Val. Loss: 0.833 | Val. Acc: 70.41%\n", 599 | "Epoch: 04 | Epoch Time: 0m 19s\n", 600 | "\tTrain Loss: 0.761 | Train Acc: 73.57%\n", 601 | "\t Val. Loss: 0.780 | Val. Acc: 73.08%\n", 602 | "Epoch: 05 | Epoch Time: 0m 19s\n", 603 | "\tTrain Loss: 0.671 | Train Acc: 76.72%\n", 604 | "\t Val. Loss: 0.693 | Val. Acc: 75.95%\n", 605 | "Epoch: 06 | Epoch Time: 0m 19s\n", 606 | "\tTrain Loss: 0.619 | Train Acc: 78.62%\n", 607 | "\t Val. Loss: 0.648 | Val. Acc: 77.97%\n", 608 | "Epoch: 07 | Epoch Time: 0m 19s\n", 609 | "\tTrain Loss: 0.570 | Train Acc: 80.38%\n", 610 | "\t Val. Loss: 0.631 | Val. Acc: 78.34%\n", 611 | "Epoch: 08 | Epoch Time: 0m 19s\n", 612 | "\tTrain Loss: 0.523 | Train Acc: 82.01%\n", 613 | "\t Val. Loss: 0.606 | Val. Acc: 79.89%\n", 614 | "Epoch: 09 | Epoch Time: 0m 19s\n", 615 | "\tTrain Loss: 0.489 | Train Acc: 83.04%\n", 616 | "\t Val. Loss: 0.606 | Val. Acc: 79.96%\n", 617 | "Epoch: 10 | Epoch Time: 0m 19s\n", 618 | "\tTrain Loss: 0.454 | Train Acc: 84.42%\n", 619 | "\t Val. Loss: 0.575 | Val. Acc: 80.70%\n" 620 | ] 621 | } 622 | ], 623 | "source": [ 624 | "EPOCHS = 10\n", 625 | "\n", 626 | "best_valid_loss = float('inf')\n", 627 | "\n", 628 | "for epoch in range(EPOCHS):\n", 629 | " \n", 630 | " start_time = time.time()\n", 631 | " \n", 632 | " train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)\n", 633 | " valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)\n", 634 | " \n", 635 | " if valid_loss < best_valid_loss:\n", 636 | " best_valid_loss = valid_loss\n", 637 | " torch.save(model.state_dict(), 'tut5-model.pt')\n", 638 | " \n", 639 | " end_time = time.time()\n", 640 | "\n", 641 | " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", 642 | " \n", 643 | " print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n", 644 | " print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n", 645 | " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": 26, 651 | "metadata": {}, 652 | "outputs": [ 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "Test Loss: 0.487 | Test Acc: 83.41%\n" 658 | ] 659 | } 660 | ], 661 | "source": [ 662 | "model.load_state_dict(torch.load('tut5-model.pt'))\n", 663 | "\n", 664 | "test_loss, test_acc = evaluate(model, test_iterator, criterion, device)\n", 665 | "\n", 666 | "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')" 667 | ] 668 | } 669 | ], 670 | "metadata": { 671 | "kernelspec": { 672 | "display_name": "Python 3", 673 | "language": "python", 674 | "name": "python3" 675 | }, 676 | "language_info": { 677 | "codemirror_mode": { 678 | "name": "ipython", 679 | "version": 3 680 | }, 681 | "file_extension": ".py", 682 | "mimetype": "text/x-python", 683 | "name": "python", 684 | "nbconvert_exporter": "python", 685 | "pygments_lexer": "ipython3", 686 | "version": "3.7.6" 687 | } 688 | }, 689 | "nbformat": 4, 690 | "nbformat_minor": 2 691 | } 692 | -------------------------------------------------------------------------------- /misc/6 - ResNet - Dogs vs Cats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "import torch.optim as optim\n", 13 | "import torchvision\n", 14 | "import torchvision.transforms as transforms\n", 15 | "import torchvision.datasets as datasets\n", 16 | "\n", 17 | "import os\n", 18 | "import random\n", 19 | "import numpy as np" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "SEED = 1234\n", 29 | "\n", 30 | "random.seed(SEED)\n", 31 | "np.random.seed(SEED)\n", 32 | "torch.manual_seed(SEED)\n", 33 | "torch.cuda.manual_seed(SEED)\n", 34 | "torch.backends.cudnn.deterministic = True" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "train_transforms = transforms.Compose([\n", 44 | " transforms.RandomHorizontalFlip(),\n", 45 | " transforms.RandomRotation(10),\n", 46 | " transforms.RandomCrop((224, 224), pad_if_needed=True),\n", 47 | " transforms.ToTensor(),\n", 48 | " transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))\n", 49 | " ])\n", 50 | "\n", 51 | "test_transforms = transforms.Compose([\n", 52 | " transforms.CenterCrop((224, 224)),\n", 53 | " transforms.ToTensor(),\n", 54 | " transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))\n", 55 | " ])" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "train_data = datasets.ImageFolder('data/dogs-vs-cats/train', train_transforms)\n", 65 | "valid_data = datasets.ImageFolder('data/dogs-vs-cats/valid', test_transforms)\n", 66 | "test_data = datasets.ImageFolder('data/dogs-vs-cats/test', test_transforms)\n", 67 | "\n", 68 | "#import os\n", 69 | "\n", 70 | "#print(len(os.listdir('data/dogs-vs-cats/train')))\n", 71 | "\n", 72 | "#n_train_examples = int(len(train_data)*0.9)\n", 73 | "#n_valid_examples = n_test_examples = len(train_data) - n_train_examples\n", 74 | "\n", 75 | "#train_data, valid_data = torch.utils.data.random_split(train_data, [n_train_examples, n_valid_examples])\n", 76 | "#train_data, test_data = torch.utils.data.random_split(train_data, [n_train_examples-n_valid_examples, n_test_examples])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "https://github.com/facebook/fb.resnet.torch/issues/180\n", 84 | "https://github.com/bamos/densenet.pytorch/blob/master/compute-cifar10-mean.py" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Number of training examples: 20000\n", 97 | "Number of validation examples: 2500\n", 98 | "Number of testing examples: 2500\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "print(f'Number of training examples: {len(train_data)}')\n", 104 | "print(f'Number of validation examples: {len(valid_data)}')\n", 105 | "print(f'Number of testing examples: {len(test_data)}')" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "BATCH_SIZE = 64\n", 115 | "\n", 116 | "train_iterator = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=BATCH_SIZE)\n", 117 | "valid_iterator = torch.utils.data.DataLoader(valid_data, batch_size=BATCH_SIZE)\n", 118 | "test_iterator = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "https://discuss.pytorch.org/t/why-does-the-resnet-model-given-by-pytorch-omit-biases-from-the-convolutional-layer/10990/4\n", 126 | "https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "device = torch.device('cuda')" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 8, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "import torchvision.models as models\n", 145 | "\n", 146 | "model = models.resnet18(pretrained=True).to(device)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 9, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "ResNet(\n", 159 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 160 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 161 | " (relu): ReLU(inplace)\n", 162 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 163 | " (layer1): Sequential(\n", 164 | " (0): BasicBlock(\n", 165 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 166 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 167 | " (relu): ReLU(inplace)\n", 168 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 169 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 170 | " )\n", 171 | " (1): BasicBlock(\n", 172 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 173 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 174 | " (relu): ReLU(inplace)\n", 175 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 176 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 177 | " )\n", 178 | " )\n", 179 | " (layer2): Sequential(\n", 180 | " (0): BasicBlock(\n", 181 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 182 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 183 | " (relu): ReLU(inplace)\n", 184 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 185 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 186 | " (downsample): Sequential(\n", 187 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 188 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 189 | " )\n", 190 | " )\n", 191 | " (1): BasicBlock(\n", 192 | " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 193 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 194 | " (relu): ReLU(inplace)\n", 195 | " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 196 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 197 | " )\n", 198 | " )\n", 199 | " (layer3): Sequential(\n", 200 | " (0): BasicBlock(\n", 201 | " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 202 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 203 | " (relu): ReLU(inplace)\n", 204 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 205 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 206 | " (downsample): Sequential(\n", 207 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 208 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 209 | " )\n", 210 | " )\n", 211 | " (1): BasicBlock(\n", 212 | " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 213 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 214 | " (relu): ReLU(inplace)\n", 215 | " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 216 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 217 | " )\n", 218 | " )\n", 219 | " (layer4): Sequential(\n", 220 | " (0): BasicBlock(\n", 221 | " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 222 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 223 | " (relu): ReLU(inplace)\n", 224 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 225 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 226 | " (downsample): Sequential(\n", 227 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 228 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 229 | " )\n", 230 | " )\n", 231 | " (1): BasicBlock(\n", 232 | " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 233 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 234 | " (relu): ReLU(inplace)\n", 235 | " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 236 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 237 | " )\n", 238 | " )\n", 239 | " (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n", 240 | " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", 241 | ")\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "print(model)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 10, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "for param in model.parameters():\n", 256 | " param.requires_grad = False" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 11, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "name": "stdout", 266 | "output_type": "stream", 267 | "text": [ 268 | "Linear(in_features=512, out_features=1000, bias=True)\n" 269 | ] 270 | } 271 | ], 272 | "source": [ 273 | "print(model.fc)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 12, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "model.fc = nn.Linear(in_features=512, out_features=2).to(device)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 13, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "optimizer = optim.Adam(model.parameters())" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 14, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "criterion = nn.CrossEntropyLoss()" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 15, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "def calculate_accuracy(fx, y):\n", 310 | " preds = fx.max(1, keepdim=True)[1]\n", 311 | " correct = preds.eq(y.view_as(preds)).sum()\n", 312 | " acc = correct.float()/preds.shape[0]\n", 313 | " return acc" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 16, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "def train(model, device, iterator, optimizer, criterion):\n", 323 | " \n", 324 | " epoch_loss = 0\n", 325 | " epoch_acc = 0\n", 326 | " \n", 327 | " model.train()\n", 328 | " \n", 329 | " for (x, y) in iterator:\n", 330 | " \n", 331 | " x = x.to(device)\n", 332 | " y = y.to(device)\n", 333 | " \n", 334 | " optimizer.zero_grad()\n", 335 | " \n", 336 | " fx = model(x)\n", 337 | " \n", 338 | " loss = criterion(fx, y)\n", 339 | " \n", 340 | " acc = calculate_accuracy(fx, y)\n", 341 | " \n", 342 | " loss.backward()\n", 343 | " \n", 344 | " optimizer.step()\n", 345 | " \n", 346 | " epoch_loss += loss.item()\n", 347 | " epoch_acc += acc.item()\n", 348 | " \n", 349 | " return epoch_loss / len(iterator), epoch_acc / len(iterator)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 17, 355 | "metadata": {}, 356 | "outputs": [], 357 | "source": [ 358 | "def evaluate(model, device, iterator, criterion):\n", 359 | " \n", 360 | " epoch_loss = 0\n", 361 | " epoch_acc = 0\n", 362 | " \n", 363 | " model.eval()\n", 364 | " \n", 365 | " with torch.no_grad():\n", 366 | " for (x, y) in iterator:\n", 367 | "\n", 368 | " x = x.to(device)\n", 369 | " y = y.to(device)\n", 370 | "\n", 371 | " fx = model(x)\n", 372 | "\n", 373 | " loss = criterion(fx, y)\n", 374 | "\n", 375 | " acc = calculate_accuracy(fx, y)\n", 376 | "\n", 377 | " epoch_loss += loss.item()\n", 378 | " epoch_acc += acc.item()\n", 379 | " \n", 380 | " return epoch_loss / len(iterator), epoch_acc / len(iterator)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 18, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "| Epoch: 01 | Train Loss: 0.198 | Train Acc: 91.95% | Val. Loss: 0.089 | Val. Acc: 96.80% |\n", 393 | "| Epoch: 02 | Train Loss: 0.136 | Train Acc: 94.40% | Val. Loss: 0.069 | Val. Acc: 97.50% |\n", 394 | "| Epoch: 03 | Train Loss: 0.128 | Train Acc: 94.68% | Val. Loss: 0.059 | Val. Acc: 97.70% |\n", 395 | "| Epoch: 04 | Train Loss: 0.119 | Train Acc: 95.03% | Val. Loss: 0.070 | Val. Acc: 97.30% |\n", 396 | "| Epoch: 05 | Train Loss: 0.118 | Train Acc: 94.95% | Val. Loss: 0.057 | Val. Acc: 97.73% |\n", 397 | "| Epoch: 06 | Train Loss: 0.121 | Train Acc: 94.95% | Val. Loss: 0.056 | Val. Acc: 97.70% |\n", 398 | "| Epoch: 07 | Train Loss: 0.117 | Train Acc: 95.11% | Val. Loss: 0.063 | Val. Acc: 97.46% |\n", 399 | "| Epoch: 08 | Train Loss: 0.110 | Train Acc: 95.44% | Val. Loss: 0.052 | Val. Acc: 97.93% |\n", 400 | "| Epoch: 09 | Train Loss: 0.116 | Train Acc: 95.14% | Val. Loss: 0.056 | Val. Acc: 97.77% |\n", 401 | "| Epoch: 10 | Train Loss: 0.114 | Train Acc: 95.36% | Val. Loss: 0.063 | Val. Acc: 97.46% |\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "EPOCHS = 10\n", 407 | "SAVE_DIR = 'models'\n", 408 | "MODEL_SAVE_PATH = os.path.join(SAVE_DIR, 'resnet18-dogs-vs-cats.pt')\n", 409 | "\n", 410 | "best_valid_loss = float('inf')\n", 411 | "\n", 412 | "if not os.path.isdir(f'{SAVE_DIR}'):\n", 413 | " os.makedirs(f'{SAVE_DIR}')\n", 414 | "\n", 415 | "for epoch in range(EPOCHS):\n", 416 | " train_loss, train_acc = train(model, device, train_iterator, optimizer, criterion)\n", 417 | " valid_loss, valid_acc = evaluate(model, device, valid_iterator, criterion)\n", 418 | " \n", 419 | " if valid_loss < best_valid_loss:\n", 420 | " best_valid_loss = valid_loss\n", 421 | " torch.save(model.state_dict(), MODEL_SAVE_PATH)\n", 422 | " \n", 423 | " print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:05.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:05.2f}% |')" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 19, 429 | "metadata": {}, 430 | "outputs": [ 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "| Test Loss: 0.052 | Test Acc: 97.93% |\n" 436 | ] 437 | } 438 | ], 439 | "source": [ 440 | "model.load_state_dict(torch.load(MODEL_SAVE_PATH))\n", 441 | "\n", 442 | "test_loss, test_acc = evaluate(model, device, valid_iterator, criterion)\n", 443 | "\n", 444 | "print(f'| Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:05.2f}% |')" 445 | ] 446 | } 447 | ], 448 | "metadata": { 449 | "kernelspec": { 450 | "display_name": "Python 3", 451 | "language": "python", 452 | "name": "python3" 453 | }, 454 | "language_info": { 455 | "codemirror_mode": { 456 | "name": "ipython", 457 | "version": 3 458 | }, 459 | "file_extension": ".py", 460 | "mimetype": "text/x-python", 461 | "name": "python", 462 | "nbconvert_exporter": "python", 463 | "pygments_lexer": "ipython3", 464 | "version": "3.7.0" 465 | } 466 | }, 467 | "nbformat": 4, 468 | "nbformat_minor": 2 469 | } 470 | -------------------------------------------------------------------------------- /misc/conv order.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "- https://blog.paperspace.com/busting-the-myths-about-batch-normalization/\n", 8 | "- http://forums.fast.ai/t/questions-about-batch-normalization/230/2\n", 9 | "- https://www.quora.com/In-most-papers-I-read-the-CNN-order-is-convolution-relu-max-pooling-So-can-I-change-the-order-to-become-convolution-max-pooling-relu\n", 10 | "\n", 11 | "Therefore, order is: `conv -> pool -> relu -> BN -> drop -> conv`" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [] 20 | } 21 | ], 22 | "metadata": { 23 | "kernelspec": { 24 | "display_name": "Python 3", 25 | "language": "python", 26 | "name": "python3" 27 | }, 28 | "language_info": { 29 | "codemirror_mode": { 30 | "name": "ipython", 31 | "version": 3 32 | }, 33 | "file_extension": ".py", 34 | "mimetype": "text/x-python", 35 | "name": "python", 36 | "nbconvert_exporter": "python", 37 | "pygments_lexer": "ipython3", 38 | "version": "3.6.5" 39 | } 40 | }, 41 | "nbformat": 4, 42 | "nbformat_minor": 2 43 | } 44 | -------------------------------------------------------------------------------- /misc/download_dogs-vs-cats.sh: -------------------------------------------------------------------------------- 1 | mkdir data 2 | mkdir data/dogs-vs-cats 3 | kaggle competitions download -c dogs-vs-cats 4 | rm sampleSubmission.csv 5 | rm test1.zip 6 | unzip train.zip 7 | mv train data/dogs-vs-cats 8 | rm train.zip 9 | mkdir data/dogs-vs-cats/train/dog 10 | mkdir data/dogs-vs-cats/train/cat 11 | mkdir data/dogs-vs-cats/valid 12 | mkdir data/dogs-vs-cats/valid/dog 13 | mkdir data/dogs-vs-cats/valid/cat 14 | mkdir data/dogs-vs-cats/test 15 | mkdir data/dogs-vs-cats/test/dog 16 | mkdir data/dogs-vs-cats/test/cat 17 | python process_dogs-vs-cats.py 18 | -------------------------------------------------------------------------------- /misc/process_dogs-vs-cats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | 5 | SEED = 1234 6 | 7 | random.seed(SEED) 8 | 9 | TRAIN_DIR = 'data/dogs-vs-cats/train/' 10 | VALID_DIR = 'data/dogs-vs-cats/valid/' 11 | TEST_DIR = 'data/dogs-vs-cats/test/' 12 | 13 | assert os.path.exists(TRAIN_DIR) 14 | assert os.path.exists(VALID_DIR) 15 | assert os.path.exists(TEST_DIR) 16 | assert os.path.exists(os.path.join(TRAIN_DIR, 'cat')), 'Run download-dogs-vs-cats.sh first!' 17 | assert os.path.exists(os.path.join(TRAIN_DIR, 'dog')), 'Run download-dogs-vs-cats.sh first!' 18 | assert os.path.exists(os.path.join(VALID_DIR, 'cat')), 'Run download-dogs-vs-cats.sh first!' 19 | assert os.path.exists(os.path.join(VALID_DIR, 'dog')), 'Run download-dogs-vs-cats.sh first!' 20 | assert os.path.exists(os.path.join(TEST_DIR, 'cat')), 'Run download-dogs-vs-cats.sh first!' 21 | assert os.path.exists(os.path.join(TEST_DIR, 'dog')), 'Run download-dogs-vs-cats.sh first!' 22 | 23 | all_images = os.listdir(TRAIN_DIR) 24 | 25 | cats = [t for t in all_images if 'cat' in t and t.endswith('.jpg')] 26 | dogs = [t for t in all_images if 'dog' in t and t.endswith('.jpg')] 27 | 28 | random.shuffle(cats) 29 | random.shuffle(dogs) 30 | 31 | n_train_examples = int(len(all_images) * 0.8) // 2 32 | n_valid_examples = int(len(all_images) * 0.1) // 2 33 | 34 | train_cats = cats[:n_train_examples] 35 | valid_cats = cats[n_train_examples:n_train_examples+n_valid_examples] 36 | test_cats = cats[n_train_examples+n_valid_examples:] 37 | 38 | train_dogs = dogs[:n_train_examples] 39 | valid_dogs = dogs[n_train_examples:n_train_examples+n_valid_examples] 40 | test_dogs = dogs[n_train_examples+n_valid_examples:] 41 | 42 | for cat in train_cats: 43 | shutil.move(os.path.join(TRAIN_DIR, cat), os.path.join(TRAIN_DIR, 'cat', cat)) 44 | 45 | for cat in valid_cats: 46 | shutil.move(os.path.join(TRAIN_DIR, cat), os.path.join(VALID_DIR, 'cat', cat)) 47 | 48 | for cat in test_cats: 49 | shutil.move(os.path.join(TRAIN_DIR, cat), os.path.join(TEST_DIR, 'cat', cat)) 50 | 51 | for dog in train_dogs: 52 | shutil.move(os.path.join(TRAIN_DIR, dog), os.path.join(TRAIN_DIR, 'dog', dog)) 53 | 54 | for dog in valid_dogs: 55 | shutil.move(os.path.join(TRAIN_DIR, dog), os.path.join(VALID_DIR, 'dog', dog)) 56 | 57 | for dog in test_dogs: 58 | shutil.move(os.path.join(TRAIN_DIR, dog), os.path.join(TEST_DIR, 'dog', dog)) --------------------------------------------------------------------------------