├── .gitignore ├── LICENSE ├── README.md ├── colab └── start_on_colab.ipynb ├── hubconf.py ├── pytorch_cifar_models ├── __init__.py ├── mobilenetv2.py ├── repvgg.py ├── resnet.py ├── shufflenetv2.py ├── vgg.py └── vit.py ├── setup.py └── tests └── pytorch_cifar_models ├── test_mobilenetv2.py ├── test_repvgg.py ├── test_resnet.py ├── test_shufflenetv2.py ├── test_vgg.py └── test_vit.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/code,python 3 | # Edit at https://www.gitignore.io/?templates=code,python 4 | 5 | ### Code ### 6 | .vscode/* 7 | # !.vscode/settings.json 8 | # !.vscode/tasks.json 9 | # !.vscode/launch.json 10 | # !.vscode/extensions.json 11 | 12 | ### Python ### 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # pipenv 82 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 83 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 84 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 85 | # install all needed dependencies. 86 | #Pipfile.lock 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # Mr Developer 102 | .mr.developer.cfg 103 | .project 104 | .pydevproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # End of https://www.gitignore.io/api/code,python 118 | 119 | *.pt 120 | *.pth 121 | *.log 122 | *.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, chenyaofo 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch CIFAR Models 2 | 3 | ## Introduction 4 | 5 | The goal of this project is to provide some neural network examples and a simple training codebase for begginners. 6 | 7 | ## Get Started with Google Colab Open In Colab 8 | 9 | **Train Models**: Open the notebook to train the models from scratch on CIFAR10/100. 10 | It will takes several hours depend on the complexity of the model and the allocated GPU type. 11 | 12 | **Test Models**: Open the notebook to measure the validation accuracy on CIFAR10/100 with pretrained models. 13 | It will only take about few seconds. 14 | 15 | ## Use Models with Pytorch Hub 16 | 17 | You can simply use the pretrained models in your project with `torch.hub` API. 18 | It will automatically load the code and the pretrained weights from GitHub 19 | (If you cannot directly access GitHub, please check [this issue](https://github.com/chenyaofo/pytorch-cifar-models/issues/14) for solution). 20 | 21 | ``` python 22 | import torch 23 | model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True) 24 | ``` 25 | 26 | To list all available model entry, you can run: 27 | 28 | ```python 29 | import torch 30 | from pprint import pprint 31 | pprint(torch.hub.list("chenyaofo/pytorch-cifar-models", force_reload=True)) 32 | ``` 33 | 34 | 35 | ## Model Zoo 36 | 37 | ### CIFAR-10 38 | 39 | | Model | Top-1 Acc.(%) | Top-5 Acc.(%) | #Params.(M) | #MAdds(M) | | 40 | |----------|----------------|---------------|-------------|-----------|--------------------| 41 | | resnet20 | 92.60 | 99.81 | 0.27 | 40.81 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet20/default.log) 42 | | resnet32 | 93.53 | 99.77 | 0.47 | 69.12 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet32-ef93fc4d.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet32/default.log) 43 | | resnet44 | 94.01 | 99.77 | 0.66 | 97.44 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet44/default.log) 44 | | resnet56 | 94.37 | 99.83 | 0.86 | 125.75 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet56-187c023a.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet56/default.log) 45 | | vgg11_bn | 92.79 | 99.72 | 9.76 | 153.29 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg11_bn-eaeebf42.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg11_bn/default.log) 46 | | vgg13_bn | 94.00 | 99.77 | 9.94 | 228.79 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg13_bn-c01e4a43.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg13_bn/default.log) 47 | | vgg16_bn | 94.16 | 99.71 | 15.25 | 313.73 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg16_bn-6ee7ea24.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg16_bn/default.log) 48 | | vgg19_bn | 93.91 | 99.64 | 20.57 | 398.66 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg19_bn-57191229.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg19_bn/default.log) 49 | | mobilenetv2_x0_5 | 92.88 | 99.86 | 0.70 | 27.97 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_5-ca14ced9.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x0_5/default.log) 50 | | mobilenetv2_x0_75 | 93.72 | 99.79 | 1.37 | 59.31 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_75-a53c314e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x0_75/default.log) 51 | | mobilenetv2_x1_0 | 93.79 | 99.73 | 2.24 | 87.98 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_0-fe6a5b48.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x1_0/default.log) 52 | | mobilenetv2_x1_4 | 94.22 | 99.80 | 4.33 | 170.07 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_4-3bbbd6e2.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x1_4/default.log) 53 | | shufflenetv2_x0_5 | 90.13 | 99.70 | 0.35 | 10.90 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x0_5-1308b4e9.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x0_5/default.log) 54 | | shufflenetv2_x1_0 | 92.98 | 99.73 | 1.26 | 45.00 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_0-98807be3.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x1_0/default.log) 55 | | shufflenetv2_x1_5 | 93.55 | 99.77 | 2.49 | 94.26 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_5-296694dd.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x1_5/default.log) 56 | | shufflenetv2_x2_0 | 93.81 | 99.79 | 5.37 | 187.81 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x2_0-ec31611c.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x2_0/default.log) 57 | | repvgg_a0 | 94.39 | 99.82 | 7.84 | 489.08 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a0-ef08a50e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/repvgg_a0/default.log) 58 | | repvgg_a1 | 94.89 | 99.83 | 12.82 | 851.33 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a1-38d2431b.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/repvgg_a1/default.log) 59 | | repvgg_a2 | 94.98 | 99.82 | 26.82 | 1850.10 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a2-09488915.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/repvgg_a2/default.log) 60 | 61 | ### CIFAR-100 62 | 63 | | Model | Top-1 Acc.(%) | Top-5 Acc.(%) | #Params.(M) | #MAdds(M) | | 64 | |----------|----------------|---------------|-------------|-----------|--------------------| 65 | | resnet20 | 68.83 | 91.01 | 0.28 | 40.82 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet20-23dac2f1.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet20/default.log) 66 | | resnet32 | 70.16 | 90.89 | 0.47 | 69.13 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet32-84213ce6.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet32/default.log) 67 | | resnet44 | 71.63 | 91.58 | 0.67 | 97.44 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet44-ffe32858.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet44/default.log) 68 | | resnet56 | 72.63 | 91.94 | 0.86 | 125.75 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet56-f2eff4c8.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet56/default.log) 69 | | vgg11_bn | 70.78 | 88.87 | 9.80 | 153.34 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg11_bn-57d0759e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg11_bn/default.log) 70 | | vgg13_bn | 74.63 | 91.09 | 9.99 | 228.84 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg13_bn-5ebe5778.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg13_bn/default.log) 71 | | vgg16_bn | 74.00 | 90.56 | 15.30 | 313.77 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg16_bn-7d8c4031.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg16_bn/default.log) 72 | | vgg19_bn | 73.87 | 90.13 | 20.61 | 398.71 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg19_bn-b98f7bd7.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg19_bn/default.log) 73 | | mobilenetv2_x0_5 | 70.88 | 91.72 | 0.82 | 28.08 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_5-9f915757.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x0_5/default.log) 74 | | mobilenetv2_x0_75 | 73.61 | 92.61 | 1.48 | 59.43 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_75-d7891e60.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x0_75/default.log) 75 | | mobilenetv2_x1_0 | 74.20 | 92.82 | 2.35 | 88.09 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_0-1311f9ff.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x1_0/default.log) 76 | | mobilenetv2_x1_4 | 75.98 | 93.44 | 4.50 | 170.23 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_4-8a269f5e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x1_4/default.log) 77 | | shufflenetv2_x0_5 | 67.82 | 89.93 | 0.44 | 10.99 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x0_5-1977720f.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x0_5/default.log) 78 | | shufflenetv2_x1_0 | 72.39 | 91.46 | 1.36 | 45.09 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_0-9ae22beb.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x1_0/default.log) 79 | | shufflenetv2_x1_5 | 73.91 | 92.13 | 2.58 | 94.35 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_5-e2c85ad8.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x1_5/default.log) 80 | | shufflenetv2_x2_0 | 75.35 | 92.62 | 5.55 | 188.00 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x2_0-e7e584cd.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x2_0/default.log) 81 | | repvgg_a0 | 75.22 | 92.93 | 7.96 | 489.19 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a0-2df1edd0.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/repvgg_a0/default.log) 82 | | repvgg_a1 | 76.12 | 92.71 | 12.94 | 851.44 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a1-c06b21a7.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/repvgg_a1/default.log) 83 | | repvgg_a2 | 77.18 | 93.51 | 26.94 | 1850.22 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a2-8e71b1f8.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/repvgg_a2/default.log) 84 | 85 | If you want to cite this repo: 86 | 87 | ``` 88 | @misc{chenyaofo_pytorch_cifar_models, 89 | author = {Yaofo Chen}, 90 | title = {PyTorch CIFAR Models}, 91 | howpublished = {\url{https://github.com/chenyaofo/pytorch-cifar-models }}, 92 | note = {Accessed: 2025-5-17}, 93 | abstract = {This repository provides pretrained neural network models trained on CIFAR-10 and CIFAR-100 datasets, including ResNet, VGG, MobileNetV2, ShuffleNetV2, and RepVGG variants.} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /colab/start_on_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "start_on_colab.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPIpbiKdUC+5q3/Vlm8dk4L" 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "code", 23 | "metadata": { 24 | "id": "PwHOLr9y7S3Q" 25 | }, 26 | "source": [ 27 | "from IPython.display import HTML, display\n", 28 | "\n", 29 | "def set_css():\n", 30 | " display(HTML('''\n", 31 | " \n", 36 | " '''))\n", 37 | "get_ipython().events.register('pre_run_cell', set_css)" 38 | ], 39 | "execution_count": null, 40 | "outputs": [] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "metadata": { 45 | "id": "zXkQlZUJ8Lnm" 46 | }, 47 | "source": [ 48 | "!git clone --recursive https://github.com/chenyaofo/image-classification-codebase\n", 49 | "%cd image-classification-codebase\n", 50 | "\n", 51 | "%pip install -qr requirements.txt\n", 52 | "\n", 53 | "import torch\n", 54 | "from IPython.display import clear_output\n", 55 | "\n", 56 | "clear_output()\n", 57 | "print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))" 58 | ], 59 | "execution_count": null, 60 | "outputs": [] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": { 65 | "id": "Ijy9oBmN6ovz" 66 | }, 67 | "source": [ 68 | "The following command list all the available models." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "metadata": { 74 | "colab": { 75 | "base_uri": "https://localhost:8080/", 76 | "height": 69 77 | }, 78 | "id": "y-AG3Ne76dOv", 79 | "outputId": "5607e9cb-31aa-48d6-a465-65662eebb653" 80 | }, 81 | "source": [ 82 | "import torch\n", 83 | "print(torch.hub.list(\"chenyaofo/pytorch-cifar-models\", force_reload=True))" 84 | ], 85 | "execution_count": null, 86 | "outputs": [ 87 | { 88 | "output_type": "display_data", 89 | "data": { 90 | "text/html": [ 91 | "\n", 92 | " \n", 97 | " " 98 | ], 99 | "text/plain": [ 100 | "" 101 | ] 102 | }, 103 | "metadata": { 104 | "tags": [] 105 | } 106 | }, 107 | { 108 | "output_type": "stream", 109 | "text": [ 110 | "Downloading: \"https://github.com/chenyaofo/pytorch-cifar-models/archive/master.zip\" to /root/.cache/torch/hub/master.zip\n" 111 | ], 112 | "name": "stderr" 113 | }, 114 | { 115 | "output_type": "stream", 116 | "text": [ 117 | "['cifar100_resnet20', 'cifar100_resnet32', 'cifar100_resnet44', 'cifar100_resnet56', 'cifar100_vgg11_bn', 'cifar100_vgg13_bn', 'cifar100_vgg16_bn', 'cifar100_vgg19_bn', 'cifar10_resnet20', 'cifar10_resnet32', 'cifar10_resnet44', 'cifar10_resnet56', 'cifar10_vgg11_bn', 'cifar10_vgg13_bn', 'cifar10_vgg16_bn', 'cifar10_vgg19_bn']\n" 118 | ], 119 | "name": "stdout" 120 | } 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": { 126 | "id": "4jofqHNwAgvS" 127 | }, 128 | "source": [ 129 | "To train/evaluate different models, please set `model.model_name` to available model names listed in the previous block." 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "YA5-PFL8BPUq" 136 | }, 137 | "source": [ 138 | "**Train on CIFAR-10:**" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "metadata": { 144 | "id": "BM_9IlZ72D9l" 145 | }, 146 | "source": [ 147 | "!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20" 148 | ], 149 | "execution_count": null, 150 | "outputs": [] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": { 155 | "id": "QwrIVjWQNV82" 156 | }, 157 | "source": [ 158 | "**Evaluate on CIFAR-10:**" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "DMd_USWH81Ov" 165 | }, 166 | "source": [ 167 | "!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20 model.pretrained=true only_evaluate=true" 168 | ], 169 | "execution_count": null, 170 | "outputs": [] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "id": "aKYW0yNINoSk" 176 | }, 177 | "source": [ 178 | "**Train on CIFAR-100:**" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "m2bVfrEjNoHY" 185 | }, 186 | "source": [ 187 | "!python -m entry.run --conf conf/cifar100.conf -o output/cifar100/resnet20 -M model.name=cifar100_resnet20" 188 | ], 189 | "execution_count": null, 190 | "outputs": [] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": { 195 | "id": "PgcF8E8NNxgp" 196 | }, 197 | "source": [ 198 | "**Evaluate on CIFAR-100:**" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "metadata": { 204 | "id": "5k-sCfhqNxDc" 205 | }, 206 | "source": [ 207 | "!python -m entry.run --conf conf/cifar100.conf -o output/cifar100/resnet20 -M model.name=cifar100_resnet20 model.pretrained=true only_evaluate=true" 208 | ], 209 | "execution_count": null, 210 | "outputs": [] 211 | } 212 | ] 213 | } -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import pytorch_cifar_models 2 | 3 | dependencies = ['torch'] 4 | 5 | models = filter(lambda name: name.startswith("cifar"), dir(pytorch_cifar_models)) 6 | globals().update({model: getattr(pytorch_cifar_models, model) for model in models}) 7 | -------------------------------------------------------------------------------- /pytorch_cifar_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import cifar10_resnet20 2 | from .resnet import cifar10_resnet32 3 | from .resnet import cifar10_resnet44 4 | from .resnet import cifar10_resnet56 5 | 6 | from .resnet import cifar100_resnet20 7 | from .resnet import cifar100_resnet32 8 | from .resnet import cifar100_resnet44 9 | from .resnet import cifar100_resnet56 10 | 11 | from .vgg import cifar10_vgg11_bn 12 | from .vgg import cifar10_vgg13_bn 13 | from .vgg import cifar10_vgg16_bn 14 | from .vgg import cifar10_vgg19_bn 15 | 16 | from .vgg import cifar100_vgg11_bn 17 | from .vgg import cifar100_vgg13_bn 18 | from .vgg import cifar100_vgg16_bn 19 | from .vgg import cifar100_vgg19_bn 20 | 21 | from .mobilenetv2 import cifar10_mobilenetv2_x0_5 22 | from .mobilenetv2 import cifar10_mobilenetv2_x0_75 23 | from .mobilenetv2 import cifar10_mobilenetv2_x1_0 24 | from .mobilenetv2 import cifar10_mobilenetv2_x1_4 25 | 26 | from .mobilenetv2 import cifar100_mobilenetv2_x0_5 27 | from .mobilenetv2 import cifar100_mobilenetv2_x0_75 28 | from .mobilenetv2 import cifar100_mobilenetv2_x1_0 29 | from .mobilenetv2 import cifar100_mobilenetv2_x1_4 30 | 31 | from .shufflenetv2 import cifar10_shufflenetv2_x0_5 32 | from .shufflenetv2 import cifar10_shufflenetv2_x1_0 33 | from .shufflenetv2 import cifar10_shufflenetv2_x1_5 34 | from .shufflenetv2 import cifar10_shufflenetv2_x2_0 35 | 36 | from .shufflenetv2 import cifar100_shufflenetv2_x0_5 37 | from .shufflenetv2 import cifar100_shufflenetv2_x1_0 38 | from .shufflenetv2 import cifar100_shufflenetv2_x1_5 39 | from .shufflenetv2 import cifar100_shufflenetv2_x2_0 40 | 41 | from .repvgg import cifar10_repvgg_a0 42 | from .repvgg import cifar10_repvgg_a1 43 | from .repvgg import cifar10_repvgg_a2 44 | 45 | from .repvgg import cifar100_repvgg_a0 46 | from .repvgg import cifar100_repvgg_a1 47 | from .repvgg import cifar100_repvgg_a2 48 | 49 | from .vit import cifar10_vit_b16 50 | from .vit import cifar10_vit_b32 51 | from .vit import cifar10_vit_l16 52 | from .vit import cifar10_vit_l32 53 | from .vit import cifar10_vit_h14 54 | 55 | from .vit import cifar100_vit_b16 56 | from .vit import cifar100_vit_b32 57 | from .vit import cifar100_vit_l16 58 | from .vit import cifar100_vit_l32 59 | from .vit import cifar100_vit_h14 60 | 61 | __version__ = "0.1.0-alpha" 62 | -------------------------------------------------------------------------------- /pytorch_cifar_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/mobilenetv2.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | import sys 35 | import torch 36 | from torch import nn 37 | from torch import Tensor 38 | try: 39 | from torch.hub import load_state_dict_from_url 40 | except ImportError: 41 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 42 | 43 | from functools import partial 44 | from typing import Dict, Type, Any, Callable, Union, List, Optional 45 | 46 | 47 | cifar10_pretrained_weight_urls = { 48 | 'mobilenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_5-ca14ced9.pt', 49 | 'mobilenetv2_x0_75': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_75-a53c314e.pt', 50 | 'mobilenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_0-fe6a5b48.pt', 51 | 'mobilenetv2_x1_4': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_4-3bbbd6e2.pt', 52 | } 53 | 54 | cifar100_pretrained_weight_urls = { 55 | 'mobilenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_5-9f915757.pt', 56 | 'mobilenetv2_x0_75': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_75-d7891e60.pt', 57 | 'mobilenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_0-1311f9ff.pt', 58 | 'mobilenetv2_x1_4': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_4-8a269f5e.pt', 59 | } 60 | 61 | 62 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 63 | """ 64 | This function is taken from the original tf repo. 65 | It ensures that all layers have a channel number that is divisible by 8 66 | It can be seen here: 67 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 68 | """ 69 | if min_value is None: 70 | min_value = divisor 71 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 72 | # Make sure that round down does not go down by more than 10%. 73 | if new_v < 0.9 * v: 74 | new_v += divisor 75 | return new_v 76 | 77 | 78 | class ConvBNActivation(nn.Sequential): 79 | def __init__( 80 | self, 81 | in_planes: int, 82 | out_planes: int, 83 | kernel_size: int = 3, 84 | stride: int = 1, 85 | groups: int = 1, 86 | norm_layer: Optional[Callable[..., nn.Module]] = None, 87 | activation_layer: Optional[Callable[..., nn.Module]] = None, 88 | dilation: int = 1, 89 | ) -> None: 90 | padding = (kernel_size - 1) // 2 * dilation 91 | if norm_layer is None: 92 | norm_layer = nn.BatchNorm2d 93 | if activation_layer is None: 94 | activation_layer = nn.ReLU6 95 | super(ConvBNReLU, self).__init__( 96 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, 97 | bias=False), 98 | norm_layer(out_planes), 99 | activation_layer(inplace=True) 100 | ) 101 | self.out_channels = out_planes 102 | 103 | 104 | # necessary for backwards compatibility 105 | ConvBNReLU = ConvBNActivation 106 | 107 | 108 | class InvertedResidual(nn.Module): 109 | def __init__( 110 | self, 111 | inp: int, 112 | oup: int, 113 | stride: int, 114 | expand_ratio: int, 115 | norm_layer: Optional[Callable[..., nn.Module]] = None 116 | ) -> None: 117 | super(InvertedResidual, self).__init__() 118 | self.stride = stride 119 | assert stride in [1, 2] 120 | 121 | if norm_layer is None: 122 | norm_layer = nn.BatchNorm2d 123 | 124 | hidden_dim = int(round(inp * expand_ratio)) 125 | self.use_res_connect = self.stride == 1 and inp == oup 126 | 127 | layers: List[nn.Module] = [] 128 | if expand_ratio != 1: 129 | # pw 130 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 131 | layers.extend([ 132 | # dw 133 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 134 | # pw-linear 135 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 136 | norm_layer(oup), 137 | ]) 138 | self.conv = nn.Sequential(*layers) 139 | self.out_channels = oup 140 | self._is_cn = stride > 1 141 | 142 | def forward(self, x: Tensor) -> Tensor: 143 | if self.use_res_connect: 144 | return x + self.conv(x) 145 | else: 146 | return self.conv(x) 147 | 148 | 149 | class MobileNetV2(nn.Module): 150 | def __init__( 151 | self, 152 | num_classes: int = 10, 153 | width_mult: float = 1.0, 154 | inverted_residual_setting: Optional[List[List[int]]] = None, 155 | round_nearest: int = 8, 156 | block: Optional[Callable[..., nn.Module]] = None, 157 | norm_layer: Optional[Callable[..., nn.Module]] = None 158 | ) -> None: 159 | """ 160 | MobileNet V2 main class 161 | 162 | Args: 163 | num_classes (int): Number of classes 164 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 165 | inverted_residual_setting: Network structure 166 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 167 | Set to 1 to turn off rounding 168 | block: Module specifying inverted residual building block for mobilenet 169 | norm_layer: Module specifying the normalization layer to use 170 | 171 | """ 172 | super(MobileNetV2, self).__init__() 173 | 174 | if block is None: 175 | block = InvertedResidual 176 | 177 | if norm_layer is None: 178 | norm_layer = nn.BatchNorm2d 179 | 180 | input_channel = 32 181 | last_channel = 1280 182 | 183 | if inverted_residual_setting is None: 184 | inverted_residual_setting = [ 185 | # t, c, n, s 186 | [1, 16, 1, 1], 187 | [6, 24, 2, 1], # NOTE: change stride 2 -> 1 for CIFAR10/100 188 | [6, 32, 3, 2], 189 | [6, 64, 4, 2], 190 | [6, 96, 3, 1], 191 | [6, 160, 3, 2], 192 | [6, 320, 1, 1], 193 | ] 194 | 195 | # only check the first element, assuming user knows t,c,n,s are required 196 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 197 | raise ValueError("inverted_residual_setting should be non-empty " 198 | "or a 4-element list, got {}".format(inverted_residual_setting)) 199 | 200 | # building first layer 201 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 202 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 203 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=1, norm_layer=norm_layer)] # NOTE: change stride 2 -> 1 for CIFAR10/100 204 | # building inverted residual blocks 205 | for t, c, n, s in inverted_residual_setting: 206 | output_channel = _make_divisible(c * width_mult, round_nearest) 207 | for i in range(n): 208 | stride = s if i == 0 else 1 209 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 210 | input_channel = output_channel 211 | # building last several layers 212 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 213 | # make it nn.Sequential 214 | self.features = nn.Sequential(*features) 215 | 216 | # building classifier 217 | self.classifier = nn.Sequential( 218 | nn.Dropout(0.2), 219 | nn.Linear(self.last_channel, num_classes), 220 | ) 221 | 222 | # weight initialization 223 | for m in self.modules(): 224 | if isinstance(m, nn.Conv2d): 225 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 226 | if m.bias is not None: 227 | nn.init.zeros_(m.bias) 228 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 229 | nn.init.ones_(m.weight) 230 | nn.init.zeros_(m.bias) 231 | elif isinstance(m, nn.Linear): 232 | nn.init.normal_(m.weight, 0, 0.01) 233 | nn.init.zeros_(m.bias) 234 | 235 | def _forward_impl(self, x: Tensor) -> Tensor: 236 | # This exists since TorchScript doesn't support inheritance, so the superclass method 237 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 238 | x = self.features(x) 239 | # Cannot use "squeeze" as batch-size can be 1 240 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) 241 | x = torch.flatten(x, 1) 242 | x = self.classifier(x) 243 | return x 244 | 245 | def forward(self, x: Tensor) -> Tensor: 246 | return self._forward_impl(x) 247 | 248 | 249 | def _mobilenet_v2( 250 | arch: str, 251 | width_mult: List[int], 252 | model_urls: Dict[str, str], 253 | progress: bool = True, 254 | pretrained: bool = False, 255 | **kwargs: Any 256 | ) -> MobileNetV2: 257 | model = MobileNetV2(width_mult=width_mult, **kwargs) 258 | if pretrained: 259 | state_dict = load_state_dict_from_url(model_urls[arch], 260 | progress=progress) 261 | model.load_state_dict(state_dict) 262 | return model 263 | 264 | 265 | def cifar10_mobilenetv2_x0_5(*args, **kwargs) -> MobileNetV2: pass 266 | def cifar10_mobilenetv2_x0_75(*args, **kwargs) -> MobileNetV2: pass 267 | def cifar10_mobilenetv2_x1_0(*args, **kwargs) -> MobileNetV2: pass 268 | def cifar10_mobilenetv2_x1_4(*args, **kwargs) -> MobileNetV2: pass 269 | 270 | 271 | def cifar100_mobilenetv2_x0_5(*args, **kwargs) -> MobileNetV2: pass 272 | def cifar100_mobilenetv2_x0_75(*args, **kwargs) -> MobileNetV2: pass 273 | def cifar100_mobilenetv2_x1_0(*args, **kwargs) -> MobileNetV2: pass 274 | def cifar100_mobilenetv2_x1_4(*args, **kwargs) -> MobileNetV2: pass 275 | 276 | 277 | thismodule = sys.modules[__name__] 278 | for dataset in ["cifar10", "cifar100"]: 279 | for width_mult, model_name in zip([0.5, 0.75, 1.0, 1.4], 280 | ["mobilenetv2_x0_5", "mobilenetv2_x0_75", "mobilenetv2_x1_0", "mobilenetv2_x1_4"]): 281 | method_name = f"{dataset}_{model_name}" 282 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 283 | num_classes = 10 if dataset == "cifar10" else 100 284 | setattr( 285 | thismodule, 286 | method_name, 287 | partial(_mobilenet_v2, 288 | arch=model_name, 289 | width_mult=width_mult, 290 | model_urls=model_urls, 291 | num_classes=num_classes) 292 | ) 293 | -------------------------------------------------------------------------------- /pytorch_cifar_models/repvgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/DingXiaoH/RepVGG/main/repvgg.py 3 | 4 | MIT License 5 | 6 | Copyright (c) 2020 DingXiaoH 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | ''' 26 | import sys 27 | import copy 28 | import torch 29 | 30 | import torch.nn as nn 31 | import numpy as np 32 | 33 | try: 34 | from torch.hub import load_state_dict_from_url 35 | except ImportError: 36 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 37 | from functools import partial 38 | from typing import Union, List, Dict, Any, cast 39 | 40 | cifar10_pretrained_weight_urls = { 41 | 'repvgg_a0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a0-ef08a50e.pt', 42 | 'repvgg_a1': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a1-38d2431b.pt', 43 | 'repvgg_a2': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a2-09488915.pt', 44 | } 45 | 46 | cifar100_pretrained_weight_urls = { 47 | 'repvgg_a0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a0-2df1edd0.pt', 48 | 'repvgg_a1': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a1-c06b21a7.pt', 49 | 'repvgg_a2': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a2-8e71b1f8.pt', 50 | } 51 | 52 | 53 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): 54 | result = nn.Sequential() 55 | result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 56 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) 57 | result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) 58 | return result 59 | 60 | 61 | class RepVGGBlock(nn.Module): 62 | 63 | def __init__(self, in_channels, out_channels, kernel_size, 64 | stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False): 65 | super(RepVGGBlock, self).__init__() 66 | self.deploy = deploy 67 | self.groups = groups 68 | self.in_channels = in_channels 69 | 70 | assert kernel_size == 3 71 | assert padding == 1 72 | 73 | padding_11 = padding - kernel_size // 2 74 | 75 | self.nonlinearity = nn.ReLU() 76 | 77 | if deploy: 78 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 79 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode) 80 | 81 | else: 82 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None 83 | self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) 84 | self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups) 85 | # print('RepVGG Block, identity = ', self.rbr_identity) 86 | 87 | def forward(self, inputs): 88 | if hasattr(self, 'rbr_reparam'): 89 | return self.nonlinearity(self.rbr_reparam(inputs)) 90 | 91 | if self.rbr_identity is None: 92 | id_out = 0 93 | else: 94 | id_out = self.rbr_identity(inputs) 95 | 96 | return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) 97 | 98 | 99 | # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. 100 | # You can get the equivalent kernel and bias at any time and do whatever you want, 101 | # for example, apply some penalties or constraints during training, just like you do to the other models. 102 | # May be useful for quantization or pruning. 103 | 104 | 105 | def get_equivalent_kernel_bias(self): 106 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) 107 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) 108 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) 109 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid 110 | 111 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 112 | if kernel1x1 is None: 113 | return 0 114 | else: 115 | return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) 116 | 117 | def _fuse_bn_tensor(self, branch): 118 | if branch is None: 119 | return 0, 0 120 | if isinstance(branch, nn.Sequential): 121 | kernel = branch.conv.weight 122 | running_mean = branch.bn.running_mean 123 | running_var = branch.bn.running_var 124 | gamma = branch.bn.weight 125 | beta = branch.bn.bias 126 | eps = branch.bn.eps 127 | else: 128 | assert isinstance(branch, nn.BatchNorm2d) 129 | if not hasattr(self, 'id_tensor'): 130 | input_dim = self.in_channels // self.groups 131 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) 132 | for i in range(self.in_channels): 133 | kernel_value[i, i % input_dim, 1, 1] = 1 134 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) 135 | kernel = self.id_tensor 136 | running_mean = branch.running_mean 137 | running_var = branch.running_var 138 | gamma = branch.weight 139 | beta = branch.bias 140 | eps = branch.eps 141 | std = (running_var + eps).sqrt() 142 | t = (gamma / std).reshape(-1, 1, 1, 1) 143 | return kernel * t, beta - running_mean * gamma / std 144 | 145 | def switch_to_deploy(self): 146 | if hasattr(self, 'rbr_reparam'): 147 | return 148 | kernel, bias = self.get_equivalent_kernel_bias() 149 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels, 150 | kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, 151 | padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True) 152 | self.rbr_reparam.weight.data = kernel 153 | self.rbr_reparam.bias.data = bias 154 | for para in self.parameters(): 155 | para.detach_() 156 | self.__delattr__('rbr_dense') 157 | self.__delattr__('rbr_1x1') 158 | if hasattr(self, 'rbr_identity'): 159 | self.__delattr__('rbr_identity') 160 | 161 | 162 | class RepVGG(nn.Module): 163 | 164 | def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False): 165 | super(RepVGG, self).__init__() 166 | 167 | assert len(width_multiplier) == 4 168 | 169 | self.deploy = deploy 170 | self.override_groups_map = override_groups_map or dict() 171 | 172 | assert 0 not in self.override_groups_map 173 | 174 | self.in_planes = min(64, int(64 * width_multiplier[0])) 175 | 176 | self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, 177 | kernel_size=3, stride=1, padding=1, deploy=self.deploy) # NOTE: change stride 2 -> 1 for CIFAR10/100 178 | self.cur_layer_idx = 1 179 | self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1) # NOTE: change stride 2 -> 1 for CIFAR10/100 180 | self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2) 181 | self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2) 182 | self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2) 183 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 184 | self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) 185 | 186 | def _make_stage(self, planes, num_blocks, stride): 187 | strides = [stride] + [1]*(num_blocks-1) 188 | blocks = [] 189 | for stride in strides: 190 | cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) 191 | blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, 192 | stride=stride, padding=1, groups=cur_groups, deploy=self.deploy)) 193 | self.in_planes = planes 194 | self.cur_layer_idx += 1 195 | return nn.Sequential(*blocks) 196 | 197 | def forward(self, x): 198 | out = self.stage0(x) 199 | out = self.stage1(out) 200 | out = self.stage2(out) 201 | out = self.stage3(out) 202 | out = self.stage4(out) 203 | out = self.gap(out) 204 | out = out.view(out.size(0), -1) 205 | out = self.linear(out) 206 | return out 207 | 208 | def convert_to_inference_model(self, do_copy=False): 209 | model = copy.deepcopy(self) if do_copy else self 210 | for module in model.modules(): 211 | if hasattr(module, 'switch_to_deploy'): 212 | module.switch_to_deploy() 213 | 214 | 215 | optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] 216 | g2_map = {l: 2 for l in optional_groupwise_layers} 217 | g4_map = {l: 4 for l in optional_groupwise_layers} 218 | 219 | 220 | def _repvgg(arch: str, num_blocks: List[int], width_multiplier: List[float], 221 | model_urls: Dict[str, str], 222 | pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG: 223 | model = RepVGG(num_blocks=num_blocks, width_multiplier=width_multiplier, 224 | override_groups_map=None, **kwargs) 225 | if pretrained: 226 | state_dict = load_state_dict_from_url(model_urls[arch], 227 | progress=progress) 228 | model.load_state_dict(state_dict) 229 | return model 230 | 231 | 232 | def cifar10_repvgg_a0(*args, **kwargs) -> RepVGG: pass 233 | def cifar10_repvgg_a1(*args, **kwargs) -> RepVGG: pass 234 | def cifar10_repvgg_a2(*args, **kwargs) -> RepVGG: pass 235 | 236 | 237 | def cifar100_repvgg_a0(*args, **kwargs) -> RepVGG: pass 238 | def cifar100_repvgg_a1(*args, **kwargs) -> RepVGG: pass 239 | def cifar100_repvgg_a2(*args, **kwargs) -> RepVGG: pass 240 | 241 | 242 | thismodule = sys.modules[__name__] 243 | for dataset in ["cifar10", "cifar100"]: 244 | for num_blocks, width_multiplier, model_name in\ 245 | zip([[2, 4, 14, 1], [2, 4, 14, 1], [2, 4, 14, 1]], 246 | [[0.75, 0.75, 0.75, 2.5], [1, 1, 1, 2.5], [1.5, 1.5, 1.5, 2.75]], 247 | ["repvgg_a0", "repvgg_a1", "repvgg_a2"]): 248 | method_name = f"{dataset}_{model_name}" 249 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 250 | num_classes = 10 if dataset == "cifar10" else 100 251 | setattr( 252 | thismodule, 253 | method_name, 254 | partial(_repvgg, 255 | arch=model_name, 256 | num_blocks=num_blocks, 257 | width_multiplier=width_multiplier, 258 | model_urls=model_urls, 259 | num_classes=num_classes) 260 | ) 261 | -------------------------------------------------------------------------------- /pytorch_cifar_models/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/resnet.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | import sys 35 | import torch.nn as nn 36 | try: 37 | from torch.hub import load_state_dict_from_url 38 | except ImportError: 39 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 40 | 41 | from functools import partial 42 | from typing import Dict, Type, Any, Callable, Union, List, Optional 43 | 44 | 45 | cifar10_pretrained_weight_urls = { 46 | 'resnet20': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt', 47 | 'resnet32': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet32-ef93fc4d.pt', 48 | 'resnet44': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt', 49 | 'resnet56': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet56-187c023a.pt', 50 | } 51 | 52 | cifar100_pretrained_weight_urls = { 53 | 'resnet20': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet20-23dac2f1.pt', 54 | 'resnet32': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet32-84213ce6.pt', 55 | 'resnet44': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet44-ffe32858.pt', 56 | 'resnet56': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet56-f2eff4c8.pt', 57 | } 58 | 59 | 60 | def conv3x3(in_planes, out_planes, stride=1): 61 | """3x3 convolution with padding""" 62 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 63 | 64 | 65 | def conv1x1(in_planes, out_planes, stride=1): 66 | """1x1 convolution""" 67 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 68 | 69 | 70 | class BasicBlock(nn.Module): 71 | expansion = 1 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None): 74 | super(BasicBlock, self).__init__() 75 | self.conv1 = conv3x3(inplanes, planes, stride) 76 | self.bn1 = nn.BatchNorm2d(planes) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.conv2 = conv3x3(planes, planes) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class CifarResNet(nn.Module): 103 | 104 | def __init__(self, block, layers, num_classes=10): 105 | super(CifarResNet, self).__init__() 106 | self.inplanes = 16 107 | self.conv1 = conv3x3(3, 16) 108 | self.bn1 = nn.BatchNorm2d(16) 109 | self.relu = nn.ReLU(inplace=True) 110 | 111 | self.layer1 = self._make_layer(block, 16, layers[0]) 112 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 114 | 115 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 116 | self.fc = nn.Linear(64 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | conv1x1(self.inplanes, planes * block.expansion, stride), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for _ in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | 146 | x = self.layer1(x) 147 | x = self.layer2(x) 148 | x = self.layer3(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def _resnet( 158 | arch: str, 159 | layers: List[int], 160 | model_urls: Dict[str, str], 161 | progress: bool = True, 162 | pretrained: bool = False, 163 | **kwargs: Any 164 | ) -> CifarResNet: 165 | model = CifarResNet(BasicBlock, layers, **kwargs) 166 | if pretrained: 167 | state_dict = load_state_dict_from_url(model_urls[arch], 168 | progress=progress) 169 | model.load_state_dict(state_dict) 170 | return model 171 | 172 | 173 | def cifar10_resnet20(*args, **kwargs) -> CifarResNet: pass 174 | def cifar10_resnet32(*args, **kwargs) -> CifarResNet: pass 175 | def cifar10_resnet44(*args, **kwargs) -> CifarResNet: pass 176 | def cifar10_resnet56(*args, **kwargs) -> CifarResNet: pass 177 | 178 | 179 | def cifar100_resnet20(*args, **kwargs) -> CifarResNet: pass 180 | def cifar100_resnet32(*args, **kwargs) -> CifarResNet: pass 181 | def cifar100_resnet44(*args, **kwargs) -> CifarResNet: pass 182 | def cifar100_resnet56(*args, **kwargs) -> CifarResNet: pass 183 | 184 | 185 | thismodule = sys.modules[__name__] 186 | for dataset in ["cifar10", "cifar100"]: 187 | for layers, model_name in zip([[3]*3, [5]*3, [7]*3, [9]*3], 188 | ["resnet20", "resnet32", "resnet44", "resnet56"]): 189 | method_name = f"{dataset}_{model_name}" 190 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 191 | num_classes = 10 if dataset == "cifar10" else 100 192 | setattr( 193 | thismodule, 194 | method_name, 195 | partial(_resnet, 196 | arch=model_name, 197 | layers=layers, 198 | model_urls=model_urls, 199 | num_classes=num_classes) 200 | ) 201 | -------------------------------------------------------------------------------- /pytorch_cifar_models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/shufflenetv2.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | import sys 35 | import torch 36 | import torch.nn as nn 37 | from torch import Tensor 38 | 39 | try: 40 | from torch.hub import load_state_dict_from_url 41 | except ImportError: 42 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 43 | 44 | from functools import partial 45 | from typing import Dict, Type, Any, Callable, Union, List, Optional 46 | 47 | 48 | cifar10_pretrained_weight_urls = { 49 | 'shufflenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x0_5-1308b4e9.pt', 50 | 'shufflenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_0-98807be3.pt', 51 | 'shufflenetv2_x1_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_5-296694dd.pt', 52 | 'shufflenetv2_x2_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x2_0-ec31611c.pt', 53 | } 54 | 55 | cifar100_pretrained_weight_urls = { 56 | 'shufflenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x0_5-1977720f.pt', 57 | 'shufflenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_0-9ae22beb.pt', 58 | 'shufflenetv2_x1_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_5-e2c85ad8.pt', 59 | 'shufflenetv2_x2_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x2_0-e7e584cd.pt', 60 | } 61 | 62 | 63 | def channel_shuffle(x: Tensor, groups: int) -> Tensor: 64 | batchsize, num_channels, height, width = x.size() 65 | channels_per_group = num_channels // groups 66 | 67 | # reshape 68 | x = x.view(batchsize, groups, 69 | channels_per_group, height, width) 70 | 71 | x = torch.transpose(x, 1, 2).contiguous() 72 | 73 | # flatten 74 | x = x.view(batchsize, -1, height, width) 75 | 76 | return x 77 | 78 | 79 | class InvertedResidual(nn.Module): 80 | def __init__( 81 | self, 82 | inp: int, 83 | oup: int, 84 | stride: int 85 | ) -> None: 86 | super(InvertedResidual, self).__init__() 87 | 88 | if not (1 <= stride <= 3): 89 | raise ValueError('illegal stride value') 90 | self.stride = stride 91 | 92 | branch_features = oup // 2 93 | assert (self.stride != 1) or (inp == branch_features << 1) 94 | 95 | if self.stride > 1: 96 | self.branch1 = nn.Sequential( 97 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 98 | nn.BatchNorm2d(inp), 99 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 100 | nn.BatchNorm2d(branch_features), 101 | nn.ReLU(inplace=True), 102 | ) 103 | else: 104 | self.branch1 = nn.Sequential() 105 | 106 | self.branch2 = nn.Sequential( 107 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 108 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 109 | nn.BatchNorm2d(branch_features), 110 | nn.ReLU(inplace=True), 111 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 112 | nn.BatchNorm2d(branch_features), 113 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 114 | nn.BatchNorm2d(branch_features), 115 | nn.ReLU(inplace=True), 116 | ) 117 | 118 | @staticmethod 119 | def depthwise_conv( 120 | i: int, 121 | o: int, 122 | kernel_size: int, 123 | stride: int = 1, 124 | padding: int = 0, 125 | bias: bool = False 126 | ) -> nn.Conv2d: 127 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 128 | 129 | def forward(self, x: Tensor) -> Tensor: 130 | if self.stride == 1: 131 | x1, x2 = x.chunk(2, dim=1) 132 | out = torch.cat((x1, self.branch2(x2)), dim=1) 133 | else: 134 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 135 | 136 | out = channel_shuffle(out, 2) 137 | 138 | return out 139 | 140 | 141 | class ShuffleNetV2(nn.Module): 142 | def __init__( 143 | self, 144 | stages_repeats: List[int], 145 | stages_out_channels: List[int], 146 | num_classes: int = 1000, 147 | inverted_residual: Callable[..., nn.Module] = InvertedResidual 148 | ) -> None: 149 | super(ShuffleNetV2, self).__init__() 150 | 151 | if len(stages_repeats) != 3: 152 | raise ValueError('expected stages_repeats as list of 3 positive ints') 153 | if len(stages_out_channels) != 5: 154 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 155 | self._stage_out_channels = stages_out_channels 156 | 157 | input_channels = 3 158 | output_channels = self._stage_out_channels[0] 159 | self.conv1 = nn.Sequential( 160 | nn.Conv2d(input_channels, output_channels, 3, 1, 1, bias=False), # NOTE: change stride 2 -> 1 for CIFAR10/100 161 | nn.BatchNorm2d(output_channels), 162 | nn.ReLU(inplace=True), 163 | ) 164 | input_channels = output_channels 165 | 166 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) NOTE: remove this maxpool layer for CIFAR10/100 167 | 168 | # Static annotations for mypy 169 | self.stage2: nn.Sequential 170 | self.stage3: nn.Sequential 171 | self.stage4: nn.Sequential 172 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 173 | for name, repeats, output_channels in zip( 174 | stage_names, stages_repeats, self._stage_out_channels[1:]): 175 | seq = [inverted_residual(input_channels, output_channels, 2)] 176 | for i in range(repeats - 1): 177 | seq.append(inverted_residual(output_channels, output_channels, 1)) 178 | setattr(self, name, nn.Sequential(*seq)) 179 | input_channels = output_channels 180 | 181 | output_channels = self._stage_out_channels[-1] 182 | self.conv5 = nn.Sequential( 183 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 184 | nn.BatchNorm2d(output_channels), 185 | nn.ReLU(inplace=True), 186 | ) 187 | 188 | self.fc = nn.Linear(output_channels, num_classes) 189 | 190 | def _forward_impl(self, x: Tensor) -> Tensor: 191 | # See note [TorchScript super()] 192 | x = self.conv1(x) 193 | # x = self.maxpool(x) NOTE: remove this maxpool layer for CIFAR10/100 194 | x = self.stage2(x) 195 | x = self.stage3(x) 196 | x = self.stage4(x) 197 | x = self.conv5(x) 198 | x = x.mean([2, 3]) # globalpool 199 | x = self.fc(x) 200 | return x 201 | 202 | def forward(self, x: Tensor) -> Tensor: 203 | return self._forward_impl(x) 204 | 205 | 206 | def _shufflenet_v2( 207 | arch: str, 208 | stages_repeats: List[int], 209 | stages_out_channels: List[int], 210 | model_urls: Dict[str, str], 211 | progress: bool = True, 212 | pretrained: bool = False, 213 | **kwargs: Any 214 | ) -> ShuffleNetV2: 215 | model = ShuffleNetV2(stages_repeats=stages_repeats, stages_out_channels=stages_out_channels, ** kwargs) 216 | if pretrained: 217 | state_dict = load_state_dict_from_url(model_urls[arch], 218 | progress=progress) 219 | model.load_state_dict(state_dict) 220 | return model 221 | 222 | 223 | def cifar10_shufflenetv2_x0_5(*args, **kwargs) -> ShuffleNetV2: pass 224 | def cifar10_shufflenetv2_x1_0(*args, **kwargs) -> ShuffleNetV2: pass 225 | def cifar10_shufflenetv2_x1_5(*args, **kwargs) -> ShuffleNetV2: pass 226 | def cifar10_shufflenetv2_x2_0(*args, **kwargs) -> ShuffleNetV2: pass 227 | 228 | 229 | def cifar100_shufflenetv2_x0_5(*args, **kwargs) -> ShuffleNetV2: pass 230 | def cifar100_shufflenetv2_x1_0(*args, **kwargs) -> ShuffleNetV2: pass 231 | def cifar100_shufflenetv2_x1_5(*args, **kwargs) -> ShuffleNetV2: pass 232 | def cifar100_shufflenetv2_x2_0(*args, **kwargs) -> ShuffleNetV2: pass 233 | 234 | 235 | thismodule = sys.modules[__name__] 236 | for dataset in ["cifar10", "cifar100"]: 237 | for stages_repeats, stages_out_channels, model_name in \ 238 | zip([[4, 8, 4]]*4, 239 | [[24, 48, 96, 192, 1024], [24, 116, 232, 464, 1024], [24, 176, 352, 704, 1024], [24, 244, 488, 976, 2048]], 240 | ["shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0"]): 241 | method_name = f"{dataset}_{model_name}" 242 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 243 | num_classes = 10 if dataset == "cifar10" else 100 244 | setattr( 245 | thismodule, 246 | method_name, 247 | partial(_shufflenet_v2, 248 | arch=model_name, 249 | stages_repeats=stages_repeats, 250 | stages_out_channels=stages_out_channels, 251 | model_urls=model_urls, 252 | num_classes=num_classes) 253 | ) 254 | -------------------------------------------------------------------------------- /pytorch_cifar_models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/vgg.py 3 | 4 | BSD 3-Clause License 5 | 6 | Copyright (c) Soumith Chintala 2016, 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | 12 | * Redistributions of source code must retain the above copyright notice, this 13 | list of conditions and the following disclaimer. 14 | 15 | * Redistributions in binary form must reproduce the above copyright notice, 16 | this list of conditions and the following disclaimer in the documentation 17 | and/or other materials provided with the distribution. 18 | 19 | * Neither the name of the copyright holder nor the names of its 20 | contributors may be used to endorse or promote products derived from 21 | this software without specific prior written permission. 22 | 23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ''' 34 | 35 | import sys 36 | import torch 37 | import torch.nn as nn 38 | try: 39 | from torch.hub import load_state_dict_from_url 40 | except ImportError: 41 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 42 | from functools import partial 43 | from typing import Union, List, Dict, Any, cast 44 | 45 | cifar10_pretrained_weight_urls = { 46 | 'vgg11_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg11_bn-eaeebf42.pt', 47 | 'vgg13_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg13_bn-c01e4a43.pt', 48 | 'vgg16_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg16_bn-6ee7ea24.pt', 49 | 'vgg19_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg19_bn-57191229.pt', 50 | } 51 | 52 | cifar100_pretrained_weight_urls = { 53 | 'vgg11_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg11_bn-57d0759e.pt', 54 | 'vgg13_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg13_bn-5ebe5778.pt', 55 | 'vgg16_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg16_bn-7d8c4031.pt', 56 | 'vgg19_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg19_bn-b98f7bd7.pt', 57 | } 58 | 59 | 60 | class VGG(nn.Module): 61 | 62 | def __init__( 63 | self, 64 | features: nn.Module, 65 | num_classes: int = 10, 66 | init_weights: bool = True 67 | ) -> None: 68 | super(VGG, self).__init__() 69 | self.features = features 70 | self.classifier = nn.Sequential( 71 | nn.Linear(512, 512), 72 | nn.ReLU(True), 73 | nn.Dropout(), 74 | nn.Linear(512, 512), 75 | nn.ReLU(True), 76 | nn.Dropout(), 77 | nn.Linear(512, num_classes), 78 | ) 79 | if init_weights: 80 | self._initialize_weights() 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | x = self.features(x) 84 | x = torch.flatten(x, 1) 85 | x = self.classifier(x) 86 | return x 87 | 88 | def _initialize_weights(self) -> None: 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | if m.bias is not None: 93 | nn.init.constant_(m.bias, 0) 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.constant_(m.weight, 1) 96 | nn.init.constant_(m.bias, 0) 97 | elif isinstance(m, nn.Linear): 98 | nn.init.normal_(m.weight, 0, 0.01) 99 | nn.init.constant_(m.bias, 0) 100 | 101 | 102 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 103 | layers: List[nn.Module] = [] 104 | in_channels = 3 105 | for v in cfg: 106 | if v == 'M': 107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 108 | else: 109 | v = cast(int, v) 110 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 111 | if batch_norm: 112 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 113 | else: 114 | layers += [conv2d, nn.ReLU(inplace=True)] 115 | in_channels = v 116 | return nn.Sequential(*layers) 117 | 118 | 119 | cfgs: Dict[str, List[Union[str, int]]] = { 120 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 121 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 122 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 123 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 124 | } 125 | 126 | 127 | def _vgg(arch: str, cfg: str, batch_norm: bool, 128 | model_urls: Dict[str, str], 129 | pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 130 | if pretrained: 131 | kwargs['init_weights'] = False 132 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 133 | if pretrained: 134 | state_dict = load_state_dict_from_url(model_urls[arch], 135 | progress=progress) 136 | model.load_state_dict(state_dict) 137 | return model 138 | 139 | 140 | def cifar10_vgg11_bn(*args, **kwargs) -> VGG: pass 141 | def cifar10_vgg13_bn(*args, **kwargs) -> VGG: pass 142 | def cifar10_vgg16_bn(*args, **kwargs) -> VGG: pass 143 | def cifar10_vgg19_bn(*args, **kwargs) -> VGG: pass 144 | 145 | 146 | def cifar100_vgg11_bn(*args, **kwargs) -> VGG: pass 147 | def cifar100_vgg13_bn(*args, **kwargs) -> VGG: pass 148 | def cifar100_vgg16_bn(*args, **kwargs) -> VGG: pass 149 | def cifar100_vgg19_bn(*args, **kwargs) -> VGG: pass 150 | 151 | 152 | thismodule = sys.modules[__name__] 153 | for dataset in ["cifar10", "cifar100"]: 154 | for cfg, model_name in zip(["A", "B", "D", "E"], ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]): 155 | method_name = f"{dataset}_{model_name}" 156 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 157 | num_classes = 10 if dataset == "cifar10" else 100 158 | setattr( 159 | thismodule, 160 | method_name, 161 | partial(_vgg, 162 | arch=model_name, 163 | cfg=cfg, 164 | batch_norm=True, 165 | model_urls=model_urls, 166 | num_classes=num_classes) 167 | ) 168 | -------------------------------------------------------------------------------- /pytorch_cifar_models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://raw.githubusercontent.com/jeonsworld/ViT-pytorch/main/models/modeling.py 3 | 4 | MIT License 5 | 6 | Copyright (c) 2020 jeonsworld 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | ''' 26 | 27 | import sys 28 | import copy 29 | import math 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 36 | from torch.nn.modules.utils import _pair 37 | try: 38 | from torch.hub import load_state_dict_from_url 39 | except ImportError: 40 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 41 | 42 | from functools import partial 43 | from typing import Dict, Type, Any, Callable, Union, List, Optional 44 | 45 | cifar10_pretrained_weight_urls = { 46 | 'vit_b16': '', 47 | 'vit_b32': '', 48 | 'vit_l16': '', 49 | 'vit_l32': '', 50 | 'vit_h14': '', 51 | } 52 | 53 | cifar100_pretrained_weight_urls = { 54 | 'vit_b16': '', 55 | 'vit_b32': '', 56 | 'vit_l16': '', 57 | 'vit_l32': '', 58 | 'vit_h14': '', 59 | } 60 | 61 | 62 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query" 63 | ATTENTION_K = "MultiHeadDotProductAttention_1/key" 64 | ATTENTION_V = "MultiHeadDotProductAttention_1/value" 65 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" 66 | FC_0 = "MlpBlock_3/Dense_0" 67 | FC_1 = "MlpBlock_3/Dense_1" 68 | ATTENTION_NORM = "LayerNorm_0" 69 | MLP_NORM = "LayerNorm_2" 70 | 71 | 72 | def swish(x): 73 | return x * torch.sigmoid(x) 74 | 75 | 76 | ACT2FN = {"gelu": F.gelu, "relu": F.relu, "swish": swish} 77 | 78 | 79 | class Attention(nn.Module): 80 | def __init__(self, config, vis): 81 | super(Attention, self).__init__() 82 | self.vis = vis 83 | self.num_attention_heads = config.transformer["num_heads"] 84 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads) 85 | self.all_head_size = self.num_attention_heads * self.attention_head_size 86 | 87 | self.query = Linear(config.hidden_size, self.all_head_size) 88 | self.key = Linear(config.hidden_size, self.all_head_size) 89 | self.value = Linear(config.hidden_size, self.all_head_size) 90 | 91 | self.out = Linear(config.hidden_size, config.hidden_size) 92 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) 93 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) 94 | 95 | self.softmax = Softmax(dim=-1) 96 | 97 | def transpose_for_scores(self, x): 98 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 99 | x = x.view(*new_x_shape) 100 | return x.permute(0, 2, 1, 3) 101 | 102 | def forward(self, hidden_states): 103 | mixed_query_layer = self.query(hidden_states) 104 | mixed_key_layer = self.key(hidden_states) 105 | mixed_value_layer = self.value(hidden_states) 106 | 107 | query_layer = self.transpose_for_scores(mixed_query_layer) 108 | key_layer = self.transpose_for_scores(mixed_key_layer) 109 | value_layer = self.transpose_for_scores(mixed_value_layer) 110 | 111 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 112 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 113 | attention_probs = self.softmax(attention_scores) 114 | weights = attention_probs if self.vis else None 115 | attention_probs = self.attn_dropout(attention_probs) 116 | 117 | context_layer = torch.matmul(attention_probs, value_layer) 118 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 119 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 120 | context_layer = context_layer.view(*new_context_layer_shape) 121 | attention_output = self.out(context_layer) 122 | attention_output = self.proj_dropout(attention_output) 123 | return attention_output, weights 124 | 125 | 126 | class MLP(nn.Module): 127 | def __init__(self, config): 128 | super(MLP, self).__init__() 129 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) 130 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) 131 | self.act_fn = ACT2FN["gelu"] 132 | self.dropout = Dropout(config.transformer["dropout_rate"]) 133 | 134 | self._init_weights() 135 | 136 | def _init_weights(self): 137 | nn.init.xavier_uniform_(self.fc1.weight) 138 | nn.init.xavier_uniform_(self.fc2.weight) 139 | nn.init.normal_(self.fc1.bias, std=1e-6) 140 | nn.init.normal_(self.fc2.bias, std=1e-6) 141 | 142 | def forward(self, x): 143 | x = self.fc1(x) 144 | x = self.act_fn(x) 145 | x = self.dropout(x) 146 | x = self.fc2(x) 147 | x = self.dropout(x) 148 | return x 149 | 150 | 151 | class Embeddings(nn.Module): 152 | def __init__(self, config, img_size, in_channels=3): 153 | super(Embeddings, self).__init__() 154 | img_size = _pair(img_size) 155 | 156 | patch_size = _pair(config.patches["size"]) 157 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) 158 | self.hybrid = False 159 | 160 | self.patch_embeddings = Conv2d(in_channels=in_channels, 161 | out_channels=config.hidden_size, 162 | kernel_size=patch_size, 163 | stride=patch_size) 164 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size)) 165 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) 166 | 167 | self.dropout = Dropout(config.transformer["dropout_rate"]) 168 | 169 | def forward(self, x: torch.Tensor): 170 | B = x.shape[0] 171 | cls_tokens = self.cls_token.expand(B, -1, -1) 172 | 173 | x = self.patch_embeddings(x) 174 | x = x.flatten(2) 175 | x = x.transpose(-1, -2) 176 | x = torch.cat((cls_tokens, x), dim=1) 177 | 178 | embeddings = x + self.position_embeddings 179 | embeddings = self.dropout(embeddings) 180 | return embeddings 181 | 182 | 183 | class Block(nn.Module): 184 | def __init__(self, config, vis): 185 | super(Block, self).__init__() 186 | self.hidden_size = config.hidden_size 187 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) 188 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) 189 | self.ffn = MLP(config) 190 | self.attn = Attention(config, vis) 191 | 192 | def forward(self, x): 193 | h = x 194 | x = self.attention_norm(x) 195 | x, weights = self.attn(x) 196 | x = x + h 197 | 198 | h = x 199 | x = self.ffn_norm(x) 200 | x = self.ffn(x) 201 | x = x + h 202 | return x, weights 203 | 204 | 205 | class Encoder(nn.Module): 206 | def __init__(self, config, vis): 207 | super(Encoder, self).__init__() 208 | self.vis = vis 209 | self.layer = nn.ModuleList() 210 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) 211 | for _ in range(config.transformer["num_layers"]): 212 | layer = Block(config, vis) 213 | self.layer.append(copy.deepcopy(layer)) 214 | 215 | def forward(self, hidden_states): 216 | attn_weights = [] 217 | for layer_block in self.layer: 218 | hidden_states, weights = layer_block(hidden_states) 219 | if self.vis: 220 | attn_weights.append(weights) 221 | encoded = self.encoder_norm(hidden_states) 222 | return encoded, attn_weights 223 | 224 | 225 | class Transformer(nn.Module): 226 | def __init__(self, config, img_size, vis): 227 | super(Transformer, self).__init__() 228 | self.embeddings = Embeddings(config, img_size=img_size) 229 | self.encoder = Encoder(config, vis) 230 | 231 | def forward(self, input_ids): 232 | embedding_output = self.embeddings(input_ids) 233 | encoded, attn_weights = self.encoder(embedding_output) 234 | return encoded, attn_weights 235 | 236 | 237 | class VisionTransformer(nn.Module): 238 | def __init__(self, config, img_size=224, num_classes=10, zero_head=False, vis=False): 239 | super(VisionTransformer, self).__init__() 240 | self.num_classes = num_classes 241 | self.classifier = config.classifier 242 | 243 | self.transformer = Transformer(config, img_size, vis) 244 | self.head = Linear(config.hidden_size, num_classes) 245 | 246 | if zero_head: 247 | with torch.no_grad(): 248 | nn.init.zeros_(self.head.weight) 249 | nn.init.zeros_(self.head.bias) 250 | 251 | def forward(self, x): 252 | x, attn_weights = self.transformer(x) 253 | logits = self.head(x[:, 0]) 254 | 255 | return logits 256 | 257 | 258 | class TestConfig: 259 | patches = dict(size=(16, 16)) 260 | hidden_size = 1 261 | transformer = dict( 262 | mlp_dim=1, 263 | num_heads=1, 264 | num_layers=1, 265 | attention_dropout_rate=0.0, 266 | dropout_rate=0.1 267 | ) 268 | classifier = 'token' 269 | representation_size = None 270 | 271 | 272 | class VitB16Config: 273 | patches = dict(size=(16, 16)) 274 | hidden_size = 768 275 | transformer = dict( 276 | mlp_dim=3072, 277 | num_heads=12, 278 | num_layers=12, 279 | attention_dropout_rate=0.0, 280 | dropout_rate=0.1 281 | ) 282 | classifier = 'token' 283 | representation_size = None 284 | 285 | 286 | class VitB32Config(VitB16Config): 287 | patches = dict(size=(32, 32)) 288 | 289 | 290 | class VitL16Config: 291 | patches = dict(size=(16, 16)) 292 | hidden_size = 1024 293 | transformer = dict( 294 | mlp_dim=4096, 295 | num_heads=16, 296 | num_layers=24, 297 | attention_dropout_rate=0.0, 298 | dropout_rate=0.1 299 | ) 300 | classifier = 'token' 301 | representation_size = None 302 | 303 | 304 | class VitL32Config(VitL16Config): 305 | patches = dict(size=(32, 32)) 306 | 307 | 308 | class VitH14Config: 309 | patches = dict(size=(14, 14)) 310 | hidden_size = 1280 311 | transformer = dict( 312 | mlp_dim=5120, 313 | num_heads=16, 314 | num_layers=32, 315 | attention_dropout_rate=0.0, 316 | dropout_rate=0.1 317 | ) 318 | classifier = 'token' 319 | representation_size = None 320 | 321 | 322 | def _vit( 323 | arch: str, 324 | config: Any, 325 | model_urls: Dict[str, str], 326 | progress: bool = True, 327 | pretrained: bool = False, 328 | **kwargs: Any 329 | ) -> VisionTransformer: 330 | model = VisionTransformer(config=config, **kwargs) 331 | if pretrained: 332 | state_dict = load_state_dict_from_url(model_urls[arch], 333 | progress=progress) 334 | model.load_state_dict(state_dict) 335 | return model 336 | 337 | 338 | def cifar10_vit_b16(*args, **kwargs) -> VisionTransformer: pass 339 | def cifar10_vit_b32(*args, **kwargs) -> VisionTransformer: pass 340 | def cifar10_vit_l16(*args, **kwargs) -> VisionTransformer: pass 341 | def cifar10_vit_l32(*args, **kwargs) -> VisionTransformer: pass 342 | def cifar10_vit_h14(*args, **kwargs) -> VisionTransformer: pass 343 | 344 | 345 | def cifar100_vit_b16(*args, **kwargs) -> VisionTransformer: pass 346 | def cifar100_vit_b32(*args, **kwargs) -> VisionTransformer: pass 347 | def cifar100_vit_l16(*args, **kwargs) -> VisionTransformer: pass 348 | def cifar100_vit_l32(*args, **kwargs) -> VisionTransformer: pass 349 | def cifar100_vit_h14(*args, **kwargs) -> VisionTransformer: pass 350 | 351 | 352 | thismodule = sys.modules[__name__] 353 | for dataset in ["cifar10", "cifar100"]: 354 | for config, model_name in zip([VitB16Config, VitB32Config, VitL16Config, VitL32Config, VitH14Config], 355 | ["vit_b16", "vit_b32", "vit_l16", "vit_l32", "vit_h14"]): 356 | method_name = f"{dataset}_{model_name}" 357 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls 358 | num_classes = 10 if dataset == "cifar10" else 100 359 | setattr( 360 | thismodule, 361 | method_name, 362 | partial(_vit, 363 | arch=model_name, 364 | config=config, 365 | model_urls=model_urls, 366 | num_classes=num_classes) 367 | ) 368 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | from setuptools import setup, find_packages 5 | 6 | 7 | def read(*names, **kwargs): 8 | with io.open( 9 | os.path.join(os.path.dirname(__file__), *names), 10 | encoding=kwargs.get("encoding", "utf8") 11 | ) as fp: 12 | return fp.read() 13 | 14 | 15 | def find_version(*file_paths): 16 | version_file = read(*file_paths) 17 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 18 | if version_match: 19 | return version_match.group(1) 20 | else: 21 | raise RuntimeError("Unable to find version string.") 22 | 23 | 24 | VERSION = find_version('pytorch_cifar_models', '__init__.py') 25 | 26 | 27 | def find_requirements(file_path): 28 | with open(file_path) as f: 29 | return f.read().splitlines() 30 | 31 | 32 | requirements = find_requirements("requirements.txt") 33 | 34 | setup( 35 | name="pytorch_cifar_models", 36 | version=VERSION, 37 | description="Pretrained models for CIFAR10/100 in PyTorch", 38 | url="https://github.com/chenyaofo/pytorch-cifar-models", 39 | author="chenyaofo", 40 | author_email="chenyaofo@gmail.com", 41 | packages=find_packages(exclude=['test']), 42 | install_requires=requirements, 43 | ) -------------------------------------------------------------------------------- /tests/pytorch_cifar_models/test_mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pytorch_cifar_models.mobilenetv2 as mobilenetv2 5 | 6 | 7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"]) 8 | @pytest.mark.parametrize("model_name", ["mobilenetv2_x0_5", "mobilenetv2_x0_75", "mobilenetv2_x1_0", "mobilenetv2_x1_4"]) 9 | def test_resnet(dataset, model_name): 10 | num_classes = 10 if dataset == "cifar10" else 100 11 | model = getattr(mobilenetv2, f"{dataset}_{model_name}")() 12 | x = torch.empty((1, 3, 32, 32)) 13 | y = model(x) 14 | assert y.shape == (1, num_classes) -------------------------------------------------------------------------------- /tests/pytorch_cifar_models/test_repvgg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pytorch_cifar_models.repvgg as repvgg 5 | 6 | 7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"]) 8 | @pytest.mark.parametrize("model_name", ["repvgg_a0", "repvgg_a1", "repvgg_a2"]) 9 | def test_resnet(dataset, model_name): 10 | num_classes = 10 if dataset == "cifar10" else 100 11 | model = getattr(repvgg, f"{dataset}_{model_name}")() 12 | x = torch.empty((1, 3, 32, 32)) 13 | y = model(x) 14 | assert y.shape == (1, num_classes) 15 | -------------------------------------------------------------------------------- /tests/pytorch_cifar_models/test_resnet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pytorch_cifar_models.resnet as resnet 5 | 6 | 7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"]) 8 | @pytest.mark.parametrize("model_name", ["resnet20", "resnet32", "resnet44", "resnet56"]) 9 | def test_resnet(dataset, model_name): 10 | num_classes = 10 if dataset == "cifar10" else 100 11 | model = getattr(resnet, f"{dataset}_{model_name}")() 12 | x = torch.empty((1, 3, 32, 32)) 13 | y = model(x) 14 | assert y.shape == (1, num_classes) 15 | -------------------------------------------------------------------------------- /tests/pytorch_cifar_models/test_shufflenetv2.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pytorch_cifar_models.shufflenetv2 as shufflenetv2 5 | 6 | 7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"]) 8 | @pytest.mark.parametrize("model_name", ["shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0"]) 9 | def test_resnet(dataset, model_name): 10 | num_classes = 10 if dataset == "cifar10" else 100 11 | model = getattr(shufflenetv2, f"{dataset}_{model_name}")() 12 | x = torch.empty((1, 3, 32, 32)) 13 | y = model(x) 14 | assert y.shape == (1, num_classes) 15 | -------------------------------------------------------------------------------- /tests/pytorch_cifar_models/test_vgg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pytorch_cifar_models.vgg as vgg 5 | 6 | 7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"]) 8 | @pytest.mark.parametrize("model_name", ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]) 9 | def test_resnet(dataset, model_name): 10 | num_classes = 10 if dataset == "cifar10" else 100 11 | model = getattr(vgg, f"{dataset}_{model_name}")() 12 | x = torch.empty((1, 3, 32, 32)) 13 | y = model(x) 14 | assert y.shape == (1, num_classes) 15 | -------------------------------------------------------------------------------- /tests/pytorch_cifar_models/test_vit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import pytorch_cifar_models.vit as vit 5 | 6 | 7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"]) 8 | @pytest.mark.parametrize("model_name", ["vit_b16", "vit_b32", "vit_l16", "vit_l32", "vit_h14"]) 9 | def test_resnet(dataset, model_name): 10 | num_classes = 10 if dataset == "cifar10" else 100 11 | model = getattr(vit, f"{dataset}_{model_name}")() 12 | x = torch.empty((1, 3, 224, 224)) 13 | y = model(x) 14 | assert y.shape == (1, num_classes) 15 | --------------------------------------------------------------------------------