├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── SERVER.md ├── config.py ├── imgs ├── cifar100_image.png ├── cifar10_image.png ├── img_356.lua ├── pytorch.png └── svhn_image.png ├── main.py ├── networks ├── __init__.py ├── lenet.py ├── resnet.py ├── vggnet.py └── wide_resnet.py └── scripts ├── cifar100_train.sh ├── cifar10_train.sh └── resnet_cifar100_train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # Torch 104 | *.t7 105 | 106 | # Data 107 | data/ 108 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # INSTALLATION GUIDE 2 | This is the Installation guide for the overall repository. 3 | 4 | ## Install NVIDIA-driver 5 | GIGABYTE based BIOS setting 6 | - First, DO NOT PLUG IN your GPU until the driver is set up. 7 | - Internal Graphic : Auto > Enable 8 | - Display Priority : PCIe > Internal 9 | 10 | NVIDIA driver download .run file : [click here](http://www.nvidia.co.kr/Download/index.aspx) 11 | 12 | If you click download from the above site, you will get a .run file format for installing drivers. 13 | 14 | ### 1. Stop display manager 15 | 16 | Before you run the .run file, you first need to stop your Xserver display manager. 17 | 18 | Press [Ctrl] + [Alt] + [F1], enter the script below 19 | 20 | ```bash 21 | $ service --status-all | grep dm 22 | 23 | (Result) [+] [:dm] 24 | ``` 25 | 26 | The part described as [:dm] is your display manager. 27 | 28 | Substitute the [:dm] part below with the result of the script above. 29 | 30 | ```bash 31 | $ sudo service [:dm] stop 32 | 33 | (Result) * Stopping Light Display Manager [:dm] 34 | ``` 35 | 36 | ### 2. Run the nvidia-driver installer 37 | 38 | Run the code below. Press 'Yes' for every option they ask. 39 | 40 | ```bash 41 | $ sh /NVIDIA-Linux_x86_64-375.20.run 42 | ``` 43 | 44 | After you have successfully installed, you shall see the same results when typing the code below. 45 | 46 | ```bash 47 | $ nvidia-smi 48 | ``` 49 | 50 | ```bash 51 | +-----------------------------------------------------------------------------+ 52 | | NVIDIA-SMI 375.20 Driver Version: 375.20 | 53 | |-------------------------------+----------------------+----------------------+ 54 | | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | 55 | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | 56 | |===============================+======================+======================| 57 | | 0 TITAN X (Pascal) Off | 0000:4B:00.0 Off | N/A | 58 | | 67% 86C P2 249W / 250W | 5026MiB / 12221MiB | 82% Default | 59 | +-------------------------------+----------------------+----------------------+ 60 | | 1 TITAN X (Pascal) Off | 0000:4C:00.0 On | N/A | 61 | | 88% 90C P2 225W / 250W | 3842MiB / 12213MiB | 78% Default | 62 | +-------------------------------+----------------------+----------------------+ 63 | 64 | +-----------------------------------------------------------------------------+ 65 | | Processes: GPU Memory | 66 | | GPU PID Type Process name Usage | 67 | |=============================================================================| 68 | | No running processes found | 69 | +-----------------------------------------------------------------------------+ 70 | 71 | ``` 72 | 73 | ### 3. Reboot the system 74 | 75 | ```bash 76 | $ sudo reboot 77 | ``` 78 | 79 | ## Install CUDA toolkit 80 | 81 | You can skip the step above and automatically install the driver within CUDA installation. 82 | 83 | Check the details below. 84 | 85 | ### * Download the CUDA download .run file from the CUDA download page 86 | 87 | CUDA download page : [click here](https://developer.nvidia.com/cuda-downloads) 88 | 89 | Before executing the file, stop the display manager by following the description above. 90 | 91 | ```bash 92 | $ sudo sh /cuda_8.0.44_linux.run 93 | ``` 94 | 95 | ### * Link your CUDA in .bashrc 96 | 97 | ```bash 98 | $ sudo apt-get install vim 99 | 100 | $ git clone https://github.com/amix/vimrc.git ~/.vim_runtime 101 | 102 | # Awesome version 103 | $ sh ~/.vim_runtime/install_awesome_vimrc.sh 104 | 105 | # Basic version 106 | $ sh ~/.vim_runtime/install_basic_vimrc.sh 107 | ``` 108 | 109 | Open your ~/.bashrc file. 110 | 111 | ```bash 112 | vi ~/.bashrc 113 | 114 | # Press Shift + G, Add the lines on the bottom 115 | 116 | export PATH=$PATH:/usr/local/cuda/bin 117 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64 118 | export CUDA_HOME=/usr/local/cuda 119 | ``` 120 | 121 | To check if the CUDA toolkit is successfully installed, type the line below. 122 | 123 | ```bash 124 | $ nvcc --version 125 | 126 | * (Result) 127 | nvcc: NVIDIA (R) Cuda compiler driver 128 | Copyright (c) 2005-2016 NVIDIA Corporation 129 | Built on Sun_Sep__4_22:14:01_CDT_2016 130 | Cuda compilation tools, release 8.0, V8.0.44 131 | ``` 132 | 133 | ## Install cuDNN library kit 134 | 135 | cuDNN download page : [click here](https://developer.nvidia.com/rdp/cudnn-download) 136 | 137 | (Membership is required, just sign in!) 138 | 139 | Download the newest cuDNN v5.1. 140 | 141 | ```bash 142 | $ cd 143 | $ tar -zxvf ./cudnn-8.0-linux-x64-v5.1.tgz 144 | $ sudo cp cuda/include/cudnn.h /usr/local/cuda/include/ 145 | $ sudo cp cuda/lib64/libcudnn* /usr/local/cuda/lib64/ 146 | $ sudo chmod a+r /usr/local/cuda/include/cudnn.h 147 | $ sudo chmod a+r /usr/local/cuda/lib64/libcudnn* 148 | ``` 149 | 150 | ## Install Tensorflow 151 | 152 | Tensorflow install page : [click here](https://www.tensorflow.org/get_started/os_setup) 153 | 154 | ```bash 155 | $ sudo apt-get install python-pip python-dev 156 | $ pip install --upgrade pip 157 | $ pip install tensorflow-gpu 158 | ``` 159 | 160 | ## Install Torch 161 | 162 | Torch install page : [click here](http://torch.ch/docs/getting-started.html) 163 | 164 | ```bash 165 | $ git clone https://github.com/torch/distro.git ~/torch --recursive 166 | $ cd ~/torch; bash install-deps; 167 | $ ./install.sh 168 | $ source ~/.bashrc 169 | ``` 170 | 171 | ## Now, Enjoy! 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Bumsoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | Best CIFAR-10, CIFAR-100 results with wide-residual networks using PyTorch 4 | 5 | Pytorch Implementation of Sergey Zagoruyko's [Wide Residual Networks](https://arxiv.org/pdf/1605.07146v2.pdf) 6 | 7 | For Torch implementations, see [here](https://github.com/meliketoy/wide-residual-network). 8 | 9 | ## Requirements 10 | See the [installation instruction](INSTALL.md) for a step-by-step installation guide. 11 | See the [server instruction](SERVER.md) for server settup. 12 | - Install [cuda-8.0](https://developer.nvidia.com/cuda-downloads) 13 | - Install [cudnn v5.1](https://developer.nvidia.com/cudnn) 14 | - Download [Pytorch 2.7](https://pytorch.org) and clone the repository. 15 | ```bash 16 | pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp27-none-linux_x86_64.whl 17 | pip install torchvision 18 | git clone https://github.com/meliketoy/wide-resnet.pytorch 19 | ``` 20 | 21 | ## How to run 22 | After you have cloned the repository, you can train each dataset of either cifar10, cifar100 by running the script below. 23 | ```bash 24 | python main --lr 0.1 resume false --net_type [lenet/vggnet/resnet/wide-resnet] --depth 28 --widen_factor 10 --dropout_rate 0.3 --dataset [cifar10/cifar100] 25 | ``` 26 | 27 | ## Implementation Details 28 | 29 | | epoch | learning rate | weight decay | Optimizer | Momentum | Nesterov | 30 | |:---------:|:-------------:|:-------------:|:---------:|:--------:|:--------:| 31 | | 0 ~ 60 | 0.1 | 0.0005 | Momentum | 0.9 | true | 32 | | 61 ~ 120 | 0.02 | 0.0005 | Momentum | 0.9 | true | 33 | | 121 ~ 160 | 0.004 | 0.0005 | Momentum | 0.9 | true | 34 | | 161 ~ 200 | 0.0008 | 0.0005 | Momentum | 0.9 | true | 35 | 36 | ## CIFAR-10 Results 37 | 38 | ![alt tag](imgs/cifar10_image.png) 39 | 40 | Below is the result of the test set accuracy for **CIFAR-10 dataset** training. 41 | 42 | **Accuracy is the average of 5 runs** 43 | 44 | | network | dropout | preprocess | GPU:0 | GPU:1 | per epoch | accuracy(%) | 45 | |:-----------------:|:-------:|:----------:|:-----:|:-----:|:------------:|:-----------:| 46 | | wide-resnet 28x10 | 0 | ZCA | 5.90G | - | 2 min 03 sec | 95.83 | 47 | | wide-resnet 28x10 | 0 | meanstd | 5.90G | - | 2 min 03 sec | 96.21 | 48 | | wide-resnet 28x10 | 0.3 | meanstd | 5.90G | - | 2 min 03 sec | 96.27 | 49 | | wide-resnet 28x20 | 0.3 | meanstd | 8.13G | 6.93G | 4 min 10 sec | **96.55** | 50 | | wide-resnet 40x10 | 0.3 | meanstd | 8.08G | - | 3 min 13 sec | 96.31 | 51 | | wide-resnet 40x14 | 0.3 | meanstd | 7.37G | 6.46G | 3 min 23 sec | 96.34 | 52 | 53 | ## CIFAR-100 Results 54 | 55 | ![alt tag](imgs/cifar100_image.png) 56 | 57 | Below is the result of the test set accuracy for **CIFAR-100 dataset** training. 58 | 59 | **Accuracy is the average of 5 runs** 60 | 61 | | network | dropout | preprocess | GPU:0 | GPU:1 | per epoch | Top1 acc(%)| Top5 acc(%) | 62 | |:-----------------:|:-------:|:-----------:|:-----:|:-----:|:------------:|:----------:|:-----------:| 63 | | wide-resnet 28x10 | 0 | ZCA | 5.90G | - | 2 min 03 sec | 80.07 | 95.02 | 64 | | wide-resnet 28x10 | 0 | meanstd | 5.90G | - | 2 min 03 sec | 81.02 | 95.41 | 65 | | wide-resnet 28x10 | 0.3 | meanstd | 5.90G | - | 2 min 03 sec | 81.49 | 95.62 | 66 | | wide-resnet 28x20 | 0.3 | meanstd | 8.13G | 6.93G | 4 min 05 sec | **82.45** | **96.11** | 67 | | wide-resnet 40x10 | 0.3 | meanstd | 8.93G | - | 3 min 06 sec | 81.42 | 95.63 | 68 | | wide-resnet 40x14 | 0.3 | meanstd | 7.39G | 6.46G | 3 min 23 sec | 81.87 | 95.51 | 69 | -------------------------------------------------------------------------------- /SERVER.md: -------------------------------------------------------------------------------- 1 | # SERVER MANAGEMENT 2 | This is the management guide for server installation. 3 | 4 | ## Welcome message 5 | Install figlet 6 | ```bash 7 | $ sudo apt-get install figlet 8 | ``` 9 | 10 | ```bash 11 | $ sudo vi /etc/bash.bashrc 12 | 13 | # Press [Shift] + [G] and write the code on the bottom 14 | clear 15 | printf "Welcome to Ubuntu 16.04.5 LTS (GNU/Linux-Mint-18 x86_64)\n" 16 | printf "This is the server for the wide-residual-network implementation.\n\n" 17 | printf " * Documentation: https://github.com/meliketoy/wide-residual-network\n\n" 18 | printf "##############################################################\n" 19 | figlet -f slant "Bumsoo Kim" 20 | printf "\n\n" 21 | printf " Data Mining & Information System Lab\n" 22 | printf " GPU Computing machine : bumsoo@163.152.163.10\n\n" 23 | printf " Administrator : Bumsoo Kim\n" 24 | printf " Please read the document\n" 25 | printf " https://github.com/meliketoy/wide-residual-network/README.md\n" 26 | printf "##############################################################\n\n" 27 | ``` 28 | 29 | ## Remote Server control 30 | 31 | ### 1. SCP call 32 | 33 | ```bash 34 | 35 | # Upload your local file to server 36 | $ scp -P 22 /file [:username]@[:server_ip]: 37 | 38 | # Download the server file to your local 39 | $ scp -P 22 [:username]@[:server_ip]:/file 40 | 41 | # Upload your local directory to server 42 | $ scp -P 22 -r /file [:username]@[:server_ip]: 43 | 44 | # Download the server file to your local 45 | $ scp -P 22 -r [:username]@[:server_ip]:/file 46 | 47 | ``` 48 | 49 | ### 2. Save sessions by name 50 | ```bash 51 | $ sudo vi ~/.bashrc 52 | 53 | # Press [Shift] + [G] and enter the function on the bottom. 54 | 55 | function server_func() { 56 | echo -n "[Enter the name of server]: " 57 | read server_name 58 | 59 | # echo "Logging in to server $server_name ..." 60 | if [ $server_name == [:servername] ]; then 61 | ssh [:usr]@[:ip].[:ip].[:ip].[:ip] 62 | fi 63 | } 64 | alias dmis_remote=server_func 65 | ``` 66 | 67 | ## Github control 68 | 69 | ```bash 70 | $ sudo vi ~/.netrc 71 | 72 | machine github.com 73 | login [:username] 74 | password [:password] 75 | ``` 76 | 77 | ### Jupyter notebook configuration 78 | 79 | For jupyter notebook configuration, type in the command line below. 80 | ```bash 81 | $ jupyter notebook --generate-config 82 | 83 | * Result : 84 | Writing default config to: /.jupyter/jupyter_notebook_config.py 85 | 86 | $ vi ~/.jupyter/jupyter_notebook_config.py 87 | ``` 88 | 89 | presh [Esc], then enter /ip to find the ip configuration. You will find the line below 90 | ``` bash 91 | ## The IP address the notebook server will listen on. 92 | #c.NotebookApp.ip = 'localhost' 93 | ``` 94 | 95 | Erase the '#' and change it into ... 96 | ```bash 97 | c.NotebookApp.ip = '163.152.163.112' # the ip address for your server 98 | ``` 99 | 100 | presh [Esc], then enter /port to find the port number. You will find the line below 101 | ```bash 102 | ## The port the notebook server will listen on. 103 | #c.NotebookApp.port = 8888 104 | ``` 105 | 106 | Erase the '#' and enter whatever port number you want 107 | ```bash 108 | c.NotebookApp.port = 9999 109 | 110 | ``` 111 | 112 | Now, Enjoy! 113 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ############### Pytorch CIFAR configuration file ############### 2 | import math 3 | 4 | start_epoch = 1 5 | num_epochs = 200 6 | batch_size = 128 7 | optim_type = 'SGD' 8 | 9 | mean = { 10 | 'cifar10': (0.4914, 0.4822, 0.4465), 11 | 'cifar100': (0.5071, 0.4867, 0.4408), 12 | } 13 | 14 | std = { 15 | 'cifar10': (0.2023, 0.1994, 0.2010), 16 | 'cifar100': (0.2675, 0.2565, 0.2761), 17 | } 18 | 19 | # Only for cifar-10 20 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 21 | 22 | def learning_rate(init, epoch): 23 | optim_factor = 0 24 | if(epoch > 160): 25 | optim_factor = 3 26 | elif(epoch > 120): 27 | optim_factor = 2 28 | elif(epoch > 60): 29 | optim_factor = 1 30 | 31 | return init*math.pow(0.2, optim_factor) 32 | 33 | def get_hms(seconds): 34 | m, s = divmod(seconds, 60) 35 | h, m = divmod(m, 60) 36 | 37 | return h, m, s 38 | -------------------------------------------------------------------------------- /imgs/cifar100_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/cifar100_image.png -------------------------------------------------------------------------------- /imgs/cifar10_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/cifar10_image.png -------------------------------------------------------------------------------- /imgs/img_356.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | 3 | new_path = 'svhn_image.png' 4 | 5 | img = image.load('cifar10_image.png') 6 | new_img = image.load(new_path) 7 | 8 | print(img:size()) 9 | print(new_img:size()) 10 | 11 | trans_img = image.scale(new_img, img:size(2), img:size(3), 'bicubic') 12 | image.save(new_path, trans_img) 13 | 14 | print("Image Convert Completed!") 15 | print(img:size()) 16 | print(trans_img:size()) 17 | -------------------------------------------------------------------------------- /imgs/pytorch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/pytorch.png -------------------------------------------------------------------------------- /imgs/svhn_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/svhn_image.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | import torch.backends.cudnn as cudnn 8 | import config as cf 9 | 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import os 14 | import sys 15 | import time 16 | import argparse 17 | import datetime 18 | 19 | from networks import * 20 | from torch.autograd import Variable 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') 23 | parser.add_argument('--lr', default=0.1, type=float, help='learning_rate') 24 | parser.add_argument('--net_type', default='wide-resnet', type=str, help='model') 25 | parser.add_argument('--depth', default=28, type=int, help='depth of model') 26 | parser.add_argument('--widen_factor', default=10, type=int, help='width of model') 27 | parser.add_argument('--dropout', default=0.3, type=float, help='dropout_rate') 28 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/cifar100]') 29 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 30 | parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model') 31 | args = parser.parse_args() 32 | 33 | # Hyper Parameter settings 34 | use_cuda = torch.cuda.is_available() 35 | best_acc = 0 36 | start_epoch, num_epochs, batch_size, optim_type = cf.start_epoch, cf.num_epochs, cf.batch_size, cf.optim_type 37 | 38 | # Data Uplaod 39 | print('\n[Phase 1] : Data Preparation') 40 | transform_train = transforms.Compose([ 41 | transforms.RandomCrop(32, padding=4), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), 45 | ]) # meanstd transformation 46 | 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]), 50 | ]) 51 | 52 | if(args.dataset == 'cifar10'): 53 | print("| Preparing CIFAR-10 dataset...") 54 | sys.stdout.write("| ") 55 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 56 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test) 57 | num_classes = 10 58 | elif(args.dataset == 'cifar100'): 59 | print("| Preparing CIFAR-100 dataset...") 60 | sys.stdout.write("| ") 61 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 62 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test) 63 | num_classes = 100 64 | 65 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) 66 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 67 | 68 | # Return network & file name 69 | def getNetwork(args): 70 | if (args.net_type == 'lenet'): 71 | net = LeNet(num_classes) 72 | file_name = 'lenet' 73 | elif (args.net_type == 'vggnet'): 74 | net = VGG(args.depth, num_classes) 75 | file_name = 'vgg-'+str(args.depth) 76 | elif (args.net_type == 'resnet'): 77 | net = ResNet(args.depth, num_classes) 78 | file_name = 'resnet-'+str(args.depth) 79 | elif (args.net_type == 'wide-resnet'): 80 | net = Wide_ResNet(args.depth, args.widen_factor, args.dropout, num_classes) 81 | file_name = 'wide-resnet-'+str(args.depth)+'x'+str(args.widen_factor) 82 | else: 83 | print('Error : Network should be either [LeNet / VGGNet / ResNet / Wide_ResNet') 84 | sys.exit(0) 85 | 86 | return net, file_name 87 | 88 | # Test only option 89 | if (args.testOnly): 90 | print('\n[Test Phase] : Model setup') 91 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' 92 | _, file_name = getNetwork(args) 93 | checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+'.t7') 94 | net = checkpoint['net'] 95 | 96 | if use_cuda: 97 | net.cuda() 98 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 99 | cudnn.benchmark = True 100 | 101 | net.eval() 102 | net.training = False 103 | test_loss = 0 104 | correct = 0 105 | total = 0 106 | 107 | with torch.no_grad(): 108 | for batch_idx, (inputs, targets) in enumerate(testloader): 109 | if use_cuda: 110 | inputs, targets = inputs.cuda(), targets.cuda() 111 | inputs, targets = Variable(inputs), Variable(targets) 112 | outputs = net(inputs) 113 | 114 | _, predicted = torch.max(outputs.data, 1) 115 | total += targets.size(0) 116 | correct += predicted.eq(targets.data).cpu().sum() 117 | 118 | acc = 100.*correct/total 119 | print("| Test Result\tAcc@1: %.2f%%" %(acc)) 120 | 121 | sys.exit(0) 122 | 123 | # Model 124 | print('\n[Phase 2] : Model setup') 125 | if args.resume: 126 | # Load checkpoint 127 | print('| Resuming from checkpoint...') 128 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!' 129 | _, file_name = getNetwork(args) 130 | checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+'.t7') 131 | net = checkpoint['net'] 132 | best_acc = checkpoint['acc'] 133 | start_epoch = checkpoint['epoch'] 134 | else: 135 | print('| Building net type [' + args.net_type + ']...') 136 | net, file_name = getNetwork(args) 137 | net.apply(conv_init) 138 | 139 | if use_cuda: 140 | net.cuda() 141 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count())) 142 | cudnn.benchmark = True 143 | 144 | criterion = nn.CrossEntropyLoss() 145 | 146 | # Training 147 | def train(epoch): 148 | net.train() 149 | net.training = True 150 | train_loss = 0 151 | correct = 0 152 | total = 0 153 | optimizer = optim.SGD(net.parameters(), lr=cf.learning_rate(args.lr, epoch), momentum=0.9, weight_decay=5e-4) 154 | 155 | print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, epoch))) 156 | for batch_idx, (inputs, targets) in enumerate(trainloader): 157 | if use_cuda: 158 | inputs, targets = inputs.cuda(), targets.cuda() # GPU settings 159 | optimizer.zero_grad() 160 | inputs, targets = Variable(inputs), Variable(targets) 161 | outputs = net(inputs) # Forward Propagation 162 | loss = criterion(outputs, targets) # Loss 163 | loss.backward() # Backward Propagation 164 | optimizer.step() # Optimizer update 165 | 166 | train_loss += loss.item() 167 | _, predicted = torch.max(outputs.data, 1) 168 | total += targets.size(0) 169 | correct += predicted.eq(targets.data).cpu().sum() 170 | 171 | sys.stdout.write('\r') 172 | sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' 173 | %(epoch, num_epochs, batch_idx+1, 174 | (len(trainset)//batch_size)+1, loss.item(), 100.*correct/total)) 175 | sys.stdout.flush() 176 | 177 | def test(epoch): 178 | global best_acc 179 | net.eval() 180 | net.training = False 181 | test_loss = 0 182 | correct = 0 183 | total = 0 184 | with torch.no_grad(): 185 | for batch_idx, (inputs, targets) in enumerate(testloader): 186 | if use_cuda: 187 | inputs, targets = inputs.cuda(), targets.cuda() 188 | inputs, targets = Variable(inputs), Variable(targets) 189 | outputs = net(inputs) 190 | loss = criterion(outputs, targets) 191 | 192 | test_loss += loss.item() 193 | _, predicted = torch.max(outputs.data, 1) 194 | total += targets.size(0) 195 | correct += predicted.eq(targets.data).cpu().sum() 196 | 197 | # Save checkpoint when best model 198 | acc = 100.*correct/total 199 | print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.item(), acc)) 200 | 201 | if acc > best_acc: 202 | print('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc)) 203 | state = { 204 | 'net':net.module if use_cuda else net, 205 | 'acc':acc, 206 | 'epoch':epoch, 207 | } 208 | if not os.path.isdir('checkpoint'): 209 | os.mkdir('checkpoint') 210 | save_point = './checkpoint/'+args.dataset+os.sep 211 | if not os.path.isdir(save_point): 212 | os.mkdir(save_point) 213 | torch.save(state, save_point+file_name+'.t7') 214 | best_acc = acc 215 | 216 | print('\n[Phase 3] : Training model') 217 | print('| Training Epochs = ' + str(num_epochs)) 218 | print('| Initial Learning Rate = ' + str(args.lr)) 219 | print('| Optimizer = ' + str(optim_type)) 220 | 221 | elapsed_time = 0 222 | for epoch in range(start_epoch, start_epoch+num_epochs): 223 | start_time = time.time() 224 | 225 | train(epoch) 226 | test(epoch) 227 | 228 | epoch_time = time.time() - start_time 229 | elapsed_time += epoch_time 230 | print('| Elapsed time : %d:%02d:%02d' %(cf.get_hms(elapsed_time))) 231 | 232 | print('\n[Phase 4] : Testing model') 233 | print('* Test results : Acc@1 = %.2f%%' %(best_acc)) 234 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .lenet import * 2 | from .vggnet import * 3 | from .resnet import * 4 | from .wide_resnet import * 5 | -------------------------------------------------------------------------------- /networks/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def conv_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 8 | init.constant(m.bias, 0) 9 | 10 | class LeNet(nn.Module): 11 | def __init__(self, num_classes): 12 | super(LeNet, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 6, 5) 14 | self.conv2 = nn.Conv2d(6, 16, 5) 15 | self.fc1 = nn.Linear(16*5*5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, num_classes) 18 | 19 | def forward(self, x): 20 | out = F.relu(self.conv1(x)) 21 | out = F.max_pool2d(out, 2) 22 | out = F.relu(self.conv2(out)) 23 | out = F.max_pool2d(out, 2) 24 | out = out.view(out.size(0), -1) 25 | out = F.relu(self.fc1(out)) 26 | out = F.relu(self.fc2(out)) 27 | out = self.fc3(out) 28 | 29 | return(out) 30 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | import sys 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 15 | init.constant(m.bias, 0) 16 | 17 | def cfg(depth): 18 | depth_lst = [18, 34, 50, 101, 152] 19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" 20 | cf_dict = { 21 | '18': (BasicBlock, [2,2,2,2]), 22 | '34': (BasicBlock, [3,4,6,3]), 23 | '50': (Bottleneck, [3,4,6,3]), 24 | '101':(Bottleneck, [3,4,23,3]), 25 | '152':(Bottleneck, [3,8,36,3]), 26 | } 27 | 28 | return cf_dict[str(depth)] 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, in_planes, planes, stride=1): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(in_planes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion * planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 44 | nn.BatchNorm2d(self.expansion*planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | 53 | return out 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, in_planes, planes, stride=1): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True) 65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 66 | 67 | self.shortcut = nn.Sequential() 68 | if stride != 1 or in_planes != self.expansion*planes: 69 | self.shortcut = nn.Sequential( 70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True), 71 | nn.BatchNorm2d(self.expansion*planes) 72 | ) 73 | 74 | def forward(self, x): 75 | out = F.relu(self.bn1(self.conv1(x))) 76 | out = F.relu(self.bn2(self.conv2(out))) 77 | out = self.bn3(self.conv3(out)) 78 | out += self.shortcut(x) 79 | out = F.relu(out) 80 | 81 | return out 82 | 83 | class ResNet(nn.Module): 84 | def __init__(self, depth, num_classes): 85 | super(ResNet, self).__init__() 86 | self.in_planes = 16 87 | 88 | block, num_blocks = cfg(depth) 89 | 90 | self.conv1 = conv3x3(3,16) 91 | self.bn1 = nn.BatchNorm2d(16) 92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 95 | self.linear = nn.Linear(64*block.expansion, num_classes) 96 | 97 | def _make_layer(self, block, planes, num_blocks, stride): 98 | strides = [stride] + [1]*(num_blocks-1) 99 | layers = [] 100 | 101 | for stride in strides: 102 | layers.append(block(self.in_planes, planes, stride)) 103 | self.in_planes = planes * block.expansion 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | out = F.relu(self.bn1(self.conv1(x))) 109 | out = self.layer1(out) 110 | out = self.layer2(out) 111 | out = self.layer3(out) 112 | out = F.avg_pool2d(out, 8) 113 | out = out.view(out.size(0), -1) 114 | out = self.linear(out) 115 | 116 | return out 117 | 118 | if __name__ == '__main__': 119 | net=ResNet(50, 10) 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | -------------------------------------------------------------------------------- /networks/vggnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | def conv_init(m): 6 | classname = m.__class__.__name__ 7 | if classname.find('Conv') != -1: 8 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 9 | init.constant(m.bias, 0) 10 | 11 | def cfg(depth): 12 | depth_lst = [11, 13, 16, 19] 13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19" 14 | cf_dict = { 15 | '11': [ 16 | 64, 'mp', 17 | 128, 'mp', 18 | 256, 256, 'mp', 19 | 512, 512, 'mp', 20 | 512, 512, 'mp'], 21 | '13': [ 22 | 64, 64, 'mp', 23 | 128, 128, 'mp', 24 | 256, 256, 'mp', 25 | 512, 512, 'mp', 26 | 512, 512, 'mp' 27 | ], 28 | '16': [ 29 | 64, 64, 'mp', 30 | 128, 128, 'mp', 31 | 256, 256, 256, 'mp', 32 | 512, 512, 512, 'mp', 33 | 512, 512, 512, 'mp' 34 | ], 35 | '19': [ 36 | 64, 64, 'mp', 37 | 128, 128, 'mp', 38 | 256, 256, 256, 256, 'mp', 39 | 512, 512, 512, 512, 'mp', 40 | 512, 512, 512, 512, 'mp' 41 | ], 42 | } 43 | 44 | return cf_dict[str(depth)] 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 48 | 49 | class VGG(nn.Module): 50 | def __init__(self, depth, num_classes): 51 | super(VGG, self).__init__() 52 | self.features = self._make_layers(cfg(depth)) 53 | self.linear = nn.Linear(512, num_classes) 54 | 55 | def forward(self, x): 56 | out = self.features(x) 57 | out = out.view(out.size(0), -1) 58 | out = self.linear(out) 59 | 60 | return out 61 | 62 | def _make_layers(self, cfg): 63 | layers = [] 64 | in_planes = 3 65 | 66 | for x in cfg: 67 | if x == 'mp': 68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 69 | else: 70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] 71 | in_planes = x 72 | 73 | # After cfg convolution 74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 75 | return nn.Sequential(*layers) 76 | 77 | if __name__ == "__main__": 78 | net = VGG(16, 10) 79 | y = net(Variable(torch.randn(1,3,32,32))) 80 | print(y.size()) 81 | -------------------------------------------------------------------------------- /networks/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import sys 8 | import numpy as np 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find('BatchNorm') != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | class wide_basic(nn.Module): 23 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 24 | super(wide_basic, self).__init__() 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 27 | self.dropout = nn.Dropout(p=dropout_rate) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 35 | ) 36 | 37 | def forward(self, x): 38 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += self.shortcut(x) 41 | 42 | return out 43 | 44 | class Wide_ResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(Wide_ResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 50 | n = (depth-4)/6 51 | k = widen_factor 52 | 53 | print('| Wide-Resnet %dx%d' %(depth, k)) 54 | nStages = [16, 16*k, 32*k, 64*k] 55 | 56 | self.conv1 = conv3x3(3,nStages[0]) 57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 61 | self.linear = nn.Linear(nStages[3], num_classes) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1]*(int(num_blocks)-1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = F.avg_pool2d(out, 8) 80 | out = out.view(out.size(0), -1) 81 | out = self.linear(out) 82 | 83 | return out 84 | 85 | if __name__ == '__main__': 86 | net=Wide_ResNet(28, 10, 0.3, 10) 87 | y = net(Variable(torch.randn(1,3,32,32))) 88 | 89 | print(y.size()) 90 | -------------------------------------------------------------------------------- /scripts/cifar100_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export netType='wide-resnet' 3 | export depth=28 4 | export width=10 5 | export dataset='cifar100' 6 | 7 | python main.py \ 8 | --lr 0.1 \ 9 | --net_type ${netType} \ 10 | --depth ${depth} \ 11 | --widen_factor ${width} \ 12 | --dropout 0 \ 13 | --dataset ${dataset} 14 | -------------------------------------------------------------------------------- /scripts/cifar10_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export netType='wide-resnet' 3 | export depth=28 4 | export width=10 5 | export dataset='cifar10' 6 | 7 | python main.py \ 8 | --lr 0.1 \ 9 | --net_type ${netType} \ 10 | --depth ${depth} \ 11 | --widen_factor ${width} \ 12 | --dropout 0 \ 13 | --dataset ${dataset} \ 14 | -------------------------------------------------------------------------------- /scripts/resnet_cifar100_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export netType='resnet' 3 | export depth=18 4 | export dataset='cifar100' 5 | 6 | python main.py \ 7 | --lr 0.1 \ 8 | --net_type ${netType} \ 9 | --depth ${depth} \ 10 | --dropout 0 \ 11 | --dataset ${dataset} 12 | --------------------------------------------------------------------------------