├── .gitignore ├── README.md ├── environment.yml └── models ├── data └── .gitignore ├── dataset_nn └── dataset_neural_nets.ipynb └── densenet_web ├── LICENSE ├── __init__.py ├── app.py ├── densenet.ipynb ├── densenet_post.png └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # .idea 10 | .idea/ 11 | 12 | # checkpoint files 13 | models/cv/checkpoint/ 14 | 15 | # ONNX Proto files 16 | models/*/*.proto 17 | 18 | .DS_Store 19 | models/.DS_Store 20 | 21 | # Distribution / 22 | packagd/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hackpack-ml 2 | 3 | Learn how to make and deploy ML hacks in PyTorch. 4 | 5 | ### Features 6 | This hackpack aims to make iterating on your models easy while also providing for quick integration into apps. 7 | 8 | We include: 9 | 1. An introductory notebook in which we load our own dataset into PyTorch and build a simple neural network for regression on tabular data: 10 | `/models/dataset_nn/dataset_neural_nets.ipynb` 11 | 12 | 2. An application-focused notebook in which we load and train a DenseNet model for image classification on MNIST and CIFAR10 datasets, 13 | and then export the model into a Flask server as our own API: 14 | `/models/densenet_web/densenet.ipynb` 15 | 16 | ## Getting Started 17 | 18 | This hackpack assumes some prior understanding of Python. Understanding of Deep Learning techniques is greatly beneficial but not required. If you are interested in learning, [Deep Learning by Goodfellow et al.](https://www.deeplearningbook.org/) is a good starting point. 19 | 20 | ### Google Colab 21 | 22 | Both notebooks are also hosted on Google Colab, a free Jupyter notebook environment with GPU/TPU support. 23 | They are accessible at dataset_neural_nets and densenet_web respectively. Feel free to open them in the playground environment or copy! 24 | 25 | ### Installing 26 | If you wish to work on the notebooks locally, the dependencies are listed in `environment.yml`. 27 | Create a new Conda environment with: 28 | ``` 29 | conda env create --file=environment.yml 30 | ``` 31 | ### License 32 | MIT 33 | 34 | # About HackPacks 🌲 35 | 36 | HackPacks are built by the [TreeHacks](https://www.treehacks.com/) team to help hackers build great projects at our hackathon that happens every February at Stanford. We believe that everyone of every skill level can learn to make awesome things, and this is one way we help facilitate hacker culture. We open source our hackpacks (along with our internal tech) so everyone can learn from and use them! Feel free to use these at your own hackathons, workshops, and anything else that promotes building :) 37 | 38 | If you're interested in attending TreeHacks, you can apply on our [website](https://www.treehacks.com/) during the application period. 39 | 40 | You can follow us here on [GitHub](https://github.com/treehacks) to see all the open source work we do (we love issues, contributions, and feedback of any kind!), and on [Facebook](https://facebook.com/treehacks), [Twitter](https://twitter.com/hackwithtrees), and [Instagram](https://instagram.com/hackwithtrees) to see general updates from TreeHacks. 41 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hackpackml 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - appnope=0.1.0=py36_1000 8 | - attrs=18.2.0=py_0 9 | - backcall=0.1.0=py_0 10 | - blas=1.0=mkl 11 | - ca-certificates=2018.11.29=ha4d7672_0 12 | - certifi=2018.11.29=py36_1000 13 | - cffi=1.11.5=py36h342bebf_1001 14 | - click=7.0=py_0 15 | - cycler=0.10.0=py_1 16 | - decorator=4.3.2=py_0 17 | - entrypoints=0.3=py36_1000 18 | - flask=1.0.2=py_2 19 | - freetype=2.9.1=h597ad8a_1005 20 | - icu=58.2=h0a44026_1000 21 | - intel-openmp=2019.1=144 22 | - ipykernel=5.1.0=py36h24bf2e0_1002 23 | - ipython=7.2.0=py36h24bf2e0_1000 24 | - ipython_genutils=0.2.0=py_1 25 | - ipywidgets=7.4.2=py_0 26 | - itsdangerous=1.1.0=py_0 27 | - jedi=0.13.2=py36_1000 28 | - jinja2=2.10=py_1 29 | - jpeg=9c=h1de35cc_1001 30 | - jsonschema=3.0.0a3=py36_1000 31 | - jupyter=1.0.0=py_1 32 | - jupyter_client=5.2.4=py_1 33 | - jupyter_console=6.0.0=py_0 34 | - jupyter_core=4.4.0=py_0 35 | - kiwisolver=1.0.1=py36h04f5b5a_1002 36 | - libcxx=7.0.0=h2d50403_2 37 | - libffi=3.2.1=h0a44026_1005 38 | - libgfortran=3.0.1=h93005f0_2 39 | - libpng=1.6.36=ha441bb4_1000 40 | - libsodium=1.0.16=h1de35cc_1001 41 | - libtiff=4.0.10=h79f4b77_1001 42 | - llvm-meta=7.0.0=0 43 | - markupsafe=1.1.0=py36h1de35cc_1000 44 | - matplotlib=3.0.2=py36_1002 45 | - matplotlib-base=3.0.2=py36hf043ca5_1002 46 | - mistune=0.8.4=py36h1de35cc_1000 47 | - mkl=2019.1=144 48 | - mkl_fft=1.0.10=py36h1de35cc_1 49 | - mkl_random=1.0.2=py36h1702cab_2 50 | - nbconvert=5.3.1=py_1 51 | - nbformat=4.4.0=py_1 52 | - ncurses=6.1=h0a44026_1002 53 | - ninja=1.9.0=h04f5b5a_0 54 | - notebook=5.7.4=py36_1000 55 | - numpy=1.15.4=py36hacdab7b_0 56 | - numpy-base=1.15.4=py36h6575580_0 57 | - olefile=0.46=py_0 58 | - openssl=1.1.1a=h1de35cc_1000 59 | - pandas=0.24.0=py36h0a44026_0 60 | - pandoc=2.6=0 61 | - pandocfilters=1.4.2=py_1 62 | - parso=0.3.2=py_0 63 | - pexpect=4.6.0=py36_1000 64 | - pickleshare=0.7.5=py36_1000 65 | - pillow=5.4.1=py36hbddbef0_1000 66 | - pip=19.0.1=py36_0 67 | - prometheus_client=0.5.0=py_0 68 | - prompt_toolkit=2.0.8=py_0 69 | - ptyprocess=0.6.0=py36_1000 70 | - pycparser=2.19=py_0 71 | - pygments=2.3.1=py_0 72 | - pyparsing=2.3.1=py_0 73 | - pyqt=5.6.0=py36hc26a216_1008 74 | - pyrsistent=0.14.9=py36h1de35cc_1000 75 | - python=3.6.8=haf84260_0 76 | - python-dateutil=2.7.5=py_0 77 | - pytorch=1.0.0=py3.6_1 78 | - pytz=2018.9=py_0 79 | - pyzmq=17.1.2=py36h111632d_1001 80 | - qt=5.6.2=h822fa55_1013 81 | - qtconsole=4.4.3=py_0 82 | - readline=7.0=hcfe32e1_1001 83 | - send2trash=1.5.0=py_0 84 | - setuptools=40.7.1=py36_0 85 | - sip=4.18.1=py36h0a44026_1000 86 | - six=1.12.0=py36_1000 87 | - sqlite=3.26.0=h1765d9f_1000 88 | - terminado=0.8.1=py36_1001 89 | - testpath=0.4.2=py36_1000 90 | - tk=8.6.9=ha441bb4_1000 91 | - torchvision=0.2.1=py_2 92 | - tornado=5.1.1=py36h1de35cc_1000 93 | - tqdm=4.30.0=py_0 94 | - traitlets=4.3.2=py36_1000 95 | - wcwidth=0.1.7=py_1 96 | - webencodings=0.5.1=py_1 97 | - werkzeug=0.14.1=py_0 98 | - wheel=0.32.3=py36_0 99 | - widgetsnbextension=3.4.2=py36_1000 100 | - xz=5.2.4=h1de35cc_1001 101 | - zeromq=4.2.5=h0a44026_1006 102 | - zlib=1.2.11=h1de35cc_1004 103 | - pip: 104 | - bleach==3.1.0 105 | -------------------------------------------------------------------------------- /models/data/.gitignore: -------------------------------------------------------------------------------- 1 | ./* 2 | !/.gitignore 3 | -------------------------------------------------------------------------------- /models/dataset_nn/dataset_neural_nets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Datasets and Neural Networks\n", 8 | "This notebook will step through the process of loading an arbitrary dataset in PyTorch, and creating a simple neural network for regression." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "# Datasets\n", 16 | "We will first work through loading an arbitrary dataset in PyTorch. For this project, we chose the delve abalone dataset. \n", 17 | "\n", 18 | "First, download and unzip the dataset from the link above, then unzip `Dataset.data.gz` and move `Dataset.data` into `hackpack-ml/models/data`.\n", 19 | "We are given the following attribute information in the spec:\n", 20 | "```\n", 21 | "Attributes:\n", 22 | " 1 sex u M F I\t# Gender or Infant (I)\n", 23 | " 2 length u (0,Inf]\t# Longest shell measurement (mm)\n", 24 | " 3 diameter u (0,Inf]\t# perpendicular to length (mm)\n", 25 | " 4 height u (0,Inf]\t# with meat in shell (mm)\n", 26 | " 5 whole_weight u (0,Inf]\t# whole abalone (gr)\n", 27 | " 6 shucked_weight u (0,Inf]\t# weight of meat (gr) \n", 28 | " 7 viscera_weight u (0,Inf]\t# gut weight (after bleeding) (gr)\n", 29 | " 8 shell_weight u (0,Inf]\t# after being dried (gr)\n", 30 | " 9 rings u 0..29\t# +1.5 gives the age in years\n", 31 | "```" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import math\n", 41 | "from tqdm import tqdm\n", 42 | "import torch\n", 43 | "import torch.nn as nn\n", 44 | "import torch.optim as optim\n", 45 | "import torch.utils.data as data\n", 46 | "import torch.nn.functional as F\n", 47 | "import pandas as pd\n", 48 | "\n", 49 | "from torch.utils.data import Dataset, DataLoader" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "Pandas is a data manipulation library that works really well with structured data. We can use Pandas DataFrames to load the dataset." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "col_names = ['sex', 'length', 'diameter', 'height', 'whole_weight', \n", 66 | " 'shucked_weight', 'viscera_weight', 'shell_weight', 'rings']\n", 67 | "abalone_df = pd.read_csv('../data/Dataset.data', sep=' ', names=col_names)\n", 68 | "abalone_df.head(n=3)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "We define a subclass of PyTorch Dataset for our Abalone dataset." 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "class AbaloneDataset(data.Dataset):\n", 85 | " \"\"\"Abalone dataset. Provides quick iteration over rows of data.\"\"\"\n", 86 | "\n", 87 | " def __init__(self, csv):\n", 88 | " \"\"\"\n", 89 | " Args: csv (string): Path to the Abalone dataset.\n", 90 | " \"\"\"\n", 91 | " self.features = ['sex', 'length', 'diameter', 'height', 'whole_weight', \n", 92 | " 'shucked_weight', 'viscera_weight', 'shell_weight']\n", 93 | " self.y = ['rings']\n", 94 | " self.abalone_df = pd.read_csv(csv, sep=' ', names=(self.features + self.y))\n", 95 | " \n", 96 | " # Turn categorical data into machine interpretable format (one hot)\n", 97 | " self.abalone_df['sex'] = pd.get_dummies(self.abalone_df['sex'])\n", 98 | "\n", 99 | " def __len__(self):\n", 100 | " return len(self.abalone_df)\n", 101 | "\n", 102 | " def __getitem__(self, idx):\n", 103 | " \"\"\"Return (x,y) pair where x are abalone features and y is age.\"\"\"\n", 104 | " features = self.abalone_df.iloc[idx][self.features].values\n", 105 | " y = self.abalone_df.iloc[idx][self.y]\n", 106 | " return torch.Tensor(features).float(), torch.Tensor(y).float()" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "# Neural Networks\n", 114 | "\n", 115 | "The task is to predict the age (number of rings) of abalone from physical measurements. We build a simple neural network with one hidden layer to model the regression." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "class Net(nn.Module):\n", 125 | "\n", 126 | " def __init__(self, feature_size):\n", 127 | " super(Net, self).__init__()\n", 128 | " # feature_size input channels (8), 1 output channels\n", 129 | " self.fc1 = nn.Linear(feature_size, 4)\n", 130 | " self.fc2 = nn.Linear(4, 1)\n", 131 | "\n", 132 | " def forward(self, x):\n", 133 | " x = F.relu(self.fc1(x))\n", 134 | " x = self.fc2(x)\n", 135 | " return x" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "We instantiate an Abalone dataset instance and create DataLoaders for train and test sets." 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "dataset = AbaloneDataset('../data/Dataset.data')\n", 152 | "train_split, test_split = math.floor(len(dataset) * 0.8), math.ceil(len(dataset) * 0.2)\n", 153 | "\n", 154 | "trainset = [dataset[i] for i in range(train_split)]\n", 155 | "testset = [dataset[train_split + j] for j in range(test_split)]\n", 156 | "batch_sz = len(trainset) # Compact data allows for big batch size\n", 157 | "trainloader = data.DataLoader(trainset, batch_size=batch_sz, shuffle=True, num_workers=4)\n", 158 | "testloader = data.DataLoader(testset, batch_size=batch_sz, shuffle=False, num_workers=4)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "Now, we can initialize our network and define train and test functions" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "net = Net(len(dataset.features))\n", 175 | "loss_fn = nn.MSELoss()\n", 176 | "optimizer = optim.Adam(net.parameters(), lr=0.1)\n", 177 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 178 | "gpu_ids = [0] # On Colab, we have access to one GPU. Change this value as you see fit\n", 179 | "\n", 180 | "def train(epoch):\n", 181 | " \"\"\"\n", 182 | " Trains our net on data from the trainloader for a single epoch\n", 183 | " \"\"\"\n", 184 | " net.train()\n", 185 | " with tqdm(total=len(trainloader.dataset)) as progress_bar:\n", 186 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 187 | " inputs, targets = inputs.to(device), targets.to(device)\n", 188 | " optimizer.zero_grad() # Clear any stored gradients for new step\n", 189 | " outputs = net(inputs.float())\n", 190 | " loss = loss_fn(outputs, targets) # Calculate loss between prediction and label \n", 191 | " loss.backward() # Backpropagate gradient updates through net based on loss\n", 192 | " optimizer.step() # Update net weights based on gradients\n", 193 | " progress_bar.set_postfix(loss=loss.item())\n", 194 | " progress_bar.update(inputs.size(0))\n", 195 | " \n", 196 | " \n", 197 | "def test(epoch):\n", 198 | " \"\"\"\n", 199 | " Run net in inference mode on test data. \n", 200 | " \"\"\" \n", 201 | " net.eval()\n", 202 | " # Ensures the net will not update weights\n", 203 | " with torch.no_grad():\n", 204 | " with tqdm(total=len(testloader.dataset)) as progress_bar:\n", 205 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 206 | " inputs, targets = inputs.to(device).float(), targets.to(device).float()\n", 207 | " outputs = net(inputs)\n", 208 | " loss = loss_fn(outputs, targets)\n", 209 | " progress_bar.set_postfix(testloss=loss.item())\n", 210 | " progress_bar.update(inputs.size(0))\n" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "Now that everything is prepared, it's time to train!" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "test_freq = 5 # Frequency to run model on validation data\n", 227 | "\n", 228 | "for epoch in range(0, 200):\n", 229 | " train(epoch)\n", 230 | " if epoch % test_freq == 0:\n", 231 | " test(epoch)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "We use the network's eval mode to do a sample prediction to see how well it does." 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "net.eval()\n", 248 | "sample = testset[0]\n", 249 | "predicted_age = net(sample[0])\n", 250 | "true_age = sample[1]\n", 251 | "\n", 252 | "print(f'Input features: {sample[0]}')\n", 253 | "print(f'Predicted age: {predicted_age.item()}, True age: {true_age[0]}')" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "Congratulations! You now know how to load your own datasets into PyTorch and run models on it. For an example of Computer Vision, check out the DenseNet notebook. Happy hacking!" 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "Python 3", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.6.8" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 2 285 | } 286 | -------------------------------------------------------------------------------- /models/densenet_web/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) Soumith Chintala 2016, 2017 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /models/densenet_web/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import densenet121 -------------------------------------------------------------------------------- /models/densenet_web/app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple Flask server to return predictions from trained densenet_web model 3 | """ 4 | 5 | from flask import Flask, request, redirect, flash 6 | import torch 7 | from torchvision.transforms.functional import to_tensor 8 | from PIL import Image 9 | 10 | from .model import densenet121 11 | 12 | app = Flask(__name__) 13 | cp = torch.load('models/densenet_web/checkpoint/best_model.pth.tar') 14 | net = densenet121() 15 | net.load_state_dict(cp['net']) 16 | net.eval() 17 | 18 | 19 | def classify_image(x): 20 | """ 21 | :param x: image 22 | :return: class prediction from densenet_web model 23 | """ 24 | outputs = net(x) 25 | _, predicted = outputs.max(1) 26 | return predicted 27 | 28 | 29 | @app.route('/predict', methods=['POST']) 30 | def predict(): 31 | """ 32 | :return: prediction for image from POST request. 33 | """ 34 | if request.method == 'POST': 35 | if 'img' not in request.files: 36 | flash('No image in request') 37 | return redirect(request.url) 38 | 39 | # preprocess image 40 | img = request.files['img'] 41 | img = Image.open(img).convert("RGB") 42 | x = to_tensor(img) 43 | 44 | prediction = classify_image(x.unsqueeze(0)) 45 | return str(int(prediction[0].data)) 46 | -------------------------------------------------------------------------------- /models/densenet_web/densenet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DenseNet\n", 8 | "In this notebook, we train a DenseNet classifier for SVHN and CIFAR10 datasets, and launch it in a web API." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "\"\"\"\n", 18 | "Script adapted from: https://github.com/kuangliu/pytorch-cifar\n", 19 | "\"\"\"\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.optim as optim\n", 23 | "import torch.utils.data as data\n", 24 | "from torch.autograd import Variable\n", 25 | "from torch.optim import lr_scheduler\n", 26 | "import torchvision\n", 27 | "import torchvision.transforms as transforms\n", 28 | "from torchvision import datasets, models, transforms\n", 29 | "import sys\n", 30 | "import os\n", 31 | "from tqdm import tqdm\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "from model import densenet121\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Dataloader\n", 42 | "\n", 43 | "Here, we load either the SVHN or CIFAR10 datasets, which are provided through torchvision. If you wish to use your own, check out the datasets notebook to see how to create your own dataset class.\n", 44 | "\n", 45 | "Note that we are treating the test set as a validation set." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# Transform from PIL image format to tensor format\n", 55 | "transform_train = transforms.Compose([\n", 56 | " # You can add more data augmentation techniques in series: \n", 57 | " # https://pytorch.org/docs/stable/torchvision/transforms.html\n", 58 | " transforms.ToTensor()\n", 59 | "])\n", 60 | "\n", 61 | "transform_test = transforms.Compose([\n", 62 | " transforms.ToTensor()\n", 63 | "])\n", 64 | "\n", 65 | "# CIFAR10 Dataset: https://www.cs.toronto.edu/~kriz/cifar.html\n", 66 | "trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train)\n", 67 | "testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test)\n", 68 | "\n", 69 | "# SVHN Dataset: http://ufldl.stanford.edu/housenumbers/\n", 70 | "# trainset = torchvision.datasets.SVHN(root='../data', split='train', transform=transform_train, download=True)\n", 71 | "# testset = torchvision.datasets.SVHN(root='../data', split='test', transform=transform_test, download=True)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "If making a proof of concept application, we can choose to overfit on a data subset for quick training." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "train_ct = 128 # Size of train data\n", 88 | "test_ct = 32 # Size of test data\n", 89 | "batch_sz = 32 # Size of batch for one gradient step (bigger batches take more memory but are faster)\n", 90 | "num_workers = 4 # 4 * number of GPUs\n", 91 | "\n", 92 | "if train_ct:\n", 93 | " trainset = data.dataset.Subset(trainset, range(train_ct))\n", 94 | "\n", 95 | "if test_ct:\n", 96 | " testset = data.dataset.Subset(testset, range(test_ct))\n", 97 | "\n", 98 | "trainloader = data.DataLoader(trainset, batch_size=batch_sz, shuffle=True, num_workers=num_workers, )\n", 99 | "testloader = data.DataLoader(testset, batch_size=batch_sz, shuffle=False, num_workers=num_workers)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "## Training\n", 107 | "\n", 108 | "First, we configure the model for training, and define our loss and optimizer." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 118 | "gpu_ids = [0] # On Colab, we have access to one GPU. Change this value as you see fit\n", 119 | "\n", 120 | "net = densenet121()\n", 121 | "net = net.to(device)\n", 122 | "\n", 123 | "if device == 'cuda':\n", 124 | " net = torch.nn.DataParallel(net, gpu_ids)\n", 125 | " \n", 126 | "resume = False # Resume training from a saved checkpoint\n", 127 | "\n", 128 | "if resume:\n", 129 | " print('Resuming from checkpoint at ./checkpoint/best_model.pth.tar')\n", 130 | " assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'\n", 131 | " checkpoint = torch.load('./checkpoint/best_model.pth.tar')\n", 132 | " net.load_state_dict(checkpoint['net'])\n", 133 | " global best_loss\n", 134 | " best_loss = checkpoint['test_loss']\n", 135 | " start_epoch = checkpoint['epoch']\n", 136 | " \n", 137 | "# Loss function: https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss\n", 138 | "loss_fn = nn.CrossEntropyLoss() \n", 139 | "optimizer = optim.Adam(net.parameters(), lr=0.1)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "We define `train`, which performs a forward/back propagation pass on our dataset per epoch. Similarly, `test` performs evaluation on the test set." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "def train(epoch):\n", 156 | " \"\"\"\n", 157 | " Trains our net on data from the trainloader for a single epoch\n", 158 | " \"\"\"\n", 159 | " net.train()\n", 160 | " train_loss = 0\n", 161 | " correct = 0\n", 162 | " total = 0\n", 163 | " with tqdm(total=len(trainloader.dataset)) as progress_bar:\n", 164 | " for batch_idx, (inputs, targets) in enumerate(trainloader):\n", 165 | " inputs, targets = inputs.to(device), targets.to(device)\n", 166 | " \n", 167 | " optimizer.zero_grad() # Clear any stored gradients for new step\n", 168 | " outputs = net(inputs)\n", 169 | " \n", 170 | " loss = loss_fn(outputs, targets) # Calculate loss between prediction and label \n", 171 | " loss.backward() # Backpropagate gradient updates through net based on loss\n", 172 | " optimizer.step() # Update net weights based on gradients\n", 173 | "\n", 174 | " train_loss += loss.item()\n", 175 | " _, predicted = outputs.max(1)\n", 176 | " total += targets.size(0)\n", 177 | " correct += predicted.eq(targets).sum().item()\n", 178 | " acc = (100. * correct / total)\n", 179 | " \n", 180 | " progress_bar.set_postfix(loss=train_loss/(batch_idx+1), accuracy=f'{acc}%')\n", 181 | " progress_bar.update(inputs.size(0))\n", 182 | " \n", 183 | " \n", 184 | "def test(epoch):\n", 185 | " \"\"\"\n", 186 | " Run net in inference mode on test data. \n", 187 | " \"\"\" \n", 188 | " global best_acc\n", 189 | " net.eval()\n", 190 | " test_loss = 0\n", 191 | " correct = 0\n", 192 | " total = 0\n", 193 | " best_acc = 0\n", 194 | " # Ensures the net will not update weights\n", 195 | " with torch.no_grad():\n", 196 | " with tqdm(total=len(testloader.dataset)) as progress_bar:\n", 197 | " for batch_idx, (inputs, targets) in enumerate(testloader):\n", 198 | " inputs, targets = inputs.to(device), targets.to(device)\n", 199 | " outputs = net(inputs)\n", 200 | " loss = loss_fn(outputs, targets)\n", 201 | " \n", 202 | " test_loss += loss.item()\n", 203 | " _, predicted = outputs.max(1)\n", 204 | " total += targets.size(0)\n", 205 | " correct += predicted.eq(targets).sum().item()\n", 206 | " \n", 207 | " acc = (100. * correct / total)\n", 208 | " progress_bar.set_postfix(loss=test_loss/(batch_idx+1), accuracy=f'{acc}%')\n", 209 | " progress_bar.update(inputs.size(0))\n", 210 | " \n", 211 | " # Save best model\n", 212 | " if acc > best_acc:\n", 213 | " print(\"Saving...\")\n", 214 | " save_state(net, acc, epoch)\n", 215 | " best_acc = acc\n", 216 | "\n", 217 | "def save_state(net, acc, epoch):\n", 218 | " \"\"\"\n", 219 | " Save the current net state, accuracy and epoch\n", 220 | " \"\"\"\n", 221 | " state = {\n", 222 | " 'net': net.state_dict(),\n", 223 | " 'acc': acc,\n", 224 | " 'epoch': epoch,\n", 225 | " }\n", 226 | " if not os.path.isdir('checkpoint'):\n", 227 | " os.mkdir('checkpoint')\n", 228 | " torch.save(state, './checkpoint/best_model.pth.tar')" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "test_freq = 5 # Frequency to run model on validation data\n", 238 | "\n", 239 | "for epoch in range(0, 100):\n", 240 | " train(epoch)\n", 241 | " if epoch % test_freq == 0:\n", 242 | " test(epoch)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "metadata": {}, 248 | "source": [ 249 | "## Inference\n", 250 | "Now that we have trained a model, we can use it for inference on new data! With each run, we save the best model weights at:\n", 251 | "`./checkpoint/best_model.pth.tar`" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "def classify_image(x):\n", 261 | " \"\"\"\n", 262 | " Return model classification for an image\n", 263 | " \"\"\"\n", 264 | " outputs = net(x)\n", 265 | " _, predicted = outputs.max(1)\n", 266 | " return predicted" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "# Load the checkpoint weights and set the model weights\n", 276 | "cp = torch.load('./checkpoint/best_model.pth.tar')\n", 277 | "net.load_state_dict(cp['net'])\n", 278 | "net.eval()\n", 279 | "\n", 280 | "# Some preprocessing to get the channels in the right order\n", 281 | "sample = trainset[1][0]\n", 282 | "plt.imshow(sample.permute(1,2,0))\n", 283 | "sample = sample.unsqueeze(0)\n", 284 | "\n", 285 | "y = classify_image(sample)[0]\n", 286 | "print(f'Predicted class: {y}')" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "(If you are running on CIFAR, you can get the associated class labels at https://www.cs.toronto.edu/~kriz/cifar.html.\n", 294 | "0 = airplane, 1 = automobile, etc.)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "## Export (API)\n", 302 | "We want to use our models inside our apps. One way to do this is to wrap our calls in an API. We included an `app.py` that takes the model saved at `models/densenet/checkpoint/best_model.pth.tar` and wraps it in a simple Flask server that receives images and returns the classification. Check out our web API hackpack for more information on how this works! (https://github.com/TreeHacks/hackpack-web-api)\n", 303 | "\n", 304 | "To post an image to your server, you can use Postman to send a request. Here is a screenshot of an example post to get a response classification for some image:\n", 305 | "\n", 306 | "![Post example](densenet_post.png)\n", 307 | "\n", 308 | "Congratulations! You have just trained and launched your own ML model." 309 | ] 310 | } 311 | ], 312 | "metadata": { 313 | "kernelspec": { 314 | "display_name": "Python 3", 315 | "language": "python", 316 | "name": "python3" 317 | }, 318 | "language_info": { 319 | "codemirror_mode": { 320 | "name": "ipython", 321 | "version": 3 322 | }, 323 | "file_extension": ".py", 324 | "mimetype": "text/x-python", 325 | "name": "python", 326 | "nbconvert_exporter": "python", 327 | "pygments_lexer": "ipython3", 328 | "version": "3.6.8" 329 | } 330 | }, 331 | "nbformat": 4, 332 | "nbformat_minor": 2 333 | } 334 | -------------------------------------------------------------------------------- /models/densenet_web/densenet_post.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TreeHacks/hackpack-ml/c5fed6a047d6082c2f10592ef35ac999a4e4e393/models/densenet_web/densenet_post.png -------------------------------------------------------------------------------- /models/densenet_web/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | DenseNet models adapted from: 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/model.py 4 | """ 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | from collections import OrderedDict 11 | 12 | 13 | class _DenseLayer(nn.Sequential): 14 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 15 | super(_DenseLayer, self).__init__() 16 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 17 | self.add_module('relu1', nn.ReLU(inplace=True)), 18 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 19 | growth_rate, kernel_size=1, stride=1, 20 | bias=False)), 21 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 22 | self.add_module('relu2', nn.ReLU(inplace=True)), 23 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 24 | kernel_size=3, stride=1, padding=1, 25 | bias=False)), 26 | self.drop_rate = drop_rate 27 | 28 | def forward(self, x): 29 | new_features = super(_DenseLayer, self).forward(x) 30 | if self.drop_rate > 0: 31 | new_features = F.dropout(new_features, p=self.drop_rate, 32 | training=self.training) 33 | return torch.cat([x, new_features], 1) 34 | 35 | 36 | class _DenseBlock(nn.Sequential): 37 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, 38 | drop_rate): 39 | super(_DenseBlock, self).__init__() 40 | for i in range(num_layers): 41 | layer = _DenseLayer(num_input_features + i * growth_rate, 42 | growth_rate, bn_size, drop_rate) 43 | self.add_module('denselayer%d' % (i + 1), layer) 44 | 45 | 46 | class _Transition(nn.Sequential): 47 | def __init__(self, num_input_features, num_output_features): 48 | super(_Transition, self).__init__() 49 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 50 | self.add_module('relu', nn.ReLU(inplace=True)) 51 | self.add_module('conv', 52 | nn.Conv2d(num_input_features, num_output_features, 53 | kernel_size=1, stride=1, bias=False)) 54 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 55 | 56 | 57 | class DenseNet(nn.Module): 58 | r"""Densenet-BC model class, based on 59 | `"Densely Connected Convolutional Networks" `_ 60 | 61 | Args: 62 | growth_rate (int) - how many filters to add each layer (`k` in paper) 63 | block_config (list of 4 ints) - how many layers in each pooling block 64 | num_init_features (int) - the number of filters to learn in the first convolution layer 65 | bn_size (int) - multiplicative factor for number of bottle neck layers 66 | (i.e. bn_size * k features in the bottleneck layer) 67 | drop_rate (float) - dropout rate after each dense layer 68 | num_classes (int) - number of classification classes 69 | """ 70 | 71 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 72 | num_init_features=64, bn_size=4, drop_rate=0, 73 | num_classes=1000): 74 | 75 | super(DenseNet, self).__init__() 76 | 77 | # First convolution 78 | self.features = nn.Sequential(OrderedDict([ 79 | ('conv0', 80 | nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, 81 | bias=False)), 82 | ('norm0', nn.BatchNorm2d(num_init_features)), 83 | ('relu0', nn.ReLU(inplace=True)), 84 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 85 | ])) 86 | 87 | # Each denseblock 88 | num_features = num_init_features 89 | for i, num_layers in enumerate(block_config): 90 | block = _DenseBlock(num_layers=num_layers, 91 | num_input_features=num_features, 92 | bn_size=bn_size, growth_rate=growth_rate, 93 | drop_rate=drop_rate) 94 | self.features.add_module('denseblock%d' % (i + 1), block) 95 | num_features = num_features + num_layers * growth_rate 96 | if i != len(block_config) - 1: 97 | trans = _Transition(num_input_features=num_features, 98 | num_output_features=num_features // 2) 99 | self.features.add_module('transition%d' % (i + 1), trans) 100 | num_features = num_features // 2 101 | 102 | # Final batch norm 103 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 104 | 105 | # Linear layer 106 | self.classifier = nn.Linear(num_features, num_classes) 107 | 108 | # Official init from torch repo. 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.constant_(m.weight, 1) 114 | nn.init.constant_(m.bias, 0) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def forward(self, x): 119 | features = self.features(x) 120 | out = F.relu(features, inplace=True) 121 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 122 | out = self.classifier(out) 123 | return out 124 | 125 | 126 | def densenet121(pretrained=False, **kwargs): 127 | r"""Densenet-121 model from 128 | `"Densely Connected Convolutional Networks" `_ 129 | 130 | Args: 131 | pretrained (bool): If True, returns a model pre-trained on ImageNet 132 | """ 133 | model = DenseNet(num_init_features=64, growth_rate=32, 134 | block_config=(6, 12, 24, 16), 135 | **kwargs) 136 | if pretrained: 137 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 138 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 139 | # They are also in the checkpoints in model_urls. This pattern is used 140 | # to find such keys. 141 | pattern = re.compile( 142 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 143 | state_dict = model_zoo.load_url(model_urls['densenet121']) 144 | for key in list(state_dict.keys()): 145 | res = pattern.match(key) 146 | if res: 147 | new_key = res.group(1) + res.group(2) 148 | state_dict[new_key] = state_dict[key] 149 | del state_dict[key] 150 | model.load_state_dict(state_dict) 151 | return model 152 | 153 | 154 | def densenet169(pretrained=False, **kwargs): 155 | r"""Densenet-169 model from 156 | `"Densely Connected Convolutional Networks" `_ 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = DenseNet(num_init_features=64, growth_rate=32, 162 | block_config=(6, 12, 32, 32), 163 | **kwargs) 164 | if pretrained: 165 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 166 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 167 | # They are also in the checkpoints in model_urls. This pattern is used 168 | # to find such keys. 169 | pattern = re.compile( 170 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 171 | state_dict = model_zoo.load_url(model_urls['densenet169']) 172 | for key in list(state_dict.keys()): 173 | res = pattern.match(key) 174 | if res: 175 | new_key = res.group(1) + res.group(2) 176 | state_dict[new_key] = state_dict[key] 177 | del state_dict[key] 178 | model.load_state_dict(state_dict) 179 | return model 180 | 181 | 182 | def densenet201(pretrained=False, **kwargs): 183 | r"""Densenet-201 model from 184 | `"Densely Connected Convolutional Networks" `_ 185 | 186 | Args: 187 | pretrained (bool): If True, returns a model pre-trained on ImageNet 188 | """ 189 | model = DenseNet(num_init_features=64, growth_rate=32, 190 | block_config=(6, 12, 48, 32), 191 | **kwargs) 192 | if pretrained: 193 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 194 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 195 | # They are also in the checkpoints in model_urls. This pattern is used 196 | # to find such keys. 197 | pattern = re.compile( 198 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 199 | state_dict = model_zoo.load_url(model_urls['densenet201']) 200 | for key in list(state_dict.keys()): 201 | res = pattern.match(key) 202 | if res: 203 | new_key = res.group(1) + res.group(2) 204 | state_dict[new_key] = state_dict[key] 205 | del state_dict[key] 206 | model.load_state_dict(state_dict) 207 | return model 208 | 209 | 210 | def densenet161(pretrained=False, **kwargs): 211 | r"""Densenet-161 model from 212 | `"Densely Connected Convolutional Networks" `_ 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = DenseNet(num_init_features=96, growth_rate=48, 218 | block_config=(6, 12, 36, 24), 219 | **kwargs) 220 | if pretrained: 221 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 222 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 223 | # They are also in the checkpoints in model_urls. This pattern is used 224 | # to find such keys. 225 | pattern = re.compile( 226 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 227 | state_dict = model_zoo.load_url(model_urls['densenet161']) 228 | for key in list(state_dict.keys()): 229 | res = pattern.match(key) 230 | if res: 231 | new_key = res.group(1) + res.group(2) 232 | state_dict[new_key] = state_dict[key] 233 | del state_dict[key] 234 | model.load_state_dict(state_dict) 235 | 236 | return model 237 | --------------------------------------------------------------------------------