├── workspace ├── models │ ├── __init__.py │ ├── .gitignore │ ├── layers.py │ └── wideresnet.py ├── cifar10_fastai_adamw.ipynb └── cifar10_fastai_dawnbench.ipynb ├── run_container.sh ├── data.py ├── README.md └── Dockerfile /workspace/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /workspace/models/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.tar 3 | checkpoint* 4 | log* 5 | wgts/ 6 | -------------------------------------------------------------------------------- /run_container.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | IMAGE_NAME=$1 3 | 4 | if [[ $# -eq 0 ]] ; then 5 | echo 'ERROR: No argument passed for image name.' 6 | exit 0 7 | fi 8 | 9 | CONTAINER="docker run -id --runtime=nvidia -e NVIDIA_DRIVER_CAPABILITIES=compute,utility -e NVIDIA_VISIBLE_DEVICES=all \ 10 | --ipc=host --net=host -v $PWD/workspace/:/root/workspace $IMAGE_NAME" 11 | echo 'Starting container with commmand: '$CONTAINER 12 | eval $CONTAINER 13 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os, shutil, re 2 | from glob import glob 3 | from subprocess import run 4 | 5 | path = 'data/' 6 | for ds in ['train', 'test']: 7 | paths = glob(f'{path}cifar/{ds}/*') 8 | for cls in ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'): 9 | run(f'mkdir -p {path}cifar10/{ds}/{cls}'.split()) 10 | for fpath in paths: 11 | cls = re.search('_(.*)\.png$', fpath).group(1) 12 | fname = re.search('\w*.png$', fpath).group(0) 13 | shutil.copy(fpath, f'{path}cifar10/{ds}/{cls}/{fname}') 14 | -------------------------------------------------------------------------------- /workspace/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class AdaptiveConcatPool2d(nn.Module): 5 | def __init__(self, sz=None): 6 | super().__init__() 7 | sz = sz or (1,1) 8 | self.ap = nn.AdaptiveAvgPool2d(sz) 9 | self.mp = nn.AdaptiveMaxPool2d(sz) 10 | def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) 11 | 12 | class Lambda(nn.Module): 13 | def __init__(self, f): super().__init__(); self.f=f 14 | def forward(self, x): return self.f(x) 15 | 16 | class Flatten(nn.Module): 17 | def __init__(self): super().__init__() 18 | def forward(self, x): return x.view(x.size(0), -1) 19 | 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## About 2 | 3 | This repository contains the [fastai](http://www.fast.ai) [DAWNbench](https://dawn.cs.stanford.edu/benchmark/#cifar10-train-time) result adapted to training on 1080ti. It is missing many of the optimizations that allowed the fastai team to achieve 94% accuracy in 2m 54s (no fp16, no data prefetching, etc) on an AWS p3.16xlarge instance with 8 V100 GPUs. On my box with a single 1080ti I am able to train to 94% accuracy (with TTA) in 13 minutes 30 seconds. 4 | 5 | The second notebook adapts recent work by fastai and trains with AdamW and the 1 cycle policy cutting down the number of required epochs to 18. You can read more about this approach on the [fastai blog](http://www.fast.ai/2018/07/02/adam-weight-decay/) or in the [official repositiory](https://github.com/sgugger/Adam-experiments). 6 | 7 | You will need to have [docker](https://docs.docker.com/install/linux/docker-ce/ubuntu/) and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) installed in order to run this. 8 | 9 | Once you start the docker container, all you have to do is access https://localhost:8888 and enter `jupyter` as password. Open the notebook and hit run all. 10 | 11 | For Tensorflow code please checkout the [tensorflow branch](https://github.com/radekosmulski/cifar10_docker/tree/tensorflow). The implementation there is very minimal but still might be useful as a starting point for experimenting. 12 | 13 | ## Instructions for building and running the container 14 | 1. cd into cloned repo 15 | 2. `docker build -t cifar .` 16 | 3. `./run_container.sh cifar` 17 | 18 | 19 | *SIDENOTE*: You might need to run the commands with sudo. I prefer to do the following: 20 | ``` 21 | sudo groupadd docker 22 | sudo usermod -aG docker $USER 23 | ``` 24 | (this effectively grants docker sudo powers so is not more secure than running docker with sudo) 25 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-base 2 | LABEL maintainer="Radek Osmulski " 3 | 4 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 5 | ENV PATH /opt/conda/bin:$PATH 6 | 7 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 8 | libglib2.0-0 libxext6 libsm6 libxrender1 git 9 | 10 | RUN wget --quiet https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh -O ~/anaconda.sh && \ 11 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \ 12 | rm ~/anaconda.sh && \ 13 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 14 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 15 | echo "conda activate fastai" >> ~/.bashrc 16 | 17 | WORKDIR /root 18 | 19 | RUN git clone https://github.com/fastai/fastai.git && cd fastai && conda env create 20 | 21 | # configure jupyter 22 | RUN jupyter notebook --generate-config 23 | 24 | # This will set the password on the notebook to 'jupyter'. To generate a hash corresponding to a password 25 | # of your choice, run the code below inside a Python interpreter 26 | # from notebook.auth import passwd; print(passwd()) 27 | ARG jupass=sha1:85ff16c0f1a9:c296112bf7b82121f5ec73ef4c1b9305b9e538af 28 | 29 | RUN echo "c.NotebookApp.password = u'"$jupass"'" >> $HOME/.jupyter/jupyter_notebook_config.py 30 | RUN echo "c.NotebookApp.ip = '*'\nc.NotebookApp.open_browser = False" >> $HOME/.jupyter/jupyter_notebook_config.py 31 | 32 | # create ssl cert for jupyter notebook 33 | RUN openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout $HOME/mykey.key -out $HOME/mycert.pem -subj "/C=IE" 34 | 35 | RUN apt-get install -y curl grep sed dpkg && \ 36 | TINI_VERSION=`curl https://github.com/krallin/tini/releases/latest | grep -o "/v.*\"" | sed 's:^..\(.*\).$:\1:'` && \ 37 | curl -L "https://github.com/krallin/tini/releases/download/v${TINI_VERSION}/tini_${TINI_VERSION}.deb" > tini.deb && \ 38 | dpkg -i tini.deb && \ 39 | rm tini.deb && \ 40 | apt-get clean 41 | 42 | COPY data.py . 43 | RUN mkdir data 44 | RUN wget http://pjreddie.com/media/files/cifar.tgz -P data/ 45 | RUN tar -xf data/cifar.tgz -C data/ 46 | RUN python data.py 47 | RUN rm -rf data/cifar.tgz data/cifar 48 | 49 | VOLUME workspace /root/workspace 50 | 51 | EXPOSE 8888 52 | 53 | RUN cd fastai && /bin/bash -c "source activate fastai && python setup.py install" 54 | 55 | ENTRYPOINT [ "/usr/bin/tini", "--" ] 56 | SHELL [ "/bin/bash", "-c" ] 57 | CMD source activate fastai && jupyter notebook --certfile=mycert.pem --keyfile mykey.key --allow-root --notebook-dir='workspace' 58 | -------------------------------------------------------------------------------- /workspace/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/uoguelph-mlrg/Cutout 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .layers import * 8 | 9 | def conv_2d(ni, nf, ks, stride): return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False) 10 | 11 | def bn(ni, init_zero=False): 12 | m = nn.BatchNorm2d(ni) 13 | m.weight.data.fill_(0 if init_zero else 1) 14 | m.bias.data.zero_() 15 | return m 16 | 17 | def bn_relu_conv(ni, nf, ks, stride, init_zero=False): 18 | bn_initzero = bn(ni, init_zero=init_zero) 19 | return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv_2d(ni, nf, ks, stride)) 20 | 21 | def noop(x): return x 22 | 23 | class BasicBlock(nn.Module): 24 | def __init__(self, ni, nf, stride, drop_p=0.0): 25 | super().__init__() 26 | self.bn = nn.BatchNorm2d(ni) 27 | self.conv1 = conv_2d(ni, nf, 3, stride) 28 | self.conv2 = bn_relu_conv(nf, nf, 3, 1) 29 | self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None 30 | self.shortcut = conv_2d(ni, nf, 1, stride) if ni != nf else noop 31 | 32 | def forward(self, x): 33 | x2 = F.relu(self.bn(x), inplace=True) 34 | r = self.shortcut(x2) 35 | x = self.conv1(x2) 36 | if self.drop: x = self.drop(x) 37 | x = self.conv2(x) ## * 0.2 38 | return x.add_(r) 39 | 40 | 41 | def _make_group(N, ni, nf, block, stride, drop_p): 42 | return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)] 43 | 44 | class WideResNet(nn.Module): 45 | def __init__(self, num_groups, N, num_classes, k=1, drop_p=0.0, start_nf=16): 46 | super().__init__() 47 | n_channels = [start_nf] 48 | for i in range(num_groups): n_channels.append(start_nf*(2**i)*k) 49 | 50 | layers = [conv_2d(3, n_channels[0], 3, 1)] # conv1 51 | for i in range(num_groups): 52 | layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p) 53 | 54 | layers += [nn.AdaptiveAvgPool2d(1), bn_relu_conv(n_channels[-1], num_classes, 1, 1), Flatten()] 55 | self.features = nn.Sequential(*layers) 56 | 57 | def forward(self, x): return self.features(x) 58 | 59 | 60 | def wrn_22(): return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.) 61 | def wrn_22_k8(): return WideResNet(num_groups=3, N=3, num_classes=10, k=8, drop_p=0.) 62 | def wrn_22_k10(): return WideResNet(num_groups=3, N=3, num_classes=10, k=10, drop_p=0.) 63 | def wrn_22_k8_p2(): return WideResNet(num_groups=3, N=3, num_classes=10, k=8, drop_p=0.) 64 | def wrn_28(): return WideResNet(num_groups=3, N=4, num_classes=10, k=6, drop_p=0.) 65 | def wrn_28_k8(): return WideResNet(num_groups=3, N=4, num_classes=10, k=8, drop_p=0.) 66 | def wrn_28_k8_p2(): return WideResNet(num_groups=3, N=4, num_classes=10, k=8, drop_p=0.2) 67 | def wrn_28_p2(): return WideResNet(num_groups=3, N=4, num_classes=10, k=6, drop_p=0.2) 68 | 69 | -------------------------------------------------------------------------------- /workspace/cifar10_fastai_adamw.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "https://github.com/sgugger/Adam-experiments" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from fastai.conv_learner import *\n", 17 | "from fastai.models.cifar10.wideresnet import wrn_22\n", 18 | "from torchvision import transforms, datasets\n", 19 | "\n", 20 | "torch.backends.cudnn.benchmark = True\n", 21 | "PATH = Path(\"../data/cifar10\")" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", 31 | "stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "[8.325000000000001, 8.325000000000001, 1.3499999999999999]" 43 | ] 44 | }, 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "sz = 32\n", 52 | "bs = 128\n", 53 | "\n", 54 | "m = wrn_22()\n", 55 | "base_lr = 3e-3\n", 56 | "lr_div = 10\n", 57 | "wd = 0.1\n", 58 | "cyc_len = 18 # lenght of the cycle expressed in epochs\n", 59 | "ann_len = 0.075 # length of the annealing phase expressed as a fraction of cycle_len\n", 60 | "\n", 61 | "moms = (0.95,0.85)\n", 62 | "beta2=0.99\n", 63 | "\n", 64 | "phase_lengths = [cyc_len * (1-ann_len) / 2, cyc_len * (1-ann_len) / 2, cyc_len * ann_len]; phase_lengths" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomCrop(sz), RandomFlip()], pad=sz//8)\n", 74 | "data = ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "learn = ConvLearner.from_model_data(m, data)\n", 84 | "learn.crit = nn.CrossEntropyLoss()\n", 85 | "learn.metrics = [accuracy]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def adam(params): return optim.Adam(params, betas=(moms[0], beta2))\n", 95 | "learn.opt_fn = adam" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 7, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "training_phases = [\n", 105 | " TrainingPhase(phase_lengths[0], adam, lr=(base_lr/lr_div, base_lr), lr_decay=DecayType.LINEAR,\n", 106 | " momentum=moms, momentum_decay=DecayType.LINEAR, wds=wd, wd_loss=False),\n", 107 | " TrainingPhase(phase_lengths[1], adam, lr=(base_lr, base_lr/lr_div), lr_decay=DecayType.LINEAR,\n", 108 | " momentum=(moms[1], moms[0]), momentum_decay=DecayType.LINEAR, wds=wd, wd_loss=False),\n", 109 | " TrainingPhase(phase_lengths[2], adam, lr=(base_lr/lr_div, base_lr/(lr_div*100)), lr_decay=DecayType.LINEAR,\n", 110 | " momentum=moms[0], wds=wd, wd_loss=False)\n", 111 | "]" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "application/vnd.jupyter.widget-view+json": { 122 | "model_id": "c7799a4443764a56933f1a03d49b963e", 123 | "version_major": 2, 124 | "version_minor": 0 125 | }, 126 | "text/plain": [ 127 | "HBox(children=(IntProgress(value=0, description='Epoch', max=19), HTML(value='')))" 128 | ] 129 | }, 130 | "metadata": {}, 131 | "output_type": "display_data" 132 | }, 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "epoch trn_loss val_loss accuracy \n", 138 | " 0 1.075619 1.015231 0.6435 \n", 139 | " 1 0.82483 1.034227 0.6552 \n", 140 | " 2 0.712239 0.651429 0.7754 \n", 141 | " 3 0.631517 0.575336 0.8018 \n", 142 | " 4 0.576791 0.651911 0.7742 \n", 143 | " 5 0.504203 0.59719 0.8013 \n", 144 | " 6 0.480576 0.686073 0.783 \n", 145 | " 7 0.443939 0.546567 0.8209 \n", 146 | " 8 0.414961 0.459819 0.8399 \n", 147 | " 9 0.348353 0.403022 0.8628 \n", 148 | " 10 0.31422 0.407295 0.8618 \n", 149 | " 11 0.257251 0.336791 0.8856 \n", 150 | " 12 0.237165 0.294243 0.9005 \n", 151 | " 13 0.183506 0.286957 0.9054 \n", 152 | " 14 0.14385 0.243615 0.9206 \n", 153 | " 15 0.108258 0.231928 0.9284 \n", 154 | " 16 0.074345 0.217072 0.9322 \n", 155 | " 17 0.050736 0.212466 0.936 \n", 156 | "CPU times: user 14min 39s, sys: 7min 35s, total: 22min 15s\n", 157 | "Wall time: 15min 59s\n" 158 | ] 159 | }, 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "[array([0.21247]), 0.936]" 164 | ] 165 | }, 166 | "execution_count": 8, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "%time learn.fit_opt_sched(training_phases)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 9, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | " \r" 185 | ] 186 | }, 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "'Final loss: 0.17752112448215485, Final accuracy: 0.9394999742507935'" 191 | ] 192 | }, 193 | "execution_count": 9, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "preds, targs = learn.TTA()\n", 200 | "probs = np.exp(preds)/np.exp(preds).sum(2)[:,:,None]\n", 201 | "probs = np.mean(probs,0)\n", 202 | "acc = learn.metrics[0](V(probs), V(targs)).data[0]\n", 203 | "loss = learn.crit(V(np.log(probs)), V(targs)).data[0]\n", 204 | "f'Final loss: {loss}, Final accuracy: {acc}'" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 10, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "image/png": "\n", 215 | "text/plain": [ 216 | "
" 217 | ] 218 | }, 219 | "metadata": {}, 220 | "output_type": "display_data" 221 | } 222 | ], 223 | "source": [ 224 | "learn.sched.plot_loss()" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 11, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "image/png": "\n", 235 | "text/plain": [ 236 | "
" 237 | ] 238 | }, 239 | "metadata": {}, 240 | "output_type": "display_data" 241 | } 242 | ], 243 | "source": [ 244 | "learn.sched.plot_lr()" 245 | ] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "Python 3", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.6.5" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 2 269 | } 270 | -------------------------------------------------------------------------------- /workspace/cifar10_fastai_dawnbench.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "https://github.com/fastai/imagenet-fast/tree/master/cifar10/dawn_submission" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from fastai.conv_learner import *\n", 17 | "from models.wideresnet import wrn_22 # this is the models directory from the fastai/imagenet-fast repo\n", 18 | "from torchvision import transforms, datasets\n", 19 | "\n", 20 | "torch.backends.cudnn.benchmark = True\n", 21 | "PATH = Path(\"../data/cifar10\")" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", 31 | "stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "We construct the data object manually from low level components in a way that can be used with the fastai library." 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "def get_loaders(bs, num_workers):\n", 48 | " traindir = str(PATH/'train')\n", 49 | " valdir = str(PATH/'test')\n", 50 | " tfms = [transforms.ToTensor(),\n", 51 | " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]\n", 52 | "\n", 53 | " aug_tfms =transforms.Compose([\n", 54 | " transforms.RandomCrop(32, padding=4),\n", 55 | " transforms.RandomHorizontalFlip(),\n", 56 | " ] + tfms)\n", 57 | " \n", 58 | " train_dataset = datasets.ImageFolder(\n", 59 | " traindir,\n", 60 | " aug_tfms)\n", 61 | "\n", 62 | " train_loader = torch.utils.data.DataLoader(\n", 63 | " train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers, pin_memory=True)\n", 64 | "\n", 65 | " val_dataset = datasets.ImageFolder(valdir, transforms.Compose(tfms))\n", 66 | "\n", 67 | " val_loader = torch.utils.data.DataLoader(\n", 68 | " val_dataset, batch_size=bs, shuffle=False, num_workers=num_workers, pin_memory=True)\n", 69 | " \n", 70 | " aug_dataset = datasets.ImageFolder(valdir, aug_tfms)\n", 71 | "\n", 72 | " aug_loader = torch.utils.data.DataLoader(\n", 73 | " aug_dataset, batch_size=bs, shuffle=False, num_workers=num_workers, pin_memory=True)\n", 74 | " \n", 75 | " return train_loader, val_loader, aug_loader" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "def get_data(bs, num_workers):\n", 85 | " trn_dl, val_dl, aug_dl = get_loaders(bs, num_workers)\n", 86 | " data = ModelData(PATH, trn_dl, val_dl)\n", 87 | " data.aug_dl = aug_dl\n", 88 | " data.sz=32\n", 89 | " return data" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 5, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "def get_learner(arch, bs):\n", 99 | " learn = ConvLearner.from_model_data(arch.cuda(), get_data(bs, num_cpus()))\n", 100 | " learn.crit = nn.CrossEntropyLoss()\n", 101 | " learn.metrics = [accuracy]\n", 102 | " return learn" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "def get_TTA_accuracy(learn):\n", 112 | " preds, targs = learn.TTA()\n", 113 | " # combining the predictions across augmented and non augmented inputs\n", 114 | " preds = 0.6 * preds[0] + 0.4 * preds[1:].sum(0)\n", 115 | " return accuracy_np(preds, targs)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "## fastai DAWN bench submission " 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "This I believe is the original FastAI DAWN bench submission in terms of the architecture and the training parameters." 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 7, 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "application/vnd.jupyter.widget-view+json": { 140 | "model_id": "f309ae504202482296bac8751d999e9b", 141 | "version_major": 2, 142 | "version_minor": 0 143 | }, 144 | "text/plain": [ 145 | "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" 146 | ] 147 | }, 148 | "metadata": {}, 149 | "output_type": "display_data" 150 | }, 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "epoch trn_loss val_loss accuracy \n", 156 | " 0 2.406675 6239621.644 0.1 \n", 157 | "\n" 158 | ] 159 | }, 160 | { 161 | "data": { 162 | "image/png": "\n", 163 | "text/plain": [ 164 | "
" 165 | ] 166 | }, 167 | "metadata": {}, 168 | "output_type": "display_data" 169 | } 170 | ], 171 | "source": [ 172 | "learn = get_learner(wrn_22(), 512)\n", 173 | "learn.lr_find(wds=1e-4)\n", 174 | "learn.clip = 1e-1\n", 175 | "learn.sched.plot(n_skip_end=1)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 8, 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "data": { 185 | "application/vnd.jupyter.widget-view+json": { 186 | "model_id": "5e57e2ab22e747ff85df0ac78e090a0b", 187 | "version_major": 2, 188 | "version_minor": 0 189 | }, 190 | "text/plain": [ 191 | "HBox(children=(IntProgress(value=0, description='Epoch', max=30), HTML(value='')))" 192 | ] 193 | }, 194 | "metadata": {}, 195 | "output_type": "display_data" 196 | }, 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "epoch trn_loss val_loss accuracy \n", 202 | " 0 1.600993 1.441529 0.4754 \n", 203 | " 1 1.188086 1.41085 0.5334 \n", 204 | " 2 0.910884 1.36183 0.5558 \n", 205 | " 4 0.611622 0.841227 0.7221 \n", 206 | " 5 0.552191 0.764714 0.7504 \n", 207 | " 6 0.503751 0.737051 0.7511 \n", 208 | " 7 0.460971 0.653494 0.7825 \n", 209 | " 8 0.429952 1.264751 0.6541 \n", 210 | " 9 0.397564 0.723189 0.7511 \n", 211 | " 10 0.37606 0.683567 0.7782 \n", 212 | " 11 0.360873 0.577565 0.7981 \n", 213 | " 12 0.341871 0.695879 0.7759 \n", 214 | " 13 0.323827 0.59713 0.798 \n", 215 | " 14 0.304483 0.473728 0.8451 \n", 216 | " 15 0.28388 0.555944 0.816 \n", 217 | " 16 0.257587 0.87107 0.7456 \n", 218 | " 17 0.242033 0.422021 0.8581 \n", 219 | " 18 0.228757 0.405375 0.8638 \n", 220 | " 19 0.21638 0.411791 0.8625 \n", 221 | " 20 0.197131 0.412363 0.8628 \n", 222 | " 21 0.179344 0.389317 0.8746 \n", 223 | " 22 0.165843 0.392866 0.8762 \n", 224 | " 23 0.144036 0.327264 0.8986 \n", 225 | " 24 0.120298 0.339279 0.8942 \n", 226 | " 25 0.096919 0.281939 0.9124 \n", 227 | " 26 0.069822 0.235627 0.9279 \n", 228 | " 27 0.046569 0.21313 0.9368 \n", 229 | " 28 0.029652 0.209425 0.939 \n", 230 | " 29 0.019703 0.201688 0.9419 \n", 231 | "\n", 232 | "CPU times: user 10min 14s, sys: 8min 33s, total: 18min 47s\n", 233 | "Wall time: 18min 54s\n" 234 | ] 235 | }, 236 | { 237 | "data": { 238 | "text/plain": [ 239 | "[array([0.20169]), 0.9419000004768372]" 240 | ] 241 | }, 242 | "execution_count": 8, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "%time learn.fit(1.5, 1, wds=1e-4, cycle_len=30, use_clr_beta=(15, 10, 0.95, 0.85))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 9, 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | " \r" 261 | ] 262 | }, 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "0.948" 267 | ] 268 | }, 269 | "execution_count": 9, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "get_TTA_accuracy(learn)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "## With tweaks for training locally on a 1080ti " 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "I run the training 3 times just to make sure we hit 94% accuracy with some degree of reliability." 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 10, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "application/vnd.jupyter.widget-view+json": { 300 | "model_id": "55a183b9bbe341709b7e12f31492b923", 301 | "version_major": 2, 302 | "version_minor": 0 303 | }, 304 | "text/plain": [ 305 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 306 | ] 307 | }, 308 | "metadata": {}, 309 | "output_type": "display_data" 310 | }, 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "epoch trn_loss val_loss accuracy \n", 316 | " 0 1.169543 1.263701 0.5503 \n", 317 | " 1 0.908752 1.154146 0.6011 \n", 318 | " 2 0.771108 0.912175 0.6869 \n", 319 | " 3 0.676908 0.793713 0.7236 \n", 320 | " 4 0.628075 0.736298 0.7481 \n", 321 | " 5 0.585728 0.812324 0.7084 \n", 322 | " 6 0.556467 0.958268 0.682 \n", 323 | " 7 0.553127 0.802786 0.7342 \n", 324 | " 8 0.537662 0.661663 0.7748 \n", 325 | " 9 0.503087 0.917536 0.7149 \n", 326 | " 10 0.482562 0.680255 0.7792 \n", 327 | " 11 0.448644 0.824123 0.7396 \n", 328 | " 12 0.420774 0.750203 0.7668 \n", 329 | " 13 0.393915 0.505999 0.8259 \n", 330 | " 14 0.365846 0.544087 0.8179 \n", 331 | " 15 0.328399 0.416165 0.8565 \n", 332 | " 16 0.238594 0.296919 0.8999 \n", 333 | " 17 0.1871 0.245863 0.9157 \n", 334 | " 18 0.134778 0.218889 0.9252 \n", 335 | " 19 0.104386 0.188387 0.9381 \n", 336 | "\n", 337 | "CPU times: user 8min 2s, sys: 5min 24s, total: 13min 27s\n", 338 | "Wall time: 13min 16s\n" 339 | ] 340 | } 341 | ], 342 | "source": [ 343 | "%%time\n", 344 | "learn = get_learner(wrn_22(), 128)\n", 345 | "learn.clip = 1e-1\n", 346 | "learn.fit(1.5, 1, wds=1e-4, cycle_len=20, use_clr_beta=(12, 15, 0.95, 0.85))" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 11, 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "name": "stdout", 356 | "output_type": "stream", 357 | "text": [ 358 | " \r" 359 | ] 360 | }, 361 | { 362 | "data": { 363 | "text/plain": [ 364 | "0.9431" 365 | ] 366 | }, 367 | "execution_count": 11, 368 | "metadata": {}, 369 | "output_type": "execute_result" 370 | } 371 | ], 372 | "source": [ 373 | "get_TTA_accuracy(learn)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 12, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "application/vnd.jupyter.widget-view+json": { 384 | "model_id": "84d7e3ce03484d649dbf13a7503fc79b", 385 | "version_major": 2, 386 | "version_minor": 0 387 | }, 388 | "text/plain": [ 389 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 390 | ] 391 | }, 392 | "metadata": {}, 393 | "output_type": "display_data" 394 | }, 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "epoch trn_loss val_loss accuracy \n", 400 | " 0 1.208857 1.557635 0.5105 \n", 401 | " 1 0.928309 1.364705 0.5856 \n", 402 | " 2 0.7586 0.78556 0.7276 \n", 403 | " 3 0.671309 0.965542 0.6779 \n", 404 | " 4 0.629488 0.748223 0.7472 \n", 405 | " 5 0.567147 0.801228 0.7312 \n", 406 | " 6 0.555923 1.104574 0.651 \n", 407 | " 7 0.55037 0.768465 0.7412 \n", 408 | " 8 0.52509 1.087103 0.668 \n", 409 | " 9 0.510132 0.814182 0.7428 \n", 410 | " 10 0.474 1.115937 0.6912 \n", 411 | " 11 0.464195 0.84923 0.7256 \n", 412 | " 12 0.439663 0.473122 0.8399 \n", 413 | " 13 0.405162 0.6486 0.7843 \n", 414 | " 14 0.365501 0.469936 0.841 \n", 415 | " 15 0.322079 0.389312 0.8637 \n", 416 | " 16 0.250046 0.293756 0.8955 \n", 417 | " 17 0.189786 0.25931 0.9132 \n", 418 | " 18 0.139936 0.211995 0.9271 \n", 419 | " 19 0.094604 0.196974 0.9338 \n", 420 | "\n", 421 | "CPU times: user 8min 1s, sys: 5min 25s, total: 13min 27s\n", 422 | "Wall time: 13min 15s\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "%%time\n", 428 | "learn = get_learner(wrn_22(), 128)\n", 429 | "learn.clip = 1e-1\n", 430 | "learn.fit(1.5, 1, wds=1e-4, cycle_len=20, use_clr_beta=(12, 15, 0.95, 0.85))" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 13, 436 | "metadata": {}, 437 | "outputs": [ 438 | { 439 | "name": "stdout", 440 | "output_type": "stream", 441 | "text": [ 442 | " \r" 443 | ] 444 | }, 445 | { 446 | "data": { 447 | "text/plain": [ 448 | "0.9402" 449 | ] 450 | }, 451 | "execution_count": 13, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "get_TTA_accuracy(learn)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 14, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "data": { 467 | "application/vnd.jupyter.widget-view+json": { 468 | "model_id": "c89dc22f1fc74dbd956cadd18f13d82e", 469 | "version_major": 2, 470 | "version_minor": 0 471 | }, 472 | "text/plain": [ 473 | "HBox(children=(IntProgress(value=0, description='Epoch', max=20), HTML(value='')))" 474 | ] 475 | }, 476 | "metadata": {}, 477 | "output_type": "display_data" 478 | }, 479 | { 480 | "name": "stdout", 481 | "output_type": "stream", 482 | "text": [ 483 | "epoch trn_loss val_loss accuracy \n", 484 | " 0 1.215059 1.240503 0.5669 \n", 485 | " 1 0.899173 1.022645 0.6619 \n", 486 | " 2 0.75076 0.80045 0.7308 \n", 487 | " 3 0.676683 0.960973 0.6804 \n", 488 | " 4 0.604371 0.844541 0.7308 \n", 489 | " 5 0.579993 0.923703 0.6962 \n", 490 | " 6 0.549273 0.74115 0.7459 \n", 491 | " 7 0.538123 0.842362 0.718 \n", 492 | " 8 0.52504 1.040526 0.6702 \n", 493 | " 9 0.500089 0.84092 0.731 \n", 494 | " 10 0.464431 0.546925 0.819 \n", 495 | " 11 0.444645 0.758121 0.7587 \n", 496 | " 12 0.422681 0.573507 0.8074 \n", 497 | " 13 0.396939 0.584204 0.8102 \n", 498 | " 14 0.377264 0.599833 0.8012 \n", 499 | " 15 0.312021 0.511016 0.8347 \n", 500 | " 16 0.231998 0.320775 0.8931 \n", 501 | " 17 0.184563 0.250778 0.9159 \n", 502 | " 18 0.13895 0.214737 0.9272 \n", 503 | " 19 0.100409 0.196783 0.9358 \n", 504 | "\n", 505 | "CPU times: user 8min 1s, sys: 5min 26s, total: 13min 27s\n", 506 | "Wall time: 13min 15s\n" 507 | ] 508 | } 509 | ], 510 | "source": [ 511 | "%%time\n", 512 | "learn = get_learner(wrn_22(), 128)\n", 513 | "learn.clip = 1e-1\n", 514 | "learn.fit(1.5, 1, wds=1e-4, cycle_len=20, use_clr_beta=(12, 15, 0.95, 0.85))" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 15, 520 | "metadata": {}, 521 | "outputs": [ 522 | { 523 | "name": "stdout", 524 | "output_type": "stream", 525 | "text": [ 526 | " \r" 527 | ] 528 | }, 529 | { 530 | "data": { 531 | "text/plain": [ 532 | "0.9416" 533 | ] 534 | }, 535 | "execution_count": 15, 536 | "metadata": {}, 537 | "output_type": "execute_result" 538 | } 539 | ], 540 | "source": [ 541 | "get_TTA_accuracy(learn)" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 16, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "data": { 551 | "image/png": "\n", 552 | "text/plain": [ 553 | "
" 554 | ] 555 | }, 556 | "metadata": {}, 557 | "output_type": "display_data" 558 | } 559 | ], 560 | "source": [ 561 | "learn.sched.plot_loss()" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 17, 567 | "metadata": {}, 568 | "outputs": [ 569 | { 570 | "data": { 571 | "image/png": "\n", 572 | "text/plain": [ 573 | "
" 574 | ] 575 | }, 576 | "metadata": {}, 577 | "output_type": "display_data" 578 | } 579 | ], 580 | "source": [ 581 | "learn.sched.plot_lr()" 582 | ] 583 | } 584 | ], 585 | "metadata": { 586 | "kernelspec": { 587 | "display_name": "Python 3", 588 | "language": "python", 589 | "name": "python3" 590 | }, 591 | "language_info": { 592 | "codemirror_mode": { 593 | "name": "ipython", 594 | "version": 3 595 | }, 596 | "file_extension": ".py", 597 | "mimetype": "text/x-python", 598 | "name": "python", 599 | "nbconvert_exporter": "python", 600 | "pygments_lexer": "ipython3", 601 | "version": "3.6.5" 602 | } 603 | }, 604 | "nbformat": 4, 605 | "nbformat_minor": 2 606 | } 607 | --------------------------------------------------------------------------------