├── .gitignore
├── README.md
├── environment.yml
└── models
├── data
└── .gitignore
├── dataset_nn
└── dataset_neural_nets.ipynb
└── densenet_web
├── LICENSE
├── __init__.py
├── app.py
├── densenet.ipynb
├── densenet_post.png
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # .idea
10 | .idea/
11 |
12 | # checkpoint files
13 | models/cv/checkpoint/
14 |
15 | # ONNX Proto files
16 | models/*/*.proto
17 |
18 | .DS_Store
19 | models/.DS_Store
20 |
21 | # Distribution /
22 | packagd/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # Environments
96 | .env
97 | .venv
98 | env/
99 | venv/
100 | ENV/
101 | env.bak/
102 | venv.bak/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # hackpack-ml
2 |
3 | Learn how to make and deploy ML hacks in PyTorch.
4 |
5 | ### Features
6 | This hackpack aims to make iterating on your models easy while also providing for quick integration into apps.
7 |
8 | We include:
9 | 1. An introductory notebook in which we load our own dataset into PyTorch and build a simple neural network for regression on tabular data:
10 | `/models/dataset_nn/dataset_neural_nets.ipynb`
11 |
12 | 2. An application-focused notebook in which we load and train a DenseNet model for image classification on MNIST and CIFAR10 datasets,
13 | and then export the model into a Flask server as our own API:
14 | `/models/densenet_web/densenet.ipynb`
15 |
16 | ## Getting Started
17 |
18 | This hackpack assumes some prior understanding of Python. Understanding of Deep Learning techniques is greatly beneficial but not required. If you are interested in learning, [Deep Learning by Goodfellow et al.](https://www.deeplearningbook.org/) is a good starting point.
19 |
20 | ### Google Colab
21 |
22 | Both notebooks are also hosted on Google Colab, a free Jupyter notebook environment with GPU/TPU support.
23 | They are accessible at dataset_neural_nets and densenet_web respectively. Feel free to open them in the playground environment or copy!
24 |
25 | ### Installing
26 | If you wish to work on the notebooks locally, the dependencies are listed in `environment.yml`.
27 | Create a new Conda environment with:
28 | ```
29 | conda env create --file=environment.yml
30 | ```
31 | ### License
32 | MIT
33 |
34 | # About HackPacks 🌲
35 |
36 | HackPacks are built by the [TreeHacks](https://www.treehacks.com/) team to help hackers build great projects at our hackathon that happens every February at Stanford. We believe that everyone of every skill level can learn to make awesome things, and this is one way we help facilitate hacker culture. We open source our hackpacks (along with our internal tech) so everyone can learn from and use them! Feel free to use these at your own hackathons, workshops, and anything else that promotes building :)
37 |
38 | If you're interested in attending TreeHacks, you can apply on our [website](https://www.treehacks.com/) during the application period.
39 |
40 | You can follow us here on [GitHub](https://github.com/treehacks) to see all the open source work we do (we love issues, contributions, and feedback of any kind!), and on [Facebook](https://facebook.com/treehacks), [Twitter](https://twitter.com/hackwithtrees), and [Instagram](https://instagram.com/hackwithtrees) to see general updates from TreeHacks.
41 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: hackpackml
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - appnope=0.1.0=py36_1000
8 | - attrs=18.2.0=py_0
9 | - backcall=0.1.0=py_0
10 | - blas=1.0=mkl
11 | - ca-certificates=2018.11.29=ha4d7672_0
12 | - certifi=2018.11.29=py36_1000
13 | - cffi=1.11.5=py36h342bebf_1001
14 | - click=7.0=py_0
15 | - cycler=0.10.0=py_1
16 | - decorator=4.3.2=py_0
17 | - entrypoints=0.3=py36_1000
18 | - flask=1.0.2=py_2
19 | - freetype=2.9.1=h597ad8a_1005
20 | - icu=58.2=h0a44026_1000
21 | - intel-openmp=2019.1=144
22 | - ipykernel=5.1.0=py36h24bf2e0_1002
23 | - ipython=7.2.0=py36h24bf2e0_1000
24 | - ipython_genutils=0.2.0=py_1
25 | - ipywidgets=7.4.2=py_0
26 | - itsdangerous=1.1.0=py_0
27 | - jedi=0.13.2=py36_1000
28 | - jinja2=2.10=py_1
29 | - jpeg=9c=h1de35cc_1001
30 | - jsonschema=3.0.0a3=py36_1000
31 | - jupyter=1.0.0=py_1
32 | - jupyter_client=5.2.4=py_1
33 | - jupyter_console=6.0.0=py_0
34 | - jupyter_core=4.4.0=py_0
35 | - kiwisolver=1.0.1=py36h04f5b5a_1002
36 | - libcxx=7.0.0=h2d50403_2
37 | - libffi=3.2.1=h0a44026_1005
38 | - libgfortran=3.0.1=h93005f0_2
39 | - libpng=1.6.36=ha441bb4_1000
40 | - libsodium=1.0.16=h1de35cc_1001
41 | - libtiff=4.0.10=h79f4b77_1001
42 | - llvm-meta=7.0.0=0
43 | - markupsafe=1.1.0=py36h1de35cc_1000
44 | - matplotlib=3.0.2=py36_1002
45 | - matplotlib-base=3.0.2=py36hf043ca5_1002
46 | - mistune=0.8.4=py36h1de35cc_1000
47 | - mkl=2019.1=144
48 | - mkl_fft=1.0.10=py36h1de35cc_1
49 | - mkl_random=1.0.2=py36h1702cab_2
50 | - nbconvert=5.3.1=py_1
51 | - nbformat=4.4.0=py_1
52 | - ncurses=6.1=h0a44026_1002
53 | - ninja=1.9.0=h04f5b5a_0
54 | - notebook=5.7.4=py36_1000
55 | - numpy=1.15.4=py36hacdab7b_0
56 | - numpy-base=1.15.4=py36h6575580_0
57 | - olefile=0.46=py_0
58 | - openssl=1.1.1a=h1de35cc_1000
59 | - pandas=0.24.0=py36h0a44026_0
60 | - pandoc=2.6=0
61 | - pandocfilters=1.4.2=py_1
62 | - parso=0.3.2=py_0
63 | - pexpect=4.6.0=py36_1000
64 | - pickleshare=0.7.5=py36_1000
65 | - pillow=5.4.1=py36hbddbef0_1000
66 | - pip=19.0.1=py36_0
67 | - prometheus_client=0.5.0=py_0
68 | - prompt_toolkit=2.0.8=py_0
69 | - ptyprocess=0.6.0=py36_1000
70 | - pycparser=2.19=py_0
71 | - pygments=2.3.1=py_0
72 | - pyparsing=2.3.1=py_0
73 | - pyqt=5.6.0=py36hc26a216_1008
74 | - pyrsistent=0.14.9=py36h1de35cc_1000
75 | - python=3.6.8=haf84260_0
76 | - python-dateutil=2.7.5=py_0
77 | - pytorch=1.0.0=py3.6_1
78 | - pytz=2018.9=py_0
79 | - pyzmq=17.1.2=py36h111632d_1001
80 | - qt=5.6.2=h822fa55_1013
81 | - qtconsole=4.4.3=py_0
82 | - readline=7.0=hcfe32e1_1001
83 | - send2trash=1.5.0=py_0
84 | - setuptools=40.7.1=py36_0
85 | - sip=4.18.1=py36h0a44026_1000
86 | - six=1.12.0=py36_1000
87 | - sqlite=3.26.0=h1765d9f_1000
88 | - terminado=0.8.1=py36_1001
89 | - testpath=0.4.2=py36_1000
90 | - tk=8.6.9=ha441bb4_1000
91 | - torchvision=0.2.1=py_2
92 | - tornado=5.1.1=py36h1de35cc_1000
93 | - tqdm=4.30.0=py_0
94 | - traitlets=4.3.2=py36_1000
95 | - wcwidth=0.1.7=py_1
96 | - webencodings=0.5.1=py_1
97 | - werkzeug=0.14.1=py_0
98 | - wheel=0.32.3=py36_0
99 | - widgetsnbextension=3.4.2=py36_1000
100 | - xz=5.2.4=h1de35cc_1001
101 | - zeromq=4.2.5=h0a44026_1006
102 | - zlib=1.2.11=h1de35cc_1004
103 | - pip:
104 | - bleach==3.1.0
105 |
--------------------------------------------------------------------------------
/models/data/.gitignore:
--------------------------------------------------------------------------------
1 | ./*
2 | !/.gitignore
3 |
--------------------------------------------------------------------------------
/models/dataset_nn/dataset_neural_nets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Datasets and Neural Networks\n",
8 | "This notebook will step through the process of loading an arbitrary dataset in PyTorch, and creating a simple neural network for regression."
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "# Datasets\n",
16 | "We will first work through loading an arbitrary dataset in PyTorch. For this project, we chose the delve abalone dataset. \n",
17 | "\n",
18 | "First, download and unzip the dataset from the link above, then unzip `Dataset.data.gz` and move `Dataset.data` into `hackpack-ml/models/data`.\n",
19 | "We are given the following attribute information in the spec:\n",
20 | "```\n",
21 | "Attributes:\n",
22 | " 1 sex u M F I\t# Gender or Infant (I)\n",
23 | " 2 length u (0,Inf]\t# Longest shell measurement (mm)\n",
24 | " 3 diameter u (0,Inf]\t# perpendicular to length (mm)\n",
25 | " 4 height u (0,Inf]\t# with meat in shell (mm)\n",
26 | " 5 whole_weight u (0,Inf]\t# whole abalone (gr)\n",
27 | " 6 shucked_weight u (0,Inf]\t# weight of meat (gr) \n",
28 | " 7 viscera_weight u (0,Inf]\t# gut weight (after bleeding) (gr)\n",
29 | " 8 shell_weight u (0,Inf]\t# after being dried (gr)\n",
30 | " 9 rings u 0..29\t# +1.5 gives the age in years\n",
31 | "```"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import math\n",
41 | "from tqdm import tqdm\n",
42 | "import torch\n",
43 | "import torch.nn as nn\n",
44 | "import torch.optim as optim\n",
45 | "import torch.utils.data as data\n",
46 | "import torch.nn.functional as F\n",
47 | "import pandas as pd\n",
48 | "\n",
49 | "from torch.utils.data import Dataset, DataLoader"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "Pandas is a data manipulation library that works really well with structured data. We can use Pandas DataFrames to load the dataset."
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "col_names = ['sex', 'length', 'diameter', 'height', 'whole_weight', \n",
66 | " 'shucked_weight', 'viscera_weight', 'shell_weight', 'rings']\n",
67 | "abalone_df = pd.read_csv('../data/Dataset.data', sep=' ', names=col_names)\n",
68 | "abalone_df.head(n=3)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "We define a subclass of PyTorch Dataset for our Abalone dataset."
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "class AbaloneDataset(data.Dataset):\n",
85 | " \"\"\"Abalone dataset. Provides quick iteration over rows of data.\"\"\"\n",
86 | "\n",
87 | " def __init__(self, csv):\n",
88 | " \"\"\"\n",
89 | " Args: csv (string): Path to the Abalone dataset.\n",
90 | " \"\"\"\n",
91 | " self.features = ['sex', 'length', 'diameter', 'height', 'whole_weight', \n",
92 | " 'shucked_weight', 'viscera_weight', 'shell_weight']\n",
93 | " self.y = ['rings']\n",
94 | " self.abalone_df = pd.read_csv(csv, sep=' ', names=(self.features + self.y))\n",
95 | " \n",
96 | " # Turn categorical data into machine interpretable format (one hot)\n",
97 | " self.abalone_df['sex'] = pd.get_dummies(self.abalone_df['sex'])\n",
98 | "\n",
99 | " def __len__(self):\n",
100 | " return len(self.abalone_df)\n",
101 | "\n",
102 | " def __getitem__(self, idx):\n",
103 | " \"\"\"Return (x,y) pair where x are abalone features and y is age.\"\"\"\n",
104 | " features = self.abalone_df.iloc[idx][self.features].values\n",
105 | " y = self.abalone_df.iloc[idx][self.y]\n",
106 | " return torch.Tensor(features).float(), torch.Tensor(y).float()"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "# Neural Networks\n",
114 | "\n",
115 | "The task is to predict the age (number of rings) of abalone from physical measurements. We build a simple neural network with one hidden layer to model the regression."
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "class Net(nn.Module):\n",
125 | "\n",
126 | " def __init__(self, feature_size):\n",
127 | " super(Net, self).__init__()\n",
128 | " # feature_size input channels (8), 1 output channels\n",
129 | " self.fc1 = nn.Linear(feature_size, 4)\n",
130 | " self.fc2 = nn.Linear(4, 1)\n",
131 | "\n",
132 | " def forward(self, x):\n",
133 | " x = F.relu(self.fc1(x))\n",
134 | " x = self.fc2(x)\n",
135 | " return x"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {},
141 | "source": [
142 | "We instantiate an Abalone dataset instance and create DataLoaders for train and test sets."
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "dataset = AbaloneDataset('../data/Dataset.data')\n",
152 | "train_split, test_split = math.floor(len(dataset) * 0.8), math.ceil(len(dataset) * 0.2)\n",
153 | "\n",
154 | "trainset = [dataset[i] for i in range(train_split)]\n",
155 | "testset = [dataset[train_split + j] for j in range(test_split)]\n",
156 | "batch_sz = len(trainset) # Compact data allows for big batch size\n",
157 | "trainloader = data.DataLoader(trainset, batch_size=batch_sz, shuffle=True, num_workers=4)\n",
158 | "testloader = data.DataLoader(testset, batch_size=batch_sz, shuffle=False, num_workers=4)"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {},
164 | "source": [
165 | "Now, we can initialize our network and define train and test functions"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "net = Net(len(dataset.features))\n",
175 | "loss_fn = nn.MSELoss()\n",
176 | "optimizer = optim.Adam(net.parameters(), lr=0.1)\n",
177 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
178 | "gpu_ids = [0] # On Colab, we have access to one GPU. Change this value as you see fit\n",
179 | "\n",
180 | "def train(epoch):\n",
181 | " \"\"\"\n",
182 | " Trains our net on data from the trainloader for a single epoch\n",
183 | " \"\"\"\n",
184 | " net.train()\n",
185 | " with tqdm(total=len(trainloader.dataset)) as progress_bar:\n",
186 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
187 | " inputs, targets = inputs.to(device), targets.to(device)\n",
188 | " optimizer.zero_grad() # Clear any stored gradients for new step\n",
189 | " outputs = net(inputs.float())\n",
190 | " loss = loss_fn(outputs, targets) # Calculate loss between prediction and label \n",
191 | " loss.backward() # Backpropagate gradient updates through net based on loss\n",
192 | " optimizer.step() # Update net weights based on gradients\n",
193 | " progress_bar.set_postfix(loss=loss.item())\n",
194 | " progress_bar.update(inputs.size(0))\n",
195 | " \n",
196 | " \n",
197 | "def test(epoch):\n",
198 | " \"\"\"\n",
199 | " Run net in inference mode on test data. \n",
200 | " \"\"\" \n",
201 | " net.eval()\n",
202 | " # Ensures the net will not update weights\n",
203 | " with torch.no_grad():\n",
204 | " with tqdm(total=len(testloader.dataset)) as progress_bar:\n",
205 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n",
206 | " inputs, targets = inputs.to(device).float(), targets.to(device).float()\n",
207 | " outputs = net(inputs)\n",
208 | " loss = loss_fn(outputs, targets)\n",
209 | " progress_bar.set_postfix(testloss=loss.item())\n",
210 | " progress_bar.update(inputs.size(0))\n"
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {},
216 | "source": [
217 | "Now that everything is prepared, it's time to train!"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "test_freq = 5 # Frequency to run model on validation data\n",
227 | "\n",
228 | "for epoch in range(0, 200):\n",
229 | " train(epoch)\n",
230 | " if epoch % test_freq == 0:\n",
231 | " test(epoch)"
232 | ]
233 | },
234 | {
235 | "cell_type": "markdown",
236 | "metadata": {},
237 | "source": [
238 | "We use the network's eval mode to do a sample prediction to see how well it does."
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": null,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "net.eval()\n",
248 | "sample = testset[0]\n",
249 | "predicted_age = net(sample[0])\n",
250 | "true_age = sample[1]\n",
251 | "\n",
252 | "print(f'Input features: {sample[0]}')\n",
253 | "print(f'Predicted age: {predicted_age.item()}, True age: {true_age[0]}')"
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "metadata": {},
259 | "source": [
260 | "Congratulations! You now know how to load your own datasets into PyTorch and run models on it. For an example of Computer Vision, check out the DenseNet notebook. Happy hacking!"
261 | ]
262 | }
263 | ],
264 | "metadata": {
265 | "kernelspec": {
266 | "display_name": "Python 3",
267 | "language": "python",
268 | "name": "python3"
269 | },
270 | "language_info": {
271 | "codemirror_mode": {
272 | "name": "ipython",
273 | "version": 3
274 | },
275 | "file_extension": ".py",
276 | "mimetype": "text/x-python",
277 | "name": "python",
278 | "nbconvert_exporter": "python",
279 | "pygments_lexer": "ipython3",
280 | "version": "3.6.8"
281 | }
282 | },
283 | "nbformat": 4,
284 | "nbformat_minor": 2
285 | }
286 |
--------------------------------------------------------------------------------
/models/densenet_web/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) Soumith Chintala 2016, 2017
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/models/densenet_web/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import densenet121
--------------------------------------------------------------------------------
/models/densenet_web/app.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple Flask server to return predictions from trained densenet_web model
3 | """
4 |
5 | from flask import Flask, request, redirect, flash
6 | import torch
7 | from torchvision.transforms.functional import to_tensor
8 | from PIL import Image
9 |
10 | from .model import densenet121
11 |
12 | app = Flask(__name__)
13 | cp = torch.load('models/densenet_web/checkpoint/best_model.pth.tar')
14 | net = densenet121()
15 | net.load_state_dict(cp['net'])
16 | net.eval()
17 |
18 |
19 | def classify_image(x):
20 | """
21 | :param x: image
22 | :return: class prediction from densenet_web model
23 | """
24 | outputs = net(x)
25 | _, predicted = outputs.max(1)
26 | return predicted
27 |
28 |
29 | @app.route('/predict', methods=['POST'])
30 | def predict():
31 | """
32 | :return: prediction for image from POST request.
33 | """
34 | if request.method == 'POST':
35 | if 'img' not in request.files:
36 | flash('No image in request')
37 | return redirect(request.url)
38 |
39 | # preprocess image
40 | img = request.files['img']
41 | img = Image.open(img).convert("RGB")
42 | x = to_tensor(img)
43 |
44 | prediction = classify_image(x.unsqueeze(0))
45 | return str(int(prediction[0].data))
46 |
--------------------------------------------------------------------------------
/models/densenet_web/densenet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# DenseNet\n",
8 | "In this notebook, we train a DenseNet classifier for SVHN and CIFAR10 datasets, and launch it in a web API."
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "metadata": {},
15 | "outputs": [],
16 | "source": [
17 | "\"\"\"\n",
18 | "Script adapted from: https://github.com/kuangliu/pytorch-cifar\n",
19 | "\"\"\"\n",
20 | "import torch\n",
21 | "import torch.nn as nn\n",
22 | "import torch.optim as optim\n",
23 | "import torch.utils.data as data\n",
24 | "from torch.autograd import Variable\n",
25 | "from torch.optim import lr_scheduler\n",
26 | "import torchvision\n",
27 | "import torchvision.transforms as transforms\n",
28 | "from torchvision import datasets, models, transforms\n",
29 | "import sys\n",
30 | "import os\n",
31 | "from tqdm import tqdm\n",
32 | "import matplotlib.pyplot as plt\n",
33 | "\n",
34 | "from model import densenet121\n"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "## Dataloader\n",
42 | "\n",
43 | "Here, we load either the SVHN or CIFAR10 datasets, which are provided through torchvision. If you wish to use your own, check out the datasets notebook to see how to create your own dataset class.\n",
44 | "\n",
45 | "Note that we are treating the test set as a validation set."
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "# Transform from PIL image format to tensor format\n",
55 | "transform_train = transforms.Compose([\n",
56 | " # You can add more data augmentation techniques in series: \n",
57 | " # https://pytorch.org/docs/stable/torchvision/transforms.html\n",
58 | " transforms.ToTensor()\n",
59 | "])\n",
60 | "\n",
61 | "transform_test = transforms.Compose([\n",
62 | " transforms.ToTensor()\n",
63 | "])\n",
64 | "\n",
65 | "# CIFAR10 Dataset: https://www.cs.toronto.edu/~kriz/cifar.html\n",
66 | "trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)\n",
67 | "testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)\n",
68 | "\n",
69 | "# SVHN Dataset: http://ufldl.stanford.edu/housenumbers/\n",
70 | "# trainset = torchvision.datasets.SVHN(root='../data', split='train', transform=transform_train, download=True)\n",
71 | "# testset = torchvision.datasets.SVHN(root='../data', split='test', transform=transform_test, download=True)"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "If making a proof of concept application, we can choose to overfit on a data subset for quick training."
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "train_ct = 128 # Size of train data\n",
88 | "test_ct = 32 # Size of test data\n",
89 | "batch_sz = 32 # Size of batch for one gradient step (bigger batches take more memory but are faster)\n",
90 | "num_workers = 4 # 4 * number of GPUs\n",
91 | "\n",
92 | "if train_ct:\n",
93 | " trainset = data.dataset.Subset(trainset, range(train_ct))\n",
94 | "\n",
95 | "if test_ct:\n",
96 | " testset = data.dataset.Subset(testset, range(test_ct))\n",
97 | "\n",
98 | "trainloader = data.DataLoader(trainset, batch_size=batch_sz, shuffle=True, num_workers=num_workers, )\n",
99 | "testloader = data.DataLoader(testset, batch_size=batch_sz, shuffle=False, num_workers=num_workers)"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "metadata": {},
105 | "source": [
106 | "## Training\n",
107 | "\n",
108 | "First, we configure the model for training, and define our loss and optimizer."
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
118 | "gpu_ids = [0] # On Colab, we have access to one GPU. Change this value as you see fit\n",
119 | "\n",
120 | "net = densenet121()\n",
121 | "net = net.to(device)\n",
122 | "\n",
123 | "if device == 'cuda':\n",
124 | " net = torch.nn.DataParallel(net, gpu_ids)\n",
125 | " \n",
126 | "resume = False # Resume training from a saved checkpoint\n",
127 | "\n",
128 | "if resume:\n",
129 | " print('Resuming from checkpoint at ./checkpoint/best_model.pth.tar')\n",
130 | " assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'\n",
131 | " checkpoint = torch.load('./checkpoint/best_model.pth.tar')\n",
132 | " net.load_state_dict(checkpoint['net'])\n",
133 | " global best_loss\n",
134 | " best_loss = checkpoint['test_loss']\n",
135 | " start_epoch = checkpoint['epoch']\n",
136 | " \n",
137 | "# Loss function: https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss\n",
138 | "loss_fn = nn.CrossEntropyLoss() \n",
139 | "optimizer = optim.Adam(net.parameters(), lr=0.1)"
140 | ]
141 | },
142 | {
143 | "cell_type": "markdown",
144 | "metadata": {},
145 | "source": [
146 | "We define `train`, which performs a forward/back propagation pass on our dataset per epoch. Similarly, `test` performs evaluation on the test set."
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "def train(epoch):\n",
156 | " \"\"\"\n",
157 | " Trains our net on data from the trainloader for a single epoch\n",
158 | " \"\"\"\n",
159 | " net.train()\n",
160 | " train_loss = 0\n",
161 | " correct = 0\n",
162 | " total = 0\n",
163 | " with tqdm(total=len(trainloader.dataset)) as progress_bar:\n",
164 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n",
165 | " inputs, targets = inputs.to(device), targets.to(device)\n",
166 | " \n",
167 | " optimizer.zero_grad() # Clear any stored gradients for new step\n",
168 | " outputs = net(inputs)\n",
169 | " \n",
170 | " loss = loss_fn(outputs, targets) # Calculate loss between prediction and label \n",
171 | " loss.backward() # Backpropagate gradient updates through net based on loss\n",
172 | " optimizer.step() # Update net weights based on gradients\n",
173 | "\n",
174 | " train_loss += loss.item()\n",
175 | " _, predicted = outputs.max(1)\n",
176 | " total += targets.size(0)\n",
177 | " correct += predicted.eq(targets).sum().item()\n",
178 | " acc = (100. * correct / total)\n",
179 | " \n",
180 | " progress_bar.set_postfix(loss=train_loss/(batch_idx+1), accuracy=f'{acc}%')\n",
181 | " progress_bar.update(inputs.size(0))\n",
182 | " \n",
183 | " \n",
184 | "def test(epoch):\n",
185 | " \"\"\"\n",
186 | " Run net in inference mode on test data. \n",
187 | " \"\"\" \n",
188 | " global best_acc\n",
189 | " net.eval()\n",
190 | " test_loss = 0\n",
191 | " correct = 0\n",
192 | " total = 0\n",
193 | " best_acc = 0\n",
194 | " # Ensures the net will not update weights\n",
195 | " with torch.no_grad():\n",
196 | " with tqdm(total=len(testloader.dataset)) as progress_bar:\n",
197 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n",
198 | " inputs, targets = inputs.to(device), targets.to(device)\n",
199 | " outputs = net(inputs)\n",
200 | " loss = loss_fn(outputs, targets)\n",
201 | " \n",
202 | " test_loss += loss.item()\n",
203 | " _, predicted = outputs.max(1)\n",
204 | " total += targets.size(0)\n",
205 | " correct += predicted.eq(targets).sum().item()\n",
206 | " \n",
207 | " acc = (100. * correct / total)\n",
208 | " progress_bar.set_postfix(loss=test_loss/(batch_idx+1), accuracy=f'{acc}%')\n",
209 | " progress_bar.update(inputs.size(0))\n",
210 | " \n",
211 | " # Save best model\n",
212 | " if acc > best_acc:\n",
213 | " print(\"Saving...\")\n",
214 | " save_state(net, acc, epoch)\n",
215 | " best_acc = acc\n",
216 | "\n",
217 | "def save_state(net, acc, epoch):\n",
218 | " \"\"\"\n",
219 | " Save the current net state, accuracy and epoch\n",
220 | " \"\"\"\n",
221 | " state = {\n",
222 | " 'net': net.state_dict(),\n",
223 | " 'acc': acc,\n",
224 | " 'epoch': epoch,\n",
225 | " }\n",
226 | " if not os.path.isdir('checkpoint'):\n",
227 | " os.mkdir('checkpoint')\n",
228 | " torch.save(state, './checkpoint/best_model.pth.tar')"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "metadata": {},
235 | "outputs": [],
236 | "source": [
237 | "test_freq = 5 # Frequency to run model on validation data\n",
238 | "\n",
239 | "for epoch in range(0, 100):\n",
240 | " train(epoch)\n",
241 | " if epoch % test_freq == 0:\n",
242 | " test(epoch)"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "metadata": {},
248 | "source": [
249 | "## Inference\n",
250 | "Now that we have trained a model, we can use it for inference on new data! With each run, we save the best model weights at:\n",
251 | "`./checkpoint/best_model.pth.tar`"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "metadata": {},
258 | "outputs": [],
259 | "source": [
260 | "def classify_image(x):\n",
261 | " \"\"\"\n",
262 | " Return model classification for an image\n",
263 | " \"\"\"\n",
264 | " outputs = net(x)\n",
265 | " _, predicted = outputs.max(1)\n",
266 | " return predicted"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": null,
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "# Load the checkpoint weights and set the model weights\n",
276 | "cp = torch.load('./checkpoint/best_model.pth.tar')\n",
277 | "net.load_state_dict(cp['net'])\n",
278 | "net.eval()\n",
279 | "\n",
280 | "# Some preprocessing to get the channels in the right order\n",
281 | "sample = trainset[1][0]\n",
282 | "plt.imshow(sample.permute(1,2,0))\n",
283 | "sample = sample.unsqueeze(0)\n",
284 | "\n",
285 | "y = classify_image(sample)[0]\n",
286 | "print(f'Predicted class: {y}')"
287 | ]
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "metadata": {},
292 | "source": [
293 | "(If you are running on CIFAR, you can get the associated class labels at https://www.cs.toronto.edu/~kriz/cifar.html.\n",
294 | "0 = airplane, 1 = automobile, etc.)"
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "metadata": {},
300 | "source": [
301 | "## Export (API)\n",
302 | "We want to use our models inside our apps. One way to do this is to wrap our calls in an API. We included an `app.py` that takes the model saved at `models/densenet/checkpoint/best_model.pth.tar` and wraps it in a simple Flask server that receives images and returns the classification. Check out our web API hackpack for more information on how this works! (https://github.com/TreeHacks/hackpack-web-api)\n",
303 | "\n",
304 | "To post an image to your server, you can use Postman to send a request. Here is a screenshot of an example post to get a response classification for some image:\n",
305 | "\n",
306 | "\n",
307 | "\n",
308 | "Congratulations! You have just trained and launched your own ML model."
309 | ]
310 | }
311 | ],
312 | "metadata": {
313 | "kernelspec": {
314 | "display_name": "Python 3",
315 | "language": "python",
316 | "name": "python3"
317 | },
318 | "language_info": {
319 | "codemirror_mode": {
320 | "name": "ipython",
321 | "version": 3
322 | },
323 | "file_extension": ".py",
324 | "mimetype": "text/x-python",
325 | "name": "python",
326 | "nbconvert_exporter": "python",
327 | "pygments_lexer": "ipython3",
328 | "version": "3.6.8"
329 | }
330 | },
331 | "nbformat": 4,
332 | "nbformat_minor": 2
333 | }
334 |
--------------------------------------------------------------------------------
/models/densenet_web/densenet_post.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TreeHacks/hackpack-ml/c5fed6a047d6082c2f10592ef35ac999a4e4e393/models/densenet_web/densenet_post.png
--------------------------------------------------------------------------------
/models/densenet_web/model.py:
--------------------------------------------------------------------------------
1 | """
2 | DenseNet models adapted from:
3 | https://github.com/pytorch/vision/blob/master/torchvision/models/model.py
4 | """
5 | import re
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.utils.model_zoo as model_zoo
10 | from collections import OrderedDict
11 |
12 |
13 | class _DenseLayer(nn.Sequential):
14 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
15 | super(_DenseLayer, self).__init__()
16 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
17 | self.add_module('relu1', nn.ReLU(inplace=True)),
18 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
19 | growth_rate, kernel_size=1, stride=1,
20 | bias=False)),
21 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
22 | self.add_module('relu2', nn.ReLU(inplace=True)),
23 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
24 | kernel_size=3, stride=1, padding=1,
25 | bias=False)),
26 | self.drop_rate = drop_rate
27 |
28 | def forward(self, x):
29 | new_features = super(_DenseLayer, self).forward(x)
30 | if self.drop_rate > 0:
31 | new_features = F.dropout(new_features, p=self.drop_rate,
32 | training=self.training)
33 | return torch.cat([x, new_features], 1)
34 |
35 |
36 | class _DenseBlock(nn.Sequential):
37 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate,
38 | drop_rate):
39 | super(_DenseBlock, self).__init__()
40 | for i in range(num_layers):
41 | layer = _DenseLayer(num_input_features + i * growth_rate,
42 | growth_rate, bn_size, drop_rate)
43 | self.add_module('denselayer%d' % (i + 1), layer)
44 |
45 |
46 | class _Transition(nn.Sequential):
47 | def __init__(self, num_input_features, num_output_features):
48 | super(_Transition, self).__init__()
49 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
50 | self.add_module('relu', nn.ReLU(inplace=True))
51 | self.add_module('conv',
52 | nn.Conv2d(num_input_features, num_output_features,
53 | kernel_size=1, stride=1, bias=False))
54 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
55 |
56 |
57 | class DenseNet(nn.Module):
58 | r"""Densenet-BC model class, based on
59 | `"Densely Connected Convolutional Networks" `_
60 |
61 | Args:
62 | growth_rate (int) - how many filters to add each layer (`k` in paper)
63 | block_config (list of 4 ints) - how many layers in each pooling block
64 | num_init_features (int) - the number of filters to learn in the first convolution layer
65 | bn_size (int) - multiplicative factor for number of bottle neck layers
66 | (i.e. bn_size * k features in the bottleneck layer)
67 | drop_rate (float) - dropout rate after each dense layer
68 | num_classes (int) - number of classification classes
69 | """
70 |
71 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
72 | num_init_features=64, bn_size=4, drop_rate=0,
73 | num_classes=1000):
74 |
75 | super(DenseNet, self).__init__()
76 |
77 | # First convolution
78 | self.features = nn.Sequential(OrderedDict([
79 | ('conv0',
80 | nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3,
81 | bias=False)),
82 | ('norm0', nn.BatchNorm2d(num_init_features)),
83 | ('relu0', nn.ReLU(inplace=True)),
84 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
85 | ]))
86 |
87 | # Each denseblock
88 | num_features = num_init_features
89 | for i, num_layers in enumerate(block_config):
90 | block = _DenseBlock(num_layers=num_layers,
91 | num_input_features=num_features,
92 | bn_size=bn_size, growth_rate=growth_rate,
93 | drop_rate=drop_rate)
94 | self.features.add_module('denseblock%d' % (i + 1), block)
95 | num_features = num_features + num_layers * growth_rate
96 | if i != len(block_config) - 1:
97 | trans = _Transition(num_input_features=num_features,
98 | num_output_features=num_features // 2)
99 | self.features.add_module('transition%d' % (i + 1), trans)
100 | num_features = num_features // 2
101 |
102 | # Final batch norm
103 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
104 |
105 | # Linear layer
106 | self.classifier = nn.Linear(num_features, num_classes)
107 |
108 | # Official init from torch repo.
109 | for m in self.modules():
110 | if isinstance(m, nn.Conv2d):
111 | nn.init.kaiming_normal_(m.weight)
112 | elif isinstance(m, nn.BatchNorm2d):
113 | nn.init.constant_(m.weight, 1)
114 | nn.init.constant_(m.bias, 0)
115 | elif isinstance(m, nn.Linear):
116 | nn.init.constant_(m.bias, 0)
117 |
118 | def forward(self, x):
119 | features = self.features(x)
120 | out = F.relu(features, inplace=True)
121 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
122 | out = self.classifier(out)
123 | return out
124 |
125 |
126 | def densenet121(pretrained=False, **kwargs):
127 | r"""Densenet-121 model from
128 | `"Densely Connected Convolutional Networks" `_
129 |
130 | Args:
131 | pretrained (bool): If True, returns a model pre-trained on ImageNet
132 | """
133 | model = DenseNet(num_init_features=64, growth_rate=32,
134 | block_config=(6, 12, 24, 16),
135 | **kwargs)
136 | if pretrained:
137 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
138 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
139 | # They are also in the checkpoints in model_urls. This pattern is used
140 | # to find such keys.
141 | pattern = re.compile(
142 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
143 | state_dict = model_zoo.load_url(model_urls['densenet121'])
144 | for key in list(state_dict.keys()):
145 | res = pattern.match(key)
146 | if res:
147 | new_key = res.group(1) + res.group(2)
148 | state_dict[new_key] = state_dict[key]
149 | del state_dict[key]
150 | model.load_state_dict(state_dict)
151 | return model
152 |
153 |
154 | def densenet169(pretrained=False, **kwargs):
155 | r"""Densenet-169 model from
156 | `"Densely Connected Convolutional Networks" `_
157 |
158 | Args:
159 | pretrained (bool): If True, returns a model pre-trained on ImageNet
160 | """
161 | model = DenseNet(num_init_features=64, growth_rate=32,
162 | block_config=(6, 12, 32, 32),
163 | **kwargs)
164 | if pretrained:
165 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
166 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
167 | # They are also in the checkpoints in model_urls. This pattern is used
168 | # to find such keys.
169 | pattern = re.compile(
170 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
171 | state_dict = model_zoo.load_url(model_urls['densenet169'])
172 | for key in list(state_dict.keys()):
173 | res = pattern.match(key)
174 | if res:
175 | new_key = res.group(1) + res.group(2)
176 | state_dict[new_key] = state_dict[key]
177 | del state_dict[key]
178 | model.load_state_dict(state_dict)
179 | return model
180 |
181 |
182 | def densenet201(pretrained=False, **kwargs):
183 | r"""Densenet-201 model from
184 | `"Densely Connected Convolutional Networks" `_
185 |
186 | Args:
187 | pretrained (bool): If True, returns a model pre-trained on ImageNet
188 | """
189 | model = DenseNet(num_init_features=64, growth_rate=32,
190 | block_config=(6, 12, 48, 32),
191 | **kwargs)
192 | if pretrained:
193 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
194 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
195 | # They are also in the checkpoints in model_urls. This pattern is used
196 | # to find such keys.
197 | pattern = re.compile(
198 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
199 | state_dict = model_zoo.load_url(model_urls['densenet201'])
200 | for key in list(state_dict.keys()):
201 | res = pattern.match(key)
202 | if res:
203 | new_key = res.group(1) + res.group(2)
204 | state_dict[new_key] = state_dict[key]
205 | del state_dict[key]
206 | model.load_state_dict(state_dict)
207 | return model
208 |
209 |
210 | def densenet161(pretrained=False, **kwargs):
211 | r"""Densenet-161 model from
212 | `"Densely Connected Convolutional Networks" `_
213 |
214 | Args:
215 | pretrained (bool): If True, returns a model pre-trained on ImageNet
216 | """
217 | model = DenseNet(num_init_features=96, growth_rate=48,
218 | block_config=(6, 12, 36, 24),
219 | **kwargs)
220 | if pretrained:
221 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
222 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
223 | # They are also in the checkpoints in model_urls. This pattern is used
224 | # to find such keys.
225 | pattern = re.compile(
226 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
227 | state_dict = model_zoo.load_url(model_urls['densenet161'])
228 | for key in list(state_dict.keys()):
229 | res = pattern.match(key)
230 | if res:
231 | new_key = res.group(1) + res.group(2)
232 | state_dict[new_key] = state_dict[key]
233 | del state_dict[key]
234 | model.load_state_dict(state_dict)
235 |
236 | return model
237 |
--------------------------------------------------------------------------------