├── .gitignore
├── LICENSE
├── README.md
├── colab
└── start_on_colab.ipynb
├── hubconf.py
├── pytorch_cifar_models
├── __init__.py
├── mobilenetv2.py
├── repvgg.py
├── resnet.py
├── shufflenetv2.py
├── vgg.py
└── vit.py
├── setup.py
└── tests
└── pytorch_cifar_models
├── test_mobilenetv2.py
├── test_repvgg.py
├── test_resnet.py
├── test_shufflenetv2.py
├── test_vgg.py
└── test_vit.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.gitignore.io/api/code,python
3 | # Edit at https://www.gitignore.io/?templates=code,python
4 |
5 | ### Code ###
6 | .vscode/*
7 | # !.vscode/settings.json
8 | # !.vscode/tasks.json
9 | # !.vscode/launch.json
10 | # !.vscode/extensions.json
11 |
12 | ### Python ###
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # pipenv
82 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
83 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
84 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
85 | # install all needed dependencies.
86 | #Pipfile.lock
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # Mr Developer
102 | .mr.developer.cfg
103 | .project
104 | .pydevproject
105 |
106 | # mkdocs documentation
107 | /site
108 |
109 | # mypy
110 | .mypy_cache/
111 | .dmypy.json
112 | dmypy.json
113 |
114 | # Pyre type checker
115 | .pyre/
116 |
117 | # End of https://www.gitignore.io/api/code,python
118 |
119 | *.pt
120 | *.pth
121 | *.log
122 | *.0
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, chenyaofo
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch CIFAR Models
2 |
3 | ## Introduction
4 |
5 | The goal of this project is to provide some neural network examples and a simple training codebase for begginners.
6 |
7 | ## Get Started with Google Colab
8 |
9 | **Train Models**: Open the notebook to train the models from scratch on CIFAR10/100.
10 | It will takes several hours depend on the complexity of the model and the allocated GPU type.
11 |
12 | **Test Models**: Open the notebook to measure the validation accuracy on CIFAR10/100 with pretrained models.
13 | It will only take about few seconds.
14 |
15 | ## Use Models with Pytorch Hub
16 |
17 | You can simply use the pretrained models in your project with `torch.hub` API.
18 | It will automatically load the code and the pretrained weights from GitHub
19 | (If you cannot directly access GitHub, please check [this issue](https://github.com/chenyaofo/pytorch-cifar-models/issues/14) for solution).
20 |
21 | ``` python
22 | import torch
23 | model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
24 | ```
25 |
26 | To list all available model entry, you can run:
27 |
28 | ```python
29 | import torch
30 | from pprint import pprint
31 | pprint(torch.hub.list("chenyaofo/pytorch-cifar-models", force_reload=True))
32 | ```
33 |
34 |
35 | ## Model Zoo
36 |
37 | ### CIFAR-10
38 |
39 | | Model | Top-1 Acc.(%) | Top-5 Acc.(%) | #Params.(M) | #MAdds(M) | |
40 | |----------|----------------|---------------|-------------|-----------|--------------------|
41 | | resnet20 | 92.60 | 99.81 | 0.27 | 40.81 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet20/default.log)
42 | | resnet32 | 93.53 | 99.77 | 0.47 | 69.12 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet32-ef93fc4d.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet32/default.log)
43 | | resnet44 | 94.01 | 99.77 | 0.66 | 97.44 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet44/default.log)
44 | | resnet56 | 94.37 | 99.83 | 0.86 | 125.75 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet56-187c023a.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/resnet56/default.log)
45 | | vgg11_bn | 92.79 | 99.72 | 9.76 | 153.29 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg11_bn-eaeebf42.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg11_bn/default.log)
46 | | vgg13_bn | 94.00 | 99.77 | 9.94 | 228.79 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg13_bn-c01e4a43.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg13_bn/default.log)
47 | | vgg16_bn | 94.16 | 99.71 | 15.25 | 313.73 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg16_bn-6ee7ea24.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg16_bn/default.log)
48 | | vgg19_bn | 93.91 | 99.64 | 20.57 | 398.66 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg19_bn-57191229.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/vgg19_bn/default.log)
49 | | mobilenetv2_x0_5 | 92.88 | 99.86 | 0.70 | 27.97 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_5-ca14ced9.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x0_5/default.log)
50 | | mobilenetv2_x0_75 | 93.72 | 99.79 | 1.37 | 59.31 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_75-a53c314e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x0_75/default.log)
51 | | mobilenetv2_x1_0 | 93.79 | 99.73 | 2.24 | 87.98 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_0-fe6a5b48.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x1_0/default.log)
52 | | mobilenetv2_x1_4 | 94.22 | 99.80 | 4.33 | 170.07 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_4-3bbbd6e2.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/mobilenetv2_x1_4/default.log)
53 | | shufflenetv2_x0_5 | 90.13 | 99.70 | 0.35 | 10.90 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x0_5-1308b4e9.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x0_5/default.log)
54 | | shufflenetv2_x1_0 | 92.98 | 99.73 | 1.26 | 45.00 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_0-98807be3.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x1_0/default.log)
55 | | shufflenetv2_x1_5 | 93.55 | 99.77 | 2.49 | 94.26 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_5-296694dd.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x1_5/default.log)
56 | | shufflenetv2_x2_0 | 93.81 | 99.79 | 5.37 | 187.81 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x2_0-ec31611c.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/shufflenetv2_x2_0/default.log)
57 | | repvgg_a0 | 94.39 | 99.82 | 7.84 | 489.08 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a0-ef08a50e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/repvgg_a0/default.log)
58 | | repvgg_a1 | 94.89 | 99.83 | 12.82 | 851.33 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a1-38d2431b.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/repvgg_a1/default.log)
59 | | repvgg_a2 | 94.98 | 99.82 | 26.82 | 1850.10 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a2-09488915.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar10/repvgg_a2/default.log)
60 |
61 | ### CIFAR-100
62 |
63 | | Model | Top-1 Acc.(%) | Top-5 Acc.(%) | #Params.(M) | #MAdds(M) | |
64 | |----------|----------------|---------------|-------------|-----------|--------------------|
65 | | resnet20 | 68.83 | 91.01 | 0.28 | 40.82 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet20-23dac2f1.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet20/default.log)
66 | | resnet32 | 70.16 | 90.89 | 0.47 | 69.13 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet32-84213ce6.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet32/default.log)
67 | | resnet44 | 71.63 | 91.58 | 0.67 | 97.44 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet44-ffe32858.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet44/default.log)
68 | | resnet56 | 72.63 | 91.94 | 0.86 | 125.75 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet56-f2eff4c8.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/resnet56/default.log)
69 | | vgg11_bn | 70.78 | 88.87 | 9.80 | 153.34 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg11_bn-57d0759e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg11_bn/default.log)
70 | | vgg13_bn | 74.63 | 91.09 | 9.99 | 228.84 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg13_bn-5ebe5778.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg13_bn/default.log)
71 | | vgg16_bn | 74.00 | 90.56 | 15.30 | 313.77 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg16_bn-7d8c4031.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg16_bn/default.log)
72 | | vgg19_bn | 73.87 | 90.13 | 20.61 | 398.71 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg19_bn-b98f7bd7.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/vgg19_bn/default.log)
73 | | mobilenetv2_x0_5 | 70.88 | 91.72 | 0.82 | 28.08 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_5-9f915757.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x0_5/default.log)
74 | | mobilenetv2_x0_75 | 73.61 | 92.61 | 1.48 | 59.43 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_75-d7891e60.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x0_75/default.log)
75 | | mobilenetv2_x1_0 | 74.20 | 92.82 | 2.35 | 88.09 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_0-1311f9ff.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x1_0/default.log)
76 | | mobilenetv2_x1_4 | 75.98 | 93.44 | 4.50 | 170.23 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_4-8a269f5e.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/mobilenetv2_x1_4/default.log)
77 | | shufflenetv2_x0_5 | 67.82 | 89.93 | 0.44 | 10.99 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x0_5-1977720f.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x0_5/default.log)
78 | | shufflenetv2_x1_0 | 72.39 | 91.46 | 1.36 | 45.09 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_0-9ae22beb.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x1_0/default.log)
79 | | shufflenetv2_x1_5 | 73.91 | 92.13 | 2.58 | 94.35 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_5-e2c85ad8.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x1_5/default.log)
80 | | shufflenetv2_x2_0 | 75.35 | 92.62 | 5.55 | 188.00 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x2_0-e7e584cd.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/shufflenetv2_x2_0/default.log)
81 | | repvgg_a0 | 75.22 | 92.93 | 7.96 | 489.19 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a0-2df1edd0.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/repvgg_a0/default.log)
82 | | repvgg_a1 | 76.12 | 92.71 | 12.94 | 851.44 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a1-c06b21a7.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/repvgg_a1/default.log)
83 | | repvgg_a2 | 77.18 | 93.51 | 26.94 | 1850.22 | [model](https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a2-8e71b1f8.pt) \| [log](https://cdn.jsdelivr.net/gh/chenyaofo/pytorch-cifar-models@logs/logs/cifar100/repvgg_a2/default.log)
84 |
85 | If you want to cite this repo:
86 |
87 | ```
88 | @misc{chenyaofo_pytorch_cifar_models,
89 | author = {Yaofo Chen},
90 | title = {PyTorch CIFAR Models},
91 | howpublished = {\url{https://github.com/chenyaofo/pytorch-cifar-models }},
92 | note = {Accessed: 2025-5-17},
93 | abstract = {This repository provides pretrained neural network models trained on CIFAR-10 and CIFAR-100 datasets, including ResNet, VGG, MobileNetV2, ShuffleNetV2, and RepVGG variants.}
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/colab/start_on_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "start_on_colab.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPIpbiKdUC+5q3/Vlm8dk4L"
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | },
18 | "accelerator": "GPU"
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "code",
23 | "metadata": {
24 | "id": "PwHOLr9y7S3Q"
25 | },
26 | "source": [
27 | "from IPython.display import HTML, display\n",
28 | "\n",
29 | "def set_css():\n",
30 | " display(HTML('''\n",
31 | " \n",
36 | " '''))\n",
37 | "get_ipython().events.register('pre_run_cell', set_css)"
38 | ],
39 | "execution_count": null,
40 | "outputs": []
41 | },
42 | {
43 | "cell_type": "code",
44 | "metadata": {
45 | "id": "zXkQlZUJ8Lnm"
46 | },
47 | "source": [
48 | "!git clone --recursive https://github.com/chenyaofo/image-classification-codebase\n",
49 | "%cd image-classification-codebase\n",
50 | "\n",
51 | "%pip install -qr requirements.txt\n",
52 | "\n",
53 | "import torch\n",
54 | "from IPython.display import clear_output\n",
55 | "\n",
56 | "clear_output()\n",
57 | "print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))"
58 | ],
59 | "execution_count": null,
60 | "outputs": []
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {
65 | "id": "Ijy9oBmN6ovz"
66 | },
67 | "source": [
68 | "The following command list all the available models."
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "metadata": {
74 | "colab": {
75 | "base_uri": "https://localhost:8080/",
76 | "height": 69
77 | },
78 | "id": "y-AG3Ne76dOv",
79 | "outputId": "5607e9cb-31aa-48d6-a465-65662eebb653"
80 | },
81 | "source": [
82 | "import torch\n",
83 | "print(torch.hub.list(\"chenyaofo/pytorch-cifar-models\", force_reload=True))"
84 | ],
85 | "execution_count": null,
86 | "outputs": [
87 | {
88 | "output_type": "display_data",
89 | "data": {
90 | "text/html": [
91 | "\n",
92 | " \n",
97 | " "
98 | ],
99 | "text/plain": [
100 | ""
101 | ]
102 | },
103 | "metadata": {
104 | "tags": []
105 | }
106 | },
107 | {
108 | "output_type": "stream",
109 | "text": [
110 | "Downloading: \"https://github.com/chenyaofo/pytorch-cifar-models/archive/master.zip\" to /root/.cache/torch/hub/master.zip\n"
111 | ],
112 | "name": "stderr"
113 | },
114 | {
115 | "output_type": "stream",
116 | "text": [
117 | "['cifar100_resnet20', 'cifar100_resnet32', 'cifar100_resnet44', 'cifar100_resnet56', 'cifar100_vgg11_bn', 'cifar100_vgg13_bn', 'cifar100_vgg16_bn', 'cifar100_vgg19_bn', 'cifar10_resnet20', 'cifar10_resnet32', 'cifar10_resnet44', 'cifar10_resnet56', 'cifar10_vgg11_bn', 'cifar10_vgg13_bn', 'cifar10_vgg16_bn', 'cifar10_vgg19_bn']\n"
118 | ],
119 | "name": "stdout"
120 | }
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {
126 | "id": "4jofqHNwAgvS"
127 | },
128 | "source": [
129 | "To train/evaluate different models, please set `model.model_name` to available model names listed in the previous block."
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "id": "YA5-PFL8BPUq"
136 | },
137 | "source": [
138 | "**Train on CIFAR-10:**"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "metadata": {
144 | "id": "BM_9IlZ72D9l"
145 | },
146 | "source": [
147 | "!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20"
148 | ],
149 | "execution_count": null,
150 | "outputs": []
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {
155 | "id": "QwrIVjWQNV82"
156 | },
157 | "source": [
158 | "**Evaluate on CIFAR-10:**"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "metadata": {
164 | "id": "DMd_USWH81Ov"
165 | },
166 | "source": [
167 | "!python -m entry.run --conf conf/cifar10.conf -o output/cifar10/resnet20 -M model.name=cifar10_resnet20 model.pretrained=true only_evaluate=true"
168 | ],
169 | "execution_count": null,
170 | "outputs": []
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "id": "aKYW0yNINoSk"
176 | },
177 | "source": [
178 | "**Train on CIFAR-100:**"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "metadata": {
184 | "id": "m2bVfrEjNoHY"
185 | },
186 | "source": [
187 | "!python -m entry.run --conf conf/cifar100.conf -o output/cifar100/resnet20 -M model.name=cifar100_resnet20"
188 | ],
189 | "execution_count": null,
190 | "outputs": []
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {
195 | "id": "PgcF8E8NNxgp"
196 | },
197 | "source": [
198 | "**Evaluate on CIFAR-100:**"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "metadata": {
204 | "id": "5k-sCfhqNxDc"
205 | },
206 | "source": [
207 | "!python -m entry.run --conf conf/cifar100.conf -o output/cifar100/resnet20 -M model.name=cifar100_resnet20 model.pretrained=true only_evaluate=true"
208 | ],
209 | "execution_count": null,
210 | "outputs": []
211 | }
212 | ]
213 | }
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | import pytorch_cifar_models
2 |
3 | dependencies = ['torch']
4 |
5 | models = filter(lambda name: name.startswith("cifar"), dir(pytorch_cifar_models))
6 | globals().update({model: getattr(pytorch_cifar_models, model) for model in models})
7 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import cifar10_resnet20
2 | from .resnet import cifar10_resnet32
3 | from .resnet import cifar10_resnet44
4 | from .resnet import cifar10_resnet56
5 |
6 | from .resnet import cifar100_resnet20
7 | from .resnet import cifar100_resnet32
8 | from .resnet import cifar100_resnet44
9 | from .resnet import cifar100_resnet56
10 |
11 | from .vgg import cifar10_vgg11_bn
12 | from .vgg import cifar10_vgg13_bn
13 | from .vgg import cifar10_vgg16_bn
14 | from .vgg import cifar10_vgg19_bn
15 |
16 | from .vgg import cifar100_vgg11_bn
17 | from .vgg import cifar100_vgg13_bn
18 | from .vgg import cifar100_vgg16_bn
19 | from .vgg import cifar100_vgg19_bn
20 |
21 | from .mobilenetv2 import cifar10_mobilenetv2_x0_5
22 | from .mobilenetv2 import cifar10_mobilenetv2_x0_75
23 | from .mobilenetv2 import cifar10_mobilenetv2_x1_0
24 | from .mobilenetv2 import cifar10_mobilenetv2_x1_4
25 |
26 | from .mobilenetv2 import cifar100_mobilenetv2_x0_5
27 | from .mobilenetv2 import cifar100_mobilenetv2_x0_75
28 | from .mobilenetv2 import cifar100_mobilenetv2_x1_0
29 | from .mobilenetv2 import cifar100_mobilenetv2_x1_4
30 |
31 | from .shufflenetv2 import cifar10_shufflenetv2_x0_5
32 | from .shufflenetv2 import cifar10_shufflenetv2_x1_0
33 | from .shufflenetv2 import cifar10_shufflenetv2_x1_5
34 | from .shufflenetv2 import cifar10_shufflenetv2_x2_0
35 |
36 | from .shufflenetv2 import cifar100_shufflenetv2_x0_5
37 | from .shufflenetv2 import cifar100_shufflenetv2_x1_0
38 | from .shufflenetv2 import cifar100_shufflenetv2_x1_5
39 | from .shufflenetv2 import cifar100_shufflenetv2_x2_0
40 |
41 | from .repvgg import cifar10_repvgg_a0
42 | from .repvgg import cifar10_repvgg_a1
43 | from .repvgg import cifar10_repvgg_a2
44 |
45 | from .repvgg import cifar100_repvgg_a0
46 | from .repvgg import cifar100_repvgg_a1
47 | from .repvgg import cifar100_repvgg_a2
48 |
49 | from .vit import cifar10_vit_b16
50 | from .vit import cifar10_vit_b32
51 | from .vit import cifar10_vit_l16
52 | from .vit import cifar10_vit_l32
53 | from .vit import cifar10_vit_h14
54 |
55 | from .vit import cifar100_vit_b16
56 | from .vit import cifar100_vit_b32
57 | from .vit import cifar100_vit_l16
58 | from .vit import cifar100_vit_l32
59 | from .vit import cifar100_vit_h14
60 |
61 | __version__ = "0.1.0-alpha"
62 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/mobilenetv2.py
3 |
4 | BSD 3-Clause License
5 |
6 | Copyright (c) Soumith Chintala 2016,
7 | All rights reserved.
8 |
9 | Redistribution and use in source and binary forms, with or without
10 | modification, are permitted provided that the following conditions are met:
11 |
12 | * Redistributions of source code must retain the above copyright notice, this
13 | list of conditions and the following disclaimer.
14 |
15 | * Redistributions in binary form must reproduce the above copyright notice,
16 | this list of conditions and the following disclaimer in the documentation
17 | and/or other materials provided with the distribution.
18 |
19 | * Neither the name of the copyright holder nor the names of its
20 | contributors may be used to endorse or promote products derived from
21 | this software without specific prior written permission.
22 |
23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 | '''
34 | import sys
35 | import torch
36 | from torch import nn
37 | from torch import Tensor
38 | try:
39 | from torch.hub import load_state_dict_from_url
40 | except ImportError:
41 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
42 |
43 | from functools import partial
44 | from typing import Dict, Type, Any, Callable, Union, List, Optional
45 |
46 |
47 | cifar10_pretrained_weight_urls = {
48 | 'mobilenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_5-ca14ced9.pt',
49 | 'mobilenetv2_x0_75': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x0_75-a53c314e.pt',
50 | 'mobilenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_0-fe6a5b48.pt',
51 | 'mobilenetv2_x1_4': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar10_mobilenetv2_x1_4-3bbbd6e2.pt',
52 | }
53 |
54 | cifar100_pretrained_weight_urls = {
55 | 'mobilenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_5-9f915757.pt',
56 | 'mobilenetv2_x0_75': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x0_75-d7891e60.pt',
57 | 'mobilenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_0-1311f9ff.pt',
58 | 'mobilenetv2_x1_4': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/mobilenetv2/cifar100_mobilenetv2_x1_4-8a269f5e.pt',
59 | }
60 |
61 |
62 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
63 | """
64 | This function is taken from the original tf repo.
65 | It ensures that all layers have a channel number that is divisible by 8
66 | It can be seen here:
67 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
68 | """
69 | if min_value is None:
70 | min_value = divisor
71 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
72 | # Make sure that round down does not go down by more than 10%.
73 | if new_v < 0.9 * v:
74 | new_v += divisor
75 | return new_v
76 |
77 |
78 | class ConvBNActivation(nn.Sequential):
79 | def __init__(
80 | self,
81 | in_planes: int,
82 | out_planes: int,
83 | kernel_size: int = 3,
84 | stride: int = 1,
85 | groups: int = 1,
86 | norm_layer: Optional[Callable[..., nn.Module]] = None,
87 | activation_layer: Optional[Callable[..., nn.Module]] = None,
88 | dilation: int = 1,
89 | ) -> None:
90 | padding = (kernel_size - 1) // 2 * dilation
91 | if norm_layer is None:
92 | norm_layer = nn.BatchNorm2d
93 | if activation_layer is None:
94 | activation_layer = nn.ReLU6
95 | super(ConvBNReLU, self).__init__(
96 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
97 | bias=False),
98 | norm_layer(out_planes),
99 | activation_layer(inplace=True)
100 | )
101 | self.out_channels = out_planes
102 |
103 |
104 | # necessary for backwards compatibility
105 | ConvBNReLU = ConvBNActivation
106 |
107 |
108 | class InvertedResidual(nn.Module):
109 | def __init__(
110 | self,
111 | inp: int,
112 | oup: int,
113 | stride: int,
114 | expand_ratio: int,
115 | norm_layer: Optional[Callable[..., nn.Module]] = None
116 | ) -> None:
117 | super(InvertedResidual, self).__init__()
118 | self.stride = stride
119 | assert stride in [1, 2]
120 |
121 | if norm_layer is None:
122 | norm_layer = nn.BatchNorm2d
123 |
124 | hidden_dim = int(round(inp * expand_ratio))
125 | self.use_res_connect = self.stride == 1 and inp == oup
126 |
127 | layers: List[nn.Module] = []
128 | if expand_ratio != 1:
129 | # pw
130 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
131 | layers.extend([
132 | # dw
133 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
134 | # pw-linear
135 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
136 | norm_layer(oup),
137 | ])
138 | self.conv = nn.Sequential(*layers)
139 | self.out_channels = oup
140 | self._is_cn = stride > 1
141 |
142 | def forward(self, x: Tensor) -> Tensor:
143 | if self.use_res_connect:
144 | return x + self.conv(x)
145 | else:
146 | return self.conv(x)
147 |
148 |
149 | class MobileNetV2(nn.Module):
150 | def __init__(
151 | self,
152 | num_classes: int = 10,
153 | width_mult: float = 1.0,
154 | inverted_residual_setting: Optional[List[List[int]]] = None,
155 | round_nearest: int = 8,
156 | block: Optional[Callable[..., nn.Module]] = None,
157 | norm_layer: Optional[Callable[..., nn.Module]] = None
158 | ) -> None:
159 | """
160 | MobileNet V2 main class
161 |
162 | Args:
163 | num_classes (int): Number of classes
164 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
165 | inverted_residual_setting: Network structure
166 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
167 | Set to 1 to turn off rounding
168 | block: Module specifying inverted residual building block for mobilenet
169 | norm_layer: Module specifying the normalization layer to use
170 |
171 | """
172 | super(MobileNetV2, self).__init__()
173 |
174 | if block is None:
175 | block = InvertedResidual
176 |
177 | if norm_layer is None:
178 | norm_layer = nn.BatchNorm2d
179 |
180 | input_channel = 32
181 | last_channel = 1280
182 |
183 | if inverted_residual_setting is None:
184 | inverted_residual_setting = [
185 | # t, c, n, s
186 | [1, 16, 1, 1],
187 | [6, 24, 2, 1], # NOTE: change stride 2 -> 1 for CIFAR10/100
188 | [6, 32, 3, 2],
189 | [6, 64, 4, 2],
190 | [6, 96, 3, 1],
191 | [6, 160, 3, 2],
192 | [6, 320, 1, 1],
193 | ]
194 |
195 | # only check the first element, assuming user knows t,c,n,s are required
196 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
197 | raise ValueError("inverted_residual_setting should be non-empty "
198 | "or a 4-element list, got {}".format(inverted_residual_setting))
199 |
200 | # building first layer
201 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
202 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
203 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=1, norm_layer=norm_layer)] # NOTE: change stride 2 -> 1 for CIFAR10/100
204 | # building inverted residual blocks
205 | for t, c, n, s in inverted_residual_setting:
206 | output_channel = _make_divisible(c * width_mult, round_nearest)
207 | for i in range(n):
208 | stride = s if i == 0 else 1
209 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
210 | input_channel = output_channel
211 | # building last several layers
212 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
213 | # make it nn.Sequential
214 | self.features = nn.Sequential(*features)
215 |
216 | # building classifier
217 | self.classifier = nn.Sequential(
218 | nn.Dropout(0.2),
219 | nn.Linear(self.last_channel, num_classes),
220 | )
221 |
222 | # weight initialization
223 | for m in self.modules():
224 | if isinstance(m, nn.Conv2d):
225 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
226 | if m.bias is not None:
227 | nn.init.zeros_(m.bias)
228 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
229 | nn.init.ones_(m.weight)
230 | nn.init.zeros_(m.bias)
231 | elif isinstance(m, nn.Linear):
232 | nn.init.normal_(m.weight, 0, 0.01)
233 | nn.init.zeros_(m.bias)
234 |
235 | def _forward_impl(self, x: Tensor) -> Tensor:
236 | # This exists since TorchScript doesn't support inheritance, so the superclass method
237 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
238 | x = self.features(x)
239 | # Cannot use "squeeze" as batch-size can be 1
240 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
241 | x = torch.flatten(x, 1)
242 | x = self.classifier(x)
243 | return x
244 |
245 | def forward(self, x: Tensor) -> Tensor:
246 | return self._forward_impl(x)
247 |
248 |
249 | def _mobilenet_v2(
250 | arch: str,
251 | width_mult: List[int],
252 | model_urls: Dict[str, str],
253 | progress: bool = True,
254 | pretrained: bool = False,
255 | **kwargs: Any
256 | ) -> MobileNetV2:
257 | model = MobileNetV2(width_mult=width_mult, **kwargs)
258 | if pretrained:
259 | state_dict = load_state_dict_from_url(model_urls[arch],
260 | progress=progress)
261 | model.load_state_dict(state_dict)
262 | return model
263 |
264 |
265 | def cifar10_mobilenetv2_x0_5(*args, **kwargs) -> MobileNetV2: pass
266 | def cifar10_mobilenetv2_x0_75(*args, **kwargs) -> MobileNetV2: pass
267 | def cifar10_mobilenetv2_x1_0(*args, **kwargs) -> MobileNetV2: pass
268 | def cifar10_mobilenetv2_x1_4(*args, **kwargs) -> MobileNetV2: pass
269 |
270 |
271 | def cifar100_mobilenetv2_x0_5(*args, **kwargs) -> MobileNetV2: pass
272 | def cifar100_mobilenetv2_x0_75(*args, **kwargs) -> MobileNetV2: pass
273 | def cifar100_mobilenetv2_x1_0(*args, **kwargs) -> MobileNetV2: pass
274 | def cifar100_mobilenetv2_x1_4(*args, **kwargs) -> MobileNetV2: pass
275 |
276 |
277 | thismodule = sys.modules[__name__]
278 | for dataset in ["cifar10", "cifar100"]:
279 | for width_mult, model_name in zip([0.5, 0.75, 1.0, 1.4],
280 | ["mobilenetv2_x0_5", "mobilenetv2_x0_75", "mobilenetv2_x1_0", "mobilenetv2_x1_4"]):
281 | method_name = f"{dataset}_{model_name}"
282 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
283 | num_classes = 10 if dataset == "cifar10" else 100
284 | setattr(
285 | thismodule,
286 | method_name,
287 | partial(_mobilenet_v2,
288 | arch=model_name,
289 | width_mult=width_mult,
290 | model_urls=model_urls,
291 | num_classes=num_classes)
292 | )
293 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/repvgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://raw.githubusercontent.com/DingXiaoH/RepVGG/main/repvgg.py
3 |
4 | MIT License
5 |
6 | Copyright (c) 2020 DingXiaoH
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 | '''
26 | import sys
27 | import copy
28 | import torch
29 |
30 | import torch.nn as nn
31 | import numpy as np
32 |
33 | try:
34 | from torch.hub import load_state_dict_from_url
35 | except ImportError:
36 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
37 | from functools import partial
38 | from typing import Union, List, Dict, Any, cast
39 |
40 | cifar10_pretrained_weight_urls = {
41 | 'repvgg_a0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a0-ef08a50e.pt',
42 | 'repvgg_a1': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a1-38d2431b.pt',
43 | 'repvgg_a2': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar10_repvgg_a2-09488915.pt',
44 | }
45 |
46 | cifar100_pretrained_weight_urls = {
47 | 'repvgg_a0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a0-2df1edd0.pt',
48 | 'repvgg_a1': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a1-c06b21a7.pt',
49 | 'repvgg_a2': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/repvgg/cifar100_repvgg_a2-8e71b1f8.pt',
50 | }
51 |
52 |
53 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
54 | result = nn.Sequential()
55 | result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
56 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
57 | result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
58 | return result
59 |
60 |
61 | class RepVGGBlock(nn.Module):
62 |
63 | def __init__(self, in_channels, out_channels, kernel_size,
64 | stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):
65 | super(RepVGGBlock, self).__init__()
66 | self.deploy = deploy
67 | self.groups = groups
68 | self.in_channels = in_channels
69 |
70 | assert kernel_size == 3
71 | assert padding == 1
72 |
73 | padding_11 = padding - kernel_size // 2
74 |
75 | self.nonlinearity = nn.ReLU()
76 |
77 | if deploy:
78 | self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
79 | padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
80 |
81 | else:
82 | self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
83 | self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
84 | self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
85 | # print('RepVGG Block, identity = ', self.rbr_identity)
86 |
87 | def forward(self, inputs):
88 | if hasattr(self, 'rbr_reparam'):
89 | return self.nonlinearity(self.rbr_reparam(inputs))
90 |
91 | if self.rbr_identity is None:
92 | id_out = 0
93 | else:
94 | id_out = self.rbr_identity(inputs)
95 |
96 | return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
97 |
98 |
99 | # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
100 | # You can get the equivalent kernel and bias at any time and do whatever you want,
101 | # for example, apply some penalties or constraints during training, just like you do to the other models.
102 | # May be useful for quantization or pruning.
103 |
104 |
105 | def get_equivalent_kernel_bias(self):
106 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
107 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
108 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
109 | return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
110 |
111 | def _pad_1x1_to_3x3_tensor(self, kernel1x1):
112 | if kernel1x1 is None:
113 | return 0
114 | else:
115 | return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
116 |
117 | def _fuse_bn_tensor(self, branch):
118 | if branch is None:
119 | return 0, 0
120 | if isinstance(branch, nn.Sequential):
121 | kernel = branch.conv.weight
122 | running_mean = branch.bn.running_mean
123 | running_var = branch.bn.running_var
124 | gamma = branch.bn.weight
125 | beta = branch.bn.bias
126 | eps = branch.bn.eps
127 | else:
128 | assert isinstance(branch, nn.BatchNorm2d)
129 | if not hasattr(self, 'id_tensor'):
130 | input_dim = self.in_channels // self.groups
131 | kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
132 | for i in range(self.in_channels):
133 | kernel_value[i, i % input_dim, 1, 1] = 1
134 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
135 | kernel = self.id_tensor
136 | running_mean = branch.running_mean
137 | running_var = branch.running_var
138 | gamma = branch.weight
139 | beta = branch.bias
140 | eps = branch.eps
141 | std = (running_var + eps).sqrt()
142 | t = (gamma / std).reshape(-1, 1, 1, 1)
143 | return kernel * t, beta - running_mean * gamma / std
144 |
145 | def switch_to_deploy(self):
146 | if hasattr(self, 'rbr_reparam'):
147 | return
148 | kernel, bias = self.get_equivalent_kernel_bias()
149 | self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
150 | kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
151 | padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
152 | self.rbr_reparam.weight.data = kernel
153 | self.rbr_reparam.bias.data = bias
154 | for para in self.parameters():
155 | para.detach_()
156 | self.__delattr__('rbr_dense')
157 | self.__delattr__('rbr_1x1')
158 | if hasattr(self, 'rbr_identity'):
159 | self.__delattr__('rbr_identity')
160 |
161 |
162 | class RepVGG(nn.Module):
163 |
164 | def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False):
165 | super(RepVGG, self).__init__()
166 |
167 | assert len(width_multiplier) == 4
168 |
169 | self.deploy = deploy
170 | self.override_groups_map = override_groups_map or dict()
171 |
172 | assert 0 not in self.override_groups_map
173 |
174 | self.in_planes = min(64, int(64 * width_multiplier[0]))
175 |
176 | self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes,
177 | kernel_size=3, stride=1, padding=1, deploy=self.deploy) # NOTE: change stride 2 -> 1 for CIFAR10/100
178 | self.cur_layer_idx = 1
179 | self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=1) # NOTE: change stride 2 -> 1 for CIFAR10/100
180 | self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
181 | self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)
182 | self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)
183 | self.gap = nn.AdaptiveAvgPool2d(output_size=1)
184 | self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
185 |
186 | def _make_stage(self, planes, num_blocks, stride):
187 | strides = [stride] + [1]*(num_blocks-1)
188 | blocks = []
189 | for stride in strides:
190 | cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
191 | blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
192 | stride=stride, padding=1, groups=cur_groups, deploy=self.deploy))
193 | self.in_planes = planes
194 | self.cur_layer_idx += 1
195 | return nn.Sequential(*blocks)
196 |
197 | def forward(self, x):
198 | out = self.stage0(x)
199 | out = self.stage1(out)
200 | out = self.stage2(out)
201 | out = self.stage3(out)
202 | out = self.stage4(out)
203 | out = self.gap(out)
204 | out = out.view(out.size(0), -1)
205 | out = self.linear(out)
206 | return out
207 |
208 | def convert_to_inference_model(self, do_copy=False):
209 | model = copy.deepcopy(self) if do_copy else self
210 | for module in model.modules():
211 | if hasattr(module, 'switch_to_deploy'):
212 | module.switch_to_deploy()
213 |
214 |
215 | optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
216 | g2_map = {l: 2 for l in optional_groupwise_layers}
217 | g4_map = {l: 4 for l in optional_groupwise_layers}
218 |
219 |
220 | def _repvgg(arch: str, num_blocks: List[int], width_multiplier: List[float],
221 | model_urls: Dict[str, str],
222 | pretrained: bool = False, progress: bool = True, **kwargs: Any) -> RepVGG:
223 | model = RepVGG(num_blocks=num_blocks, width_multiplier=width_multiplier,
224 | override_groups_map=None, **kwargs)
225 | if pretrained:
226 | state_dict = load_state_dict_from_url(model_urls[arch],
227 | progress=progress)
228 | model.load_state_dict(state_dict)
229 | return model
230 |
231 |
232 | def cifar10_repvgg_a0(*args, **kwargs) -> RepVGG: pass
233 | def cifar10_repvgg_a1(*args, **kwargs) -> RepVGG: pass
234 | def cifar10_repvgg_a2(*args, **kwargs) -> RepVGG: pass
235 |
236 |
237 | def cifar100_repvgg_a0(*args, **kwargs) -> RepVGG: pass
238 | def cifar100_repvgg_a1(*args, **kwargs) -> RepVGG: pass
239 | def cifar100_repvgg_a2(*args, **kwargs) -> RepVGG: pass
240 |
241 |
242 | thismodule = sys.modules[__name__]
243 | for dataset in ["cifar10", "cifar100"]:
244 | for num_blocks, width_multiplier, model_name in\
245 | zip([[2, 4, 14, 1], [2, 4, 14, 1], [2, 4, 14, 1]],
246 | [[0.75, 0.75, 0.75, 2.5], [1, 1, 1, 2.5], [1.5, 1.5, 1.5, 2.75]],
247 | ["repvgg_a0", "repvgg_a1", "repvgg_a2"]):
248 | method_name = f"{dataset}_{model_name}"
249 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
250 | num_classes = 10 if dataset == "cifar10" else 100
251 | setattr(
252 | thismodule,
253 | method_name,
254 | partial(_repvgg,
255 | arch=model_name,
256 | num_blocks=num_blocks,
257 | width_multiplier=width_multiplier,
258 | model_urls=model_urls,
259 | num_classes=num_classes)
260 | )
261 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/resnet.py
3 |
4 | BSD 3-Clause License
5 |
6 | Copyright (c) Soumith Chintala 2016,
7 | All rights reserved.
8 |
9 | Redistribution and use in source and binary forms, with or without
10 | modification, are permitted provided that the following conditions are met:
11 |
12 | * Redistributions of source code must retain the above copyright notice, this
13 | list of conditions and the following disclaimer.
14 |
15 | * Redistributions in binary form must reproduce the above copyright notice,
16 | this list of conditions and the following disclaimer in the documentation
17 | and/or other materials provided with the distribution.
18 |
19 | * Neither the name of the copyright holder nor the names of its
20 | contributors may be used to endorse or promote products derived from
21 | this software without specific prior written permission.
22 |
23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 | '''
34 | import sys
35 | import torch.nn as nn
36 | try:
37 | from torch.hub import load_state_dict_from_url
38 | except ImportError:
39 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
40 |
41 | from functools import partial
42 | from typing import Dict, Type, Any, Callable, Union, List, Optional
43 |
44 |
45 | cifar10_pretrained_weight_urls = {
46 | 'resnet20': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet20-4118986f.pt',
47 | 'resnet32': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet32-ef93fc4d.pt',
48 | 'resnet44': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet44-2a3cabcb.pt',
49 | 'resnet56': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar10_resnet56-187c023a.pt',
50 | }
51 |
52 | cifar100_pretrained_weight_urls = {
53 | 'resnet20': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet20-23dac2f1.pt',
54 | 'resnet32': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet32-84213ce6.pt',
55 | 'resnet44': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet44-ffe32858.pt',
56 | 'resnet56': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/resnet/cifar100_resnet56-f2eff4c8.pt',
57 | }
58 |
59 |
60 | def conv3x3(in_planes, out_planes, stride=1):
61 | """3x3 convolution with padding"""
62 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
63 |
64 |
65 | def conv1x1(in_planes, out_planes, stride=1):
66 | """1x1 convolution"""
67 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
68 |
69 |
70 | class BasicBlock(nn.Module):
71 | expansion = 1
72 |
73 | def __init__(self, inplanes, planes, stride=1, downsample=None):
74 | super(BasicBlock, self).__init__()
75 | self.conv1 = conv3x3(inplanes, planes, stride)
76 | self.bn1 = nn.BatchNorm2d(planes)
77 | self.relu = nn.ReLU(inplace=True)
78 | self.conv2 = conv3x3(planes, planes)
79 | self.bn2 = nn.BatchNorm2d(planes)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | identity = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 |
93 | if self.downsample is not None:
94 | identity = self.downsample(x)
95 |
96 | out += identity
97 | out = self.relu(out)
98 |
99 | return out
100 |
101 |
102 | class CifarResNet(nn.Module):
103 |
104 | def __init__(self, block, layers, num_classes=10):
105 | super(CifarResNet, self).__init__()
106 | self.inplanes = 16
107 | self.conv1 = conv3x3(3, 16)
108 | self.bn1 = nn.BatchNorm2d(16)
109 | self.relu = nn.ReLU(inplace=True)
110 |
111 | self.layer1 = self._make_layer(block, 16, layers[0])
112 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
113 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
114 |
115 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
116 | self.fc = nn.Linear(64 * block.expansion, num_classes)
117 |
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
121 | elif isinstance(m, nn.BatchNorm2d):
122 | nn.init.constant_(m.weight, 1)
123 | nn.init.constant_(m.bias, 0)
124 |
125 | def _make_layer(self, block, planes, blocks, stride=1):
126 | downsample = None
127 | if stride != 1 or self.inplanes != planes * block.expansion:
128 | downsample = nn.Sequential(
129 | conv1x1(self.inplanes, planes * block.expansion, stride),
130 | nn.BatchNorm2d(planes * block.expansion),
131 | )
132 |
133 | layers = []
134 | layers.append(block(self.inplanes, planes, stride, downsample))
135 | self.inplanes = planes * block.expansion
136 | for _ in range(1, blocks):
137 | layers.append(block(self.inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x):
142 | x = self.conv1(x)
143 | x = self.bn1(x)
144 | x = self.relu(x)
145 |
146 | x = self.layer1(x)
147 | x = self.layer2(x)
148 | x = self.layer3(x)
149 |
150 | x = self.avgpool(x)
151 | x = x.view(x.size(0), -1)
152 | x = self.fc(x)
153 |
154 | return x
155 |
156 |
157 | def _resnet(
158 | arch: str,
159 | layers: List[int],
160 | model_urls: Dict[str, str],
161 | progress: bool = True,
162 | pretrained: bool = False,
163 | **kwargs: Any
164 | ) -> CifarResNet:
165 | model = CifarResNet(BasicBlock, layers, **kwargs)
166 | if pretrained:
167 | state_dict = load_state_dict_from_url(model_urls[arch],
168 | progress=progress)
169 | model.load_state_dict(state_dict)
170 | return model
171 |
172 |
173 | def cifar10_resnet20(*args, **kwargs) -> CifarResNet: pass
174 | def cifar10_resnet32(*args, **kwargs) -> CifarResNet: pass
175 | def cifar10_resnet44(*args, **kwargs) -> CifarResNet: pass
176 | def cifar10_resnet56(*args, **kwargs) -> CifarResNet: pass
177 |
178 |
179 | def cifar100_resnet20(*args, **kwargs) -> CifarResNet: pass
180 | def cifar100_resnet32(*args, **kwargs) -> CifarResNet: pass
181 | def cifar100_resnet44(*args, **kwargs) -> CifarResNet: pass
182 | def cifar100_resnet56(*args, **kwargs) -> CifarResNet: pass
183 |
184 |
185 | thismodule = sys.modules[__name__]
186 | for dataset in ["cifar10", "cifar100"]:
187 | for layers, model_name in zip([[3]*3, [5]*3, [7]*3, [9]*3],
188 | ["resnet20", "resnet32", "resnet44", "resnet56"]):
189 | method_name = f"{dataset}_{model_name}"
190 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
191 | num_classes = 10 if dataset == "cifar10" else 100
192 | setattr(
193 | thismodule,
194 | method_name,
195 | partial(_resnet,
196 | arch=model_name,
197 | layers=layers,
198 | model_urls=model_urls,
199 | num_classes=num_classes)
200 | )
201 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/shufflenetv2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/shufflenetv2.py
3 |
4 | BSD 3-Clause License
5 |
6 | Copyright (c) Soumith Chintala 2016,
7 | All rights reserved.
8 |
9 | Redistribution and use in source and binary forms, with or without
10 | modification, are permitted provided that the following conditions are met:
11 |
12 | * Redistributions of source code must retain the above copyright notice, this
13 | list of conditions and the following disclaimer.
14 |
15 | * Redistributions in binary form must reproduce the above copyright notice,
16 | this list of conditions and the following disclaimer in the documentation
17 | and/or other materials provided with the distribution.
18 |
19 | * Neither the name of the copyright holder nor the names of its
20 | contributors may be used to endorse or promote products derived from
21 | this software without specific prior written permission.
22 |
23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 | '''
34 | import sys
35 | import torch
36 | import torch.nn as nn
37 | from torch import Tensor
38 |
39 | try:
40 | from torch.hub import load_state_dict_from_url
41 | except ImportError:
42 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
43 |
44 | from functools import partial
45 | from typing import Dict, Type, Any, Callable, Union, List, Optional
46 |
47 |
48 | cifar10_pretrained_weight_urls = {
49 | 'shufflenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x0_5-1308b4e9.pt',
50 | 'shufflenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_0-98807be3.pt',
51 | 'shufflenetv2_x1_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x1_5-296694dd.pt',
52 | 'shufflenetv2_x2_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar10_shufflenetv2_x2_0-ec31611c.pt',
53 | }
54 |
55 | cifar100_pretrained_weight_urls = {
56 | 'shufflenetv2_x0_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x0_5-1977720f.pt',
57 | 'shufflenetv2_x1_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_0-9ae22beb.pt',
58 | 'shufflenetv2_x1_5': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x1_5-e2c85ad8.pt',
59 | 'shufflenetv2_x2_0': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/shufflenetv2/cifar100_shufflenetv2_x2_0-e7e584cd.pt',
60 | }
61 |
62 |
63 | def channel_shuffle(x: Tensor, groups: int) -> Tensor:
64 | batchsize, num_channels, height, width = x.size()
65 | channels_per_group = num_channels // groups
66 |
67 | # reshape
68 | x = x.view(batchsize, groups,
69 | channels_per_group, height, width)
70 |
71 | x = torch.transpose(x, 1, 2).contiguous()
72 |
73 | # flatten
74 | x = x.view(batchsize, -1, height, width)
75 |
76 | return x
77 |
78 |
79 | class InvertedResidual(nn.Module):
80 | def __init__(
81 | self,
82 | inp: int,
83 | oup: int,
84 | stride: int
85 | ) -> None:
86 | super(InvertedResidual, self).__init__()
87 |
88 | if not (1 <= stride <= 3):
89 | raise ValueError('illegal stride value')
90 | self.stride = stride
91 |
92 | branch_features = oup // 2
93 | assert (self.stride != 1) or (inp == branch_features << 1)
94 |
95 | if self.stride > 1:
96 | self.branch1 = nn.Sequential(
97 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
98 | nn.BatchNorm2d(inp),
99 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
100 | nn.BatchNorm2d(branch_features),
101 | nn.ReLU(inplace=True),
102 | )
103 | else:
104 | self.branch1 = nn.Sequential()
105 |
106 | self.branch2 = nn.Sequential(
107 | nn.Conv2d(inp if (self.stride > 1) else branch_features,
108 | branch_features, kernel_size=1, stride=1, padding=0, bias=False),
109 | nn.BatchNorm2d(branch_features),
110 | nn.ReLU(inplace=True),
111 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
112 | nn.BatchNorm2d(branch_features),
113 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
114 | nn.BatchNorm2d(branch_features),
115 | nn.ReLU(inplace=True),
116 | )
117 |
118 | @staticmethod
119 | def depthwise_conv(
120 | i: int,
121 | o: int,
122 | kernel_size: int,
123 | stride: int = 1,
124 | padding: int = 0,
125 | bias: bool = False
126 | ) -> nn.Conv2d:
127 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
128 |
129 | def forward(self, x: Tensor) -> Tensor:
130 | if self.stride == 1:
131 | x1, x2 = x.chunk(2, dim=1)
132 | out = torch.cat((x1, self.branch2(x2)), dim=1)
133 | else:
134 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
135 |
136 | out = channel_shuffle(out, 2)
137 |
138 | return out
139 |
140 |
141 | class ShuffleNetV2(nn.Module):
142 | def __init__(
143 | self,
144 | stages_repeats: List[int],
145 | stages_out_channels: List[int],
146 | num_classes: int = 1000,
147 | inverted_residual: Callable[..., nn.Module] = InvertedResidual
148 | ) -> None:
149 | super(ShuffleNetV2, self).__init__()
150 |
151 | if len(stages_repeats) != 3:
152 | raise ValueError('expected stages_repeats as list of 3 positive ints')
153 | if len(stages_out_channels) != 5:
154 | raise ValueError('expected stages_out_channels as list of 5 positive ints')
155 | self._stage_out_channels = stages_out_channels
156 |
157 | input_channels = 3
158 | output_channels = self._stage_out_channels[0]
159 | self.conv1 = nn.Sequential(
160 | nn.Conv2d(input_channels, output_channels, 3, 1, 1, bias=False), # NOTE: change stride 2 -> 1 for CIFAR10/100
161 | nn.BatchNorm2d(output_channels),
162 | nn.ReLU(inplace=True),
163 | )
164 | input_channels = output_channels
165 |
166 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) NOTE: remove this maxpool layer for CIFAR10/100
167 |
168 | # Static annotations for mypy
169 | self.stage2: nn.Sequential
170 | self.stage3: nn.Sequential
171 | self.stage4: nn.Sequential
172 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
173 | for name, repeats, output_channels in zip(
174 | stage_names, stages_repeats, self._stage_out_channels[1:]):
175 | seq = [inverted_residual(input_channels, output_channels, 2)]
176 | for i in range(repeats - 1):
177 | seq.append(inverted_residual(output_channels, output_channels, 1))
178 | setattr(self, name, nn.Sequential(*seq))
179 | input_channels = output_channels
180 |
181 | output_channels = self._stage_out_channels[-1]
182 | self.conv5 = nn.Sequential(
183 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
184 | nn.BatchNorm2d(output_channels),
185 | nn.ReLU(inplace=True),
186 | )
187 |
188 | self.fc = nn.Linear(output_channels, num_classes)
189 |
190 | def _forward_impl(self, x: Tensor) -> Tensor:
191 | # See note [TorchScript super()]
192 | x = self.conv1(x)
193 | # x = self.maxpool(x) NOTE: remove this maxpool layer for CIFAR10/100
194 | x = self.stage2(x)
195 | x = self.stage3(x)
196 | x = self.stage4(x)
197 | x = self.conv5(x)
198 | x = x.mean([2, 3]) # globalpool
199 | x = self.fc(x)
200 | return x
201 |
202 | def forward(self, x: Tensor) -> Tensor:
203 | return self._forward_impl(x)
204 |
205 |
206 | def _shufflenet_v2(
207 | arch: str,
208 | stages_repeats: List[int],
209 | stages_out_channels: List[int],
210 | model_urls: Dict[str, str],
211 | progress: bool = True,
212 | pretrained: bool = False,
213 | **kwargs: Any
214 | ) -> ShuffleNetV2:
215 | model = ShuffleNetV2(stages_repeats=stages_repeats, stages_out_channels=stages_out_channels, ** kwargs)
216 | if pretrained:
217 | state_dict = load_state_dict_from_url(model_urls[arch],
218 | progress=progress)
219 | model.load_state_dict(state_dict)
220 | return model
221 |
222 |
223 | def cifar10_shufflenetv2_x0_5(*args, **kwargs) -> ShuffleNetV2: pass
224 | def cifar10_shufflenetv2_x1_0(*args, **kwargs) -> ShuffleNetV2: pass
225 | def cifar10_shufflenetv2_x1_5(*args, **kwargs) -> ShuffleNetV2: pass
226 | def cifar10_shufflenetv2_x2_0(*args, **kwargs) -> ShuffleNetV2: pass
227 |
228 |
229 | def cifar100_shufflenetv2_x0_5(*args, **kwargs) -> ShuffleNetV2: pass
230 | def cifar100_shufflenetv2_x1_0(*args, **kwargs) -> ShuffleNetV2: pass
231 | def cifar100_shufflenetv2_x1_5(*args, **kwargs) -> ShuffleNetV2: pass
232 | def cifar100_shufflenetv2_x2_0(*args, **kwargs) -> ShuffleNetV2: pass
233 |
234 |
235 | thismodule = sys.modules[__name__]
236 | for dataset in ["cifar10", "cifar100"]:
237 | for stages_repeats, stages_out_channels, model_name in \
238 | zip([[4, 8, 4]]*4,
239 | [[24, 48, 96, 192, 1024], [24, 116, 232, 464, 1024], [24, 176, 352, 704, 1024], [24, 244, 488, 976, 2048]],
240 | ["shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0"]):
241 | method_name = f"{dataset}_{model_name}"
242 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
243 | num_classes = 10 if dataset == "cifar10" else 100
244 | setattr(
245 | thismodule,
246 | method_name,
247 | partial(_shufflenet_v2,
248 | arch=model_name,
249 | stages_repeats=stages_repeats,
250 | stages_out_channels=stages_out_channels,
251 | model_urls=model_urls,
252 | num_classes=num_classes)
253 | )
254 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://raw.githubusercontent.com/pytorch/vision/v0.9.1/torchvision/models/vgg.py
3 |
4 | BSD 3-Clause License
5 |
6 | Copyright (c) Soumith Chintala 2016,
7 | All rights reserved.
8 |
9 | Redistribution and use in source and binary forms, with or without
10 | modification, are permitted provided that the following conditions are met:
11 |
12 | * Redistributions of source code must retain the above copyright notice, this
13 | list of conditions and the following disclaimer.
14 |
15 | * Redistributions in binary form must reproduce the above copyright notice,
16 | this list of conditions and the following disclaimer in the documentation
17 | and/or other materials provided with the distribution.
18 |
19 | * Neither the name of the copyright holder nor the names of its
20 | contributors may be used to endorse or promote products derived from
21 | this software without specific prior written permission.
22 |
23 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
26 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
27 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
29 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
31 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 | '''
34 |
35 | import sys
36 | import torch
37 | import torch.nn as nn
38 | try:
39 | from torch.hub import load_state_dict_from_url
40 | except ImportError:
41 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
42 | from functools import partial
43 | from typing import Union, List, Dict, Any, cast
44 |
45 | cifar10_pretrained_weight_urls = {
46 | 'vgg11_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg11_bn-eaeebf42.pt',
47 | 'vgg13_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg13_bn-c01e4a43.pt',
48 | 'vgg16_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg16_bn-6ee7ea24.pt',
49 | 'vgg19_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar10_vgg19_bn-57191229.pt',
50 | }
51 |
52 | cifar100_pretrained_weight_urls = {
53 | 'vgg11_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg11_bn-57d0759e.pt',
54 | 'vgg13_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg13_bn-5ebe5778.pt',
55 | 'vgg16_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg16_bn-7d8c4031.pt',
56 | 'vgg19_bn': 'https://github.com/chenyaofo/pytorch-cifar-models/releases/download/vgg/cifar100_vgg19_bn-b98f7bd7.pt',
57 | }
58 |
59 |
60 | class VGG(nn.Module):
61 |
62 | def __init__(
63 | self,
64 | features: nn.Module,
65 | num_classes: int = 10,
66 | init_weights: bool = True
67 | ) -> None:
68 | super(VGG, self).__init__()
69 | self.features = features
70 | self.classifier = nn.Sequential(
71 | nn.Linear(512, 512),
72 | nn.ReLU(True),
73 | nn.Dropout(),
74 | nn.Linear(512, 512),
75 | nn.ReLU(True),
76 | nn.Dropout(),
77 | nn.Linear(512, num_classes),
78 | )
79 | if init_weights:
80 | self._initialize_weights()
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 | x = self.features(x)
84 | x = torch.flatten(x, 1)
85 | x = self.classifier(x)
86 | return x
87 |
88 | def _initialize_weights(self) -> None:
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | if m.bias is not None:
93 | nn.init.constant_(m.bias, 0)
94 | elif isinstance(m, nn.BatchNorm2d):
95 | nn.init.constant_(m.weight, 1)
96 | nn.init.constant_(m.bias, 0)
97 | elif isinstance(m, nn.Linear):
98 | nn.init.normal_(m.weight, 0, 0.01)
99 | nn.init.constant_(m.bias, 0)
100 |
101 |
102 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
103 | layers: List[nn.Module] = []
104 | in_channels = 3
105 | for v in cfg:
106 | if v == 'M':
107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
108 | else:
109 | v = cast(int, v)
110 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
111 | if batch_norm:
112 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
113 | else:
114 | layers += [conv2d, nn.ReLU(inplace=True)]
115 | in_channels = v
116 | return nn.Sequential(*layers)
117 |
118 |
119 | cfgs: Dict[str, List[Union[str, int]]] = {
120 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
121 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
122 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
123 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
124 | }
125 |
126 |
127 | def _vgg(arch: str, cfg: str, batch_norm: bool,
128 | model_urls: Dict[str, str],
129 | pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
130 | if pretrained:
131 | kwargs['init_weights'] = False
132 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
133 | if pretrained:
134 | state_dict = load_state_dict_from_url(model_urls[arch],
135 | progress=progress)
136 | model.load_state_dict(state_dict)
137 | return model
138 |
139 |
140 | def cifar10_vgg11_bn(*args, **kwargs) -> VGG: pass
141 | def cifar10_vgg13_bn(*args, **kwargs) -> VGG: pass
142 | def cifar10_vgg16_bn(*args, **kwargs) -> VGG: pass
143 | def cifar10_vgg19_bn(*args, **kwargs) -> VGG: pass
144 |
145 |
146 | def cifar100_vgg11_bn(*args, **kwargs) -> VGG: pass
147 | def cifar100_vgg13_bn(*args, **kwargs) -> VGG: pass
148 | def cifar100_vgg16_bn(*args, **kwargs) -> VGG: pass
149 | def cifar100_vgg19_bn(*args, **kwargs) -> VGG: pass
150 |
151 |
152 | thismodule = sys.modules[__name__]
153 | for dataset in ["cifar10", "cifar100"]:
154 | for cfg, model_name in zip(["A", "B", "D", "E"], ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]):
155 | method_name = f"{dataset}_{model_name}"
156 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
157 | num_classes = 10 if dataset == "cifar10" else 100
158 | setattr(
159 | thismodule,
160 | method_name,
161 | partial(_vgg,
162 | arch=model_name,
163 | cfg=cfg,
164 | batch_norm=True,
165 | model_urls=model_urls,
166 | num_classes=num_classes)
167 | )
168 |
--------------------------------------------------------------------------------
/pytorch_cifar_models/vit.py:
--------------------------------------------------------------------------------
1 | '''
2 | Modified from https://raw.githubusercontent.com/jeonsworld/ViT-pytorch/main/models/modeling.py
3 |
4 | MIT License
5 |
6 | Copyright (c) 2020 jeonsworld
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 | '''
26 |
27 | import sys
28 | import copy
29 | import math
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 |
35 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
36 | from torch.nn.modules.utils import _pair
37 | try:
38 | from torch.hub import load_state_dict_from_url
39 | except ImportError:
40 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
41 |
42 | from functools import partial
43 | from typing import Dict, Type, Any, Callable, Union, List, Optional
44 |
45 | cifar10_pretrained_weight_urls = {
46 | 'vit_b16': '',
47 | 'vit_b32': '',
48 | 'vit_l16': '',
49 | 'vit_l32': '',
50 | 'vit_h14': '',
51 | }
52 |
53 | cifar100_pretrained_weight_urls = {
54 | 'vit_b16': '',
55 | 'vit_b32': '',
56 | 'vit_l16': '',
57 | 'vit_l32': '',
58 | 'vit_h14': '',
59 | }
60 |
61 |
62 | ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
63 | ATTENTION_K = "MultiHeadDotProductAttention_1/key"
64 | ATTENTION_V = "MultiHeadDotProductAttention_1/value"
65 | ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
66 | FC_0 = "MlpBlock_3/Dense_0"
67 | FC_1 = "MlpBlock_3/Dense_1"
68 | ATTENTION_NORM = "LayerNorm_0"
69 | MLP_NORM = "LayerNorm_2"
70 |
71 |
72 | def swish(x):
73 | return x * torch.sigmoid(x)
74 |
75 |
76 | ACT2FN = {"gelu": F.gelu, "relu": F.relu, "swish": swish}
77 |
78 |
79 | class Attention(nn.Module):
80 | def __init__(self, config, vis):
81 | super(Attention, self).__init__()
82 | self.vis = vis
83 | self.num_attention_heads = config.transformer["num_heads"]
84 | self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
85 | self.all_head_size = self.num_attention_heads * self.attention_head_size
86 |
87 | self.query = Linear(config.hidden_size, self.all_head_size)
88 | self.key = Linear(config.hidden_size, self.all_head_size)
89 | self.value = Linear(config.hidden_size, self.all_head_size)
90 |
91 | self.out = Linear(config.hidden_size, config.hidden_size)
92 | self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
93 | self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
94 |
95 | self.softmax = Softmax(dim=-1)
96 |
97 | def transpose_for_scores(self, x):
98 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
99 | x = x.view(*new_x_shape)
100 | return x.permute(0, 2, 1, 3)
101 |
102 | def forward(self, hidden_states):
103 | mixed_query_layer = self.query(hidden_states)
104 | mixed_key_layer = self.key(hidden_states)
105 | mixed_value_layer = self.value(hidden_states)
106 |
107 | query_layer = self.transpose_for_scores(mixed_query_layer)
108 | key_layer = self.transpose_for_scores(mixed_key_layer)
109 | value_layer = self.transpose_for_scores(mixed_value_layer)
110 |
111 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
112 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
113 | attention_probs = self.softmax(attention_scores)
114 | weights = attention_probs if self.vis else None
115 | attention_probs = self.attn_dropout(attention_probs)
116 |
117 | context_layer = torch.matmul(attention_probs, value_layer)
118 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
119 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
120 | context_layer = context_layer.view(*new_context_layer_shape)
121 | attention_output = self.out(context_layer)
122 | attention_output = self.proj_dropout(attention_output)
123 | return attention_output, weights
124 |
125 |
126 | class MLP(nn.Module):
127 | def __init__(self, config):
128 | super(MLP, self).__init__()
129 | self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
130 | self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
131 | self.act_fn = ACT2FN["gelu"]
132 | self.dropout = Dropout(config.transformer["dropout_rate"])
133 |
134 | self._init_weights()
135 |
136 | def _init_weights(self):
137 | nn.init.xavier_uniform_(self.fc1.weight)
138 | nn.init.xavier_uniform_(self.fc2.weight)
139 | nn.init.normal_(self.fc1.bias, std=1e-6)
140 | nn.init.normal_(self.fc2.bias, std=1e-6)
141 |
142 | def forward(self, x):
143 | x = self.fc1(x)
144 | x = self.act_fn(x)
145 | x = self.dropout(x)
146 | x = self.fc2(x)
147 | x = self.dropout(x)
148 | return x
149 |
150 |
151 | class Embeddings(nn.Module):
152 | def __init__(self, config, img_size, in_channels=3):
153 | super(Embeddings, self).__init__()
154 | img_size = _pair(img_size)
155 |
156 | patch_size = _pair(config.patches["size"])
157 | n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
158 | self.hybrid = False
159 |
160 | self.patch_embeddings = Conv2d(in_channels=in_channels,
161 | out_channels=config.hidden_size,
162 | kernel_size=patch_size,
163 | stride=patch_size)
164 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
165 | self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
166 |
167 | self.dropout = Dropout(config.transformer["dropout_rate"])
168 |
169 | def forward(self, x: torch.Tensor):
170 | B = x.shape[0]
171 | cls_tokens = self.cls_token.expand(B, -1, -1)
172 |
173 | x = self.patch_embeddings(x)
174 | x = x.flatten(2)
175 | x = x.transpose(-1, -2)
176 | x = torch.cat((cls_tokens, x), dim=1)
177 |
178 | embeddings = x + self.position_embeddings
179 | embeddings = self.dropout(embeddings)
180 | return embeddings
181 |
182 |
183 | class Block(nn.Module):
184 | def __init__(self, config, vis):
185 | super(Block, self).__init__()
186 | self.hidden_size = config.hidden_size
187 | self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
188 | self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
189 | self.ffn = MLP(config)
190 | self.attn = Attention(config, vis)
191 |
192 | def forward(self, x):
193 | h = x
194 | x = self.attention_norm(x)
195 | x, weights = self.attn(x)
196 | x = x + h
197 |
198 | h = x
199 | x = self.ffn_norm(x)
200 | x = self.ffn(x)
201 | x = x + h
202 | return x, weights
203 |
204 |
205 | class Encoder(nn.Module):
206 | def __init__(self, config, vis):
207 | super(Encoder, self).__init__()
208 | self.vis = vis
209 | self.layer = nn.ModuleList()
210 | self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
211 | for _ in range(config.transformer["num_layers"]):
212 | layer = Block(config, vis)
213 | self.layer.append(copy.deepcopy(layer))
214 |
215 | def forward(self, hidden_states):
216 | attn_weights = []
217 | for layer_block in self.layer:
218 | hidden_states, weights = layer_block(hidden_states)
219 | if self.vis:
220 | attn_weights.append(weights)
221 | encoded = self.encoder_norm(hidden_states)
222 | return encoded, attn_weights
223 |
224 |
225 | class Transformer(nn.Module):
226 | def __init__(self, config, img_size, vis):
227 | super(Transformer, self).__init__()
228 | self.embeddings = Embeddings(config, img_size=img_size)
229 | self.encoder = Encoder(config, vis)
230 |
231 | def forward(self, input_ids):
232 | embedding_output = self.embeddings(input_ids)
233 | encoded, attn_weights = self.encoder(embedding_output)
234 | return encoded, attn_weights
235 |
236 |
237 | class VisionTransformer(nn.Module):
238 | def __init__(self, config, img_size=224, num_classes=10, zero_head=False, vis=False):
239 | super(VisionTransformer, self).__init__()
240 | self.num_classes = num_classes
241 | self.classifier = config.classifier
242 |
243 | self.transformer = Transformer(config, img_size, vis)
244 | self.head = Linear(config.hidden_size, num_classes)
245 |
246 | if zero_head:
247 | with torch.no_grad():
248 | nn.init.zeros_(self.head.weight)
249 | nn.init.zeros_(self.head.bias)
250 |
251 | def forward(self, x):
252 | x, attn_weights = self.transformer(x)
253 | logits = self.head(x[:, 0])
254 |
255 | return logits
256 |
257 |
258 | class TestConfig:
259 | patches = dict(size=(16, 16))
260 | hidden_size = 1
261 | transformer = dict(
262 | mlp_dim=1,
263 | num_heads=1,
264 | num_layers=1,
265 | attention_dropout_rate=0.0,
266 | dropout_rate=0.1
267 | )
268 | classifier = 'token'
269 | representation_size = None
270 |
271 |
272 | class VitB16Config:
273 | patches = dict(size=(16, 16))
274 | hidden_size = 768
275 | transformer = dict(
276 | mlp_dim=3072,
277 | num_heads=12,
278 | num_layers=12,
279 | attention_dropout_rate=0.0,
280 | dropout_rate=0.1
281 | )
282 | classifier = 'token'
283 | representation_size = None
284 |
285 |
286 | class VitB32Config(VitB16Config):
287 | patches = dict(size=(32, 32))
288 |
289 |
290 | class VitL16Config:
291 | patches = dict(size=(16, 16))
292 | hidden_size = 1024
293 | transformer = dict(
294 | mlp_dim=4096,
295 | num_heads=16,
296 | num_layers=24,
297 | attention_dropout_rate=0.0,
298 | dropout_rate=0.1
299 | )
300 | classifier = 'token'
301 | representation_size = None
302 |
303 |
304 | class VitL32Config(VitL16Config):
305 | patches = dict(size=(32, 32))
306 |
307 |
308 | class VitH14Config:
309 | patches = dict(size=(14, 14))
310 | hidden_size = 1280
311 | transformer = dict(
312 | mlp_dim=5120,
313 | num_heads=16,
314 | num_layers=32,
315 | attention_dropout_rate=0.0,
316 | dropout_rate=0.1
317 | )
318 | classifier = 'token'
319 | representation_size = None
320 |
321 |
322 | def _vit(
323 | arch: str,
324 | config: Any,
325 | model_urls: Dict[str, str],
326 | progress: bool = True,
327 | pretrained: bool = False,
328 | **kwargs: Any
329 | ) -> VisionTransformer:
330 | model = VisionTransformer(config=config, **kwargs)
331 | if pretrained:
332 | state_dict = load_state_dict_from_url(model_urls[arch],
333 | progress=progress)
334 | model.load_state_dict(state_dict)
335 | return model
336 |
337 |
338 | def cifar10_vit_b16(*args, **kwargs) -> VisionTransformer: pass
339 | def cifar10_vit_b32(*args, **kwargs) -> VisionTransformer: pass
340 | def cifar10_vit_l16(*args, **kwargs) -> VisionTransformer: pass
341 | def cifar10_vit_l32(*args, **kwargs) -> VisionTransformer: pass
342 | def cifar10_vit_h14(*args, **kwargs) -> VisionTransformer: pass
343 |
344 |
345 | def cifar100_vit_b16(*args, **kwargs) -> VisionTransformer: pass
346 | def cifar100_vit_b32(*args, **kwargs) -> VisionTransformer: pass
347 | def cifar100_vit_l16(*args, **kwargs) -> VisionTransformer: pass
348 | def cifar100_vit_l32(*args, **kwargs) -> VisionTransformer: pass
349 | def cifar100_vit_h14(*args, **kwargs) -> VisionTransformer: pass
350 |
351 |
352 | thismodule = sys.modules[__name__]
353 | for dataset in ["cifar10", "cifar100"]:
354 | for config, model_name in zip([VitB16Config, VitB32Config, VitL16Config, VitL32Config, VitH14Config],
355 | ["vit_b16", "vit_b32", "vit_l16", "vit_l32", "vit_h14"]):
356 | method_name = f"{dataset}_{model_name}"
357 | model_urls = cifar10_pretrained_weight_urls if dataset == "cifar10" else cifar100_pretrained_weight_urls
358 | num_classes = 10 if dataset == "cifar10" else 100
359 | setattr(
360 | thismodule,
361 | method_name,
362 | partial(_vit,
363 | arch=model_name,
364 | config=config,
365 | model_urls=model_urls,
366 | num_classes=num_classes)
367 | )
368 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import re
4 | from setuptools import setup, find_packages
5 |
6 |
7 | def read(*names, **kwargs):
8 | with io.open(
9 | os.path.join(os.path.dirname(__file__), *names),
10 | encoding=kwargs.get("encoding", "utf8")
11 | ) as fp:
12 | return fp.read()
13 |
14 |
15 | def find_version(*file_paths):
16 | version_file = read(*file_paths)
17 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
18 | if version_match:
19 | return version_match.group(1)
20 | else:
21 | raise RuntimeError("Unable to find version string.")
22 |
23 |
24 | VERSION = find_version('pytorch_cifar_models', '__init__.py')
25 |
26 |
27 | def find_requirements(file_path):
28 | with open(file_path) as f:
29 | return f.read().splitlines()
30 |
31 |
32 | requirements = find_requirements("requirements.txt")
33 |
34 | setup(
35 | name="pytorch_cifar_models",
36 | version=VERSION,
37 | description="Pretrained models for CIFAR10/100 in PyTorch",
38 | url="https://github.com/chenyaofo/pytorch-cifar-models",
39 | author="chenyaofo",
40 | author_email="chenyaofo@gmail.com",
41 | packages=find_packages(exclude=['test']),
42 | install_requires=requirements,
43 | )
--------------------------------------------------------------------------------
/tests/pytorch_cifar_models/test_mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import pytorch_cifar_models.mobilenetv2 as mobilenetv2
5 |
6 |
7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
8 | @pytest.mark.parametrize("model_name", ["mobilenetv2_x0_5", "mobilenetv2_x0_75", "mobilenetv2_x1_0", "mobilenetv2_x1_4"])
9 | def test_resnet(dataset, model_name):
10 | num_classes = 10 if dataset == "cifar10" else 100
11 | model = getattr(mobilenetv2, f"{dataset}_{model_name}")()
12 | x = torch.empty((1, 3, 32, 32))
13 | y = model(x)
14 | assert y.shape == (1, num_classes)
--------------------------------------------------------------------------------
/tests/pytorch_cifar_models/test_repvgg.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import pytorch_cifar_models.repvgg as repvgg
5 |
6 |
7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
8 | @pytest.mark.parametrize("model_name", ["repvgg_a0", "repvgg_a1", "repvgg_a2"])
9 | def test_resnet(dataset, model_name):
10 | num_classes = 10 if dataset == "cifar10" else 100
11 | model = getattr(repvgg, f"{dataset}_{model_name}")()
12 | x = torch.empty((1, 3, 32, 32))
13 | y = model(x)
14 | assert y.shape == (1, num_classes)
15 |
--------------------------------------------------------------------------------
/tests/pytorch_cifar_models/test_resnet.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import pytorch_cifar_models.resnet as resnet
5 |
6 |
7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
8 | @pytest.mark.parametrize("model_name", ["resnet20", "resnet32", "resnet44", "resnet56"])
9 | def test_resnet(dataset, model_name):
10 | num_classes = 10 if dataset == "cifar10" else 100
11 | model = getattr(resnet, f"{dataset}_{model_name}")()
12 | x = torch.empty((1, 3, 32, 32))
13 | y = model(x)
14 | assert y.shape == (1, num_classes)
15 |
--------------------------------------------------------------------------------
/tests/pytorch_cifar_models/test_shufflenetv2.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import pytorch_cifar_models.shufflenetv2 as shufflenetv2
5 |
6 |
7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
8 | @pytest.mark.parametrize("model_name", ["shufflenetv2_x0_5", "shufflenetv2_x1_0", "shufflenetv2_x1_5", "shufflenetv2_x2_0"])
9 | def test_resnet(dataset, model_name):
10 | num_classes = 10 if dataset == "cifar10" else 100
11 | model = getattr(shufflenetv2, f"{dataset}_{model_name}")()
12 | x = torch.empty((1, 3, 32, 32))
13 | y = model(x)
14 | assert y.shape == (1, num_classes)
15 |
--------------------------------------------------------------------------------
/tests/pytorch_cifar_models/test_vgg.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import pytorch_cifar_models.vgg as vgg
5 |
6 |
7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
8 | @pytest.mark.parametrize("model_name", ["vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"])
9 | def test_resnet(dataset, model_name):
10 | num_classes = 10 if dataset == "cifar10" else 100
11 | model = getattr(vgg, f"{dataset}_{model_name}")()
12 | x = torch.empty((1, 3, 32, 32))
13 | y = model(x)
14 | assert y.shape == (1, num_classes)
15 |
--------------------------------------------------------------------------------
/tests/pytorch_cifar_models/test_vit.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | import pytorch_cifar_models.vit as vit
5 |
6 |
7 | @pytest.mark.parametrize("dataset", ["cifar10", "cifar100"])
8 | @pytest.mark.parametrize("model_name", ["vit_b16", "vit_b32", "vit_l16", "vit_l32", "vit_h14"])
9 | def test_resnet(dataset, model_name):
10 | num_classes = 10 if dataset == "cifar10" else 100
11 | model = getattr(vit, f"{dataset}_{model_name}")()
12 | x = torch.empty((1, 3, 224, 224))
13 | y = model(x)
14 | assert y.shape == (1, num_classes)
15 |
--------------------------------------------------------------------------------