├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── SERVER.md
├── config.py
├── imgs
├── cifar100_image.png
├── cifar10_image.png
├── img_356.lua
├── pytorch.png
└── svhn_image.png
├── main.py
├── networks
├── __init__.py
├── lenet.py
├── resnet.py
├── vggnet.py
└── wide_resnet.py
└── scripts
├── cifar100_train.sh
├── cifar10_train.sh
└── resnet_cifar100_train.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # Torch
104 | *.t7
105 |
106 | # Data
107 | data/
108 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # INSTALLATION GUIDE
2 | This is the Installation guide for the overall repository.
3 |
4 | ## Install NVIDIA-driver
5 | GIGABYTE based BIOS setting
6 | - First, DO NOT PLUG IN your GPU until the driver is set up.
7 | - Internal Graphic : Auto > Enable
8 | - Display Priority : PCIe > Internal
9 |
10 | NVIDIA driver download .run file : [click here](http://www.nvidia.co.kr/Download/index.aspx)
11 |
12 | If you click download from the above site, you will get a .run file format for installing drivers.
13 |
14 | ### 1. Stop display manager
15 |
16 | Before you run the .run file, you first need to stop your Xserver display manager.
17 |
18 | Press [Ctrl] + [Alt] + [F1], enter the script below
19 |
20 | ```bash
21 | $ service --status-all | grep dm
22 |
23 | (Result) [+] [:dm]
24 | ```
25 |
26 | The part described as [:dm] is your display manager.
27 |
28 | Substitute the [:dm] part below with the result of the script above.
29 |
30 | ```bash
31 | $ sudo service [:dm] stop
32 |
33 | (Result) * Stopping Light Display Manager [:dm]
34 | ```
35 |
36 | ### 2. Run the nvidia-driver installer
37 |
38 | Run the code below. Press 'Yes' for every option they ask.
39 |
40 | ```bash
41 | $ sh
/NVIDIA-Linux_x86_64-375.20.run
42 | ```
43 |
44 | After you have successfully installed, you shall see the same results when typing the code below.
45 |
46 | ```bash
47 | $ nvidia-smi
48 | ```
49 |
50 | ```bash
51 | +-----------------------------------------------------------------------------+
52 | | NVIDIA-SMI 375.20 Driver Version: 375.20 |
53 | |-------------------------------+----------------------+----------------------+
54 | | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
55 | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
56 | |===============================+======================+======================|
57 | | 0 TITAN X (Pascal) Off | 0000:4B:00.0 Off | N/A |
58 | | 67% 86C P2 249W / 250W | 5026MiB / 12221MiB | 82% Default |
59 | +-------------------------------+----------------------+----------------------+
60 | | 1 TITAN X (Pascal) Off | 0000:4C:00.0 On | N/A |
61 | | 88% 90C P2 225W / 250W | 3842MiB / 12213MiB | 78% Default |
62 | +-------------------------------+----------------------+----------------------+
63 |
64 | +-----------------------------------------------------------------------------+
65 | | Processes: GPU Memory |
66 | | GPU PID Type Process name Usage |
67 | |=============================================================================|
68 | | No running processes found |
69 | +-----------------------------------------------------------------------------+
70 |
71 | ```
72 |
73 | ### 3. Reboot the system
74 |
75 | ```bash
76 | $ sudo reboot
77 | ```
78 |
79 | ## Install CUDA toolkit
80 |
81 | You can skip the step above and automatically install the driver within CUDA installation.
82 |
83 | Check the details below.
84 |
85 | ### * Download the CUDA download .run file from the CUDA download page
86 |
87 | CUDA download page : [click here](https://developer.nvidia.com/cuda-downloads)
88 |
89 | Before executing the file, stop the display manager by following the description above.
90 |
91 | ```bash
92 | $ sudo sh /cuda_8.0.44_linux.run
93 | ```
94 |
95 | ### * Link your CUDA in .bashrc
96 |
97 | ```bash
98 | $ sudo apt-get install vim
99 |
100 | $ git clone https://github.com/amix/vimrc.git ~/.vim_runtime
101 |
102 | # Awesome version
103 | $ sh ~/.vim_runtime/install_awesome_vimrc.sh
104 |
105 | # Basic version
106 | $ sh ~/.vim_runtime/install_basic_vimrc.sh
107 | ```
108 |
109 | Open your ~/.bashrc file.
110 |
111 | ```bash
112 | vi ~/.bashrc
113 |
114 | # Press Shift + G, Add the lines on the bottom
115 |
116 | export PATH=$PATH:/usr/local/cuda/bin
117 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64
118 | export CUDA_HOME=/usr/local/cuda
119 | ```
120 |
121 | To check if the CUDA toolkit is successfully installed, type the line below.
122 |
123 | ```bash
124 | $ nvcc --version
125 |
126 | * (Result)
127 | nvcc: NVIDIA (R) Cuda compiler driver
128 | Copyright (c) 2005-2016 NVIDIA Corporation
129 | Built on Sun_Sep__4_22:14:01_CDT_2016
130 | Cuda compilation tools, release 8.0, V8.0.44
131 | ```
132 |
133 | ## Install cuDNN library kit
134 |
135 | cuDNN download page : [click here](https://developer.nvidia.com/rdp/cudnn-download)
136 |
137 | (Membership is required, just sign in!)
138 |
139 | Download the newest cuDNN v5.1.
140 |
141 | ```bash
142 | $ cd
143 | $ tar -zxvf ./cudnn-8.0-linux-x64-v5.1.tgz
144 | $ sudo cp cuda/include/cudnn.h /usr/local/cuda/include/
145 | $ sudo cp cuda/lib64/libcudnn* /usr/local/cuda/lib64/
146 | $ sudo chmod a+r /usr/local/cuda/include/cudnn.h
147 | $ sudo chmod a+r /usr/local/cuda/lib64/libcudnn*
148 | ```
149 |
150 | ## Install Tensorflow
151 |
152 | Tensorflow install page : [click here](https://www.tensorflow.org/get_started/os_setup)
153 |
154 | ```bash
155 | $ sudo apt-get install python-pip python-dev
156 | $ pip install --upgrade pip
157 | $ pip install tensorflow-gpu
158 | ```
159 |
160 | ## Install Torch
161 |
162 | Torch install page : [click here](http://torch.ch/docs/getting-started.html)
163 |
164 | ```bash
165 | $ git clone https://github.com/torch/distro.git ~/torch --recursive
166 | $ cd ~/torch; bash install-deps;
167 | $ ./install.sh
168 | $ source ~/.bashrc
169 | ```
170 |
171 | ## Now, Enjoy!
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Bumsoo Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | Best CIFAR-10, CIFAR-100 results with wide-residual networks using PyTorch
4 |
5 | Pytorch Implementation of Sergey Zagoruyko's [Wide Residual Networks](https://arxiv.org/pdf/1605.07146v2.pdf)
6 |
7 | For Torch implementations, see [here](https://github.com/meliketoy/wide-residual-network).
8 |
9 | ## Requirements
10 | See the [installation instruction](INSTALL.md) for a step-by-step installation guide.
11 | See the [server instruction](SERVER.md) for server settup.
12 | - Install [cuda-8.0](https://developer.nvidia.com/cuda-downloads)
13 | - Install [cudnn v5.1](https://developer.nvidia.com/cudnn)
14 | - Download [Pytorch 2.7](https://pytorch.org) and clone the repository.
15 | ```bash
16 | pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp27-none-linux_x86_64.whl
17 | pip install torchvision
18 | git clone https://github.com/meliketoy/wide-resnet.pytorch
19 | ```
20 |
21 | ## How to run
22 | After you have cloned the repository, you can train each dataset of either cifar10, cifar100 by running the script below.
23 | ```bash
24 | python main --lr 0.1 resume false --net_type [lenet/vggnet/resnet/wide-resnet] --depth 28 --widen_factor 10 --dropout_rate 0.3 --dataset [cifar10/cifar100]
25 | ```
26 |
27 | ## Implementation Details
28 |
29 | | epoch | learning rate | weight decay | Optimizer | Momentum | Nesterov |
30 | |:---------:|:-------------:|:-------------:|:---------:|:--------:|:--------:|
31 | | 0 ~ 60 | 0.1 | 0.0005 | Momentum | 0.9 | true |
32 | | 61 ~ 120 | 0.02 | 0.0005 | Momentum | 0.9 | true |
33 | | 121 ~ 160 | 0.004 | 0.0005 | Momentum | 0.9 | true |
34 | | 161 ~ 200 | 0.0008 | 0.0005 | Momentum | 0.9 | true |
35 |
36 | ## CIFAR-10 Results
37 |
38 | 
39 |
40 | Below is the result of the test set accuracy for **CIFAR-10 dataset** training.
41 |
42 | **Accuracy is the average of 5 runs**
43 |
44 | | network | dropout | preprocess | GPU:0 | GPU:1 | per epoch | accuracy(%) |
45 | |:-----------------:|:-------:|:----------:|:-----:|:-----:|:------------:|:-----------:|
46 | | wide-resnet 28x10 | 0 | ZCA | 5.90G | - | 2 min 03 sec | 95.83 |
47 | | wide-resnet 28x10 | 0 | meanstd | 5.90G | - | 2 min 03 sec | 96.21 |
48 | | wide-resnet 28x10 | 0.3 | meanstd | 5.90G | - | 2 min 03 sec | 96.27 |
49 | | wide-resnet 28x20 | 0.3 | meanstd | 8.13G | 6.93G | 4 min 10 sec | **96.55** |
50 | | wide-resnet 40x10 | 0.3 | meanstd | 8.08G | - | 3 min 13 sec | 96.31 |
51 | | wide-resnet 40x14 | 0.3 | meanstd | 7.37G | 6.46G | 3 min 23 sec | 96.34 |
52 |
53 | ## CIFAR-100 Results
54 |
55 | 
56 |
57 | Below is the result of the test set accuracy for **CIFAR-100 dataset** training.
58 |
59 | **Accuracy is the average of 5 runs**
60 |
61 | | network | dropout | preprocess | GPU:0 | GPU:1 | per epoch | Top1 acc(%)| Top5 acc(%) |
62 | |:-----------------:|:-------:|:-----------:|:-----:|:-----:|:------------:|:----------:|:-----------:|
63 | | wide-resnet 28x10 | 0 | ZCA | 5.90G | - | 2 min 03 sec | 80.07 | 95.02 |
64 | | wide-resnet 28x10 | 0 | meanstd | 5.90G | - | 2 min 03 sec | 81.02 | 95.41 |
65 | | wide-resnet 28x10 | 0.3 | meanstd | 5.90G | - | 2 min 03 sec | 81.49 | 95.62 |
66 | | wide-resnet 28x20 | 0.3 | meanstd | 8.13G | 6.93G | 4 min 05 sec | **82.45** | **96.11** |
67 | | wide-resnet 40x10 | 0.3 | meanstd | 8.93G | - | 3 min 06 sec | 81.42 | 95.63 |
68 | | wide-resnet 40x14 | 0.3 | meanstd | 7.39G | 6.46G | 3 min 23 sec | 81.87 | 95.51 |
69 |
--------------------------------------------------------------------------------
/SERVER.md:
--------------------------------------------------------------------------------
1 | # SERVER MANAGEMENT
2 | This is the management guide for server installation.
3 |
4 | ## Welcome message
5 | Install figlet
6 | ```bash
7 | $ sudo apt-get install figlet
8 | ```
9 |
10 | ```bash
11 | $ sudo vi /etc/bash.bashrc
12 |
13 | # Press [Shift] + [G] and write the code on the bottom
14 | clear
15 | printf "Welcome to Ubuntu 16.04.5 LTS (GNU/Linux-Mint-18 x86_64)\n"
16 | printf "This is the server for the wide-residual-network implementation.\n\n"
17 | printf " * Documentation: https://github.com/meliketoy/wide-residual-network\n\n"
18 | printf "##############################################################\n"
19 | figlet -f slant "Bumsoo Kim"
20 | printf "\n\n"
21 | printf " Data Mining & Information System Lab\n"
22 | printf " GPU Computing machine : bumsoo@163.152.163.10\n\n"
23 | printf " Administrator : Bumsoo Kim\n"
24 | printf " Please read the document\n"
25 | printf " https://github.com/meliketoy/wide-residual-network/README.md\n"
26 | printf "##############################################################\n\n"
27 | ```
28 |
29 | ## Remote Server control
30 |
31 | ### 1. SCP call
32 |
33 | ```bash
34 |
35 | # Upload your local file to server
36 | $ scp -P 22 /file [:username]@[:server_ip]:
37 |
38 | # Download the server file to your local
39 | $ scp -P 22 [:username]@[:server_ip]:/file
40 |
41 | # Upload your local directory to server
42 | $ scp -P 22 -r /file [:username]@[:server_ip]:
43 |
44 | # Download the server file to your local
45 | $ scp -P 22 -r [:username]@[:server_ip]:/file
46 |
47 | ```
48 |
49 | ### 2. Save sessions by name
50 | ```bash
51 | $ sudo vi ~/.bashrc
52 |
53 | # Press [Shift] + [G] and enter the function on the bottom.
54 |
55 | function server_func() {
56 | echo -n "[Enter the name of server]: "
57 | read server_name
58 |
59 | # echo "Logging in to server $server_name ..."
60 | if [ $server_name == [:servername] ]; then
61 | ssh [:usr]@[:ip].[:ip].[:ip].[:ip]
62 | fi
63 | }
64 | alias dmis_remote=server_func
65 | ```
66 |
67 | ## Github control
68 |
69 | ```bash
70 | $ sudo vi ~/.netrc
71 |
72 | machine github.com
73 | login [:username]
74 | password [:password]
75 | ```
76 |
77 | ### Jupyter notebook configuration
78 |
79 | For jupyter notebook configuration, type in the command line below.
80 | ```bash
81 | $ jupyter notebook --generate-config
82 |
83 | * Result :
84 | Writing default config to: /.jupyter/jupyter_notebook_config.py
85 |
86 | $ vi ~/.jupyter/jupyter_notebook_config.py
87 | ```
88 |
89 | presh [Esc], then enter /ip to find the ip configuration. You will find the line below
90 | ``` bash
91 | ## The IP address the notebook server will listen on.
92 | #c.NotebookApp.ip = 'localhost'
93 | ```
94 |
95 | Erase the '#' and change it into ...
96 | ```bash
97 | c.NotebookApp.ip = '163.152.163.112' # the ip address for your server
98 | ```
99 |
100 | presh [Esc], then enter /port to find the port number. You will find the line below
101 | ```bash
102 | ## The port the notebook server will listen on.
103 | #c.NotebookApp.port = 8888
104 | ```
105 |
106 | Erase the '#' and enter whatever port number you want
107 | ```bash
108 | c.NotebookApp.port = 9999
109 |
110 | ```
111 |
112 | Now, Enjoy!
113 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | ############### Pytorch CIFAR configuration file ###############
2 | import math
3 |
4 | start_epoch = 1
5 | num_epochs = 200
6 | batch_size = 128
7 | optim_type = 'SGD'
8 |
9 | mean = {
10 | 'cifar10': (0.4914, 0.4822, 0.4465),
11 | 'cifar100': (0.5071, 0.4867, 0.4408),
12 | }
13 |
14 | std = {
15 | 'cifar10': (0.2023, 0.1994, 0.2010),
16 | 'cifar100': (0.2675, 0.2565, 0.2761),
17 | }
18 |
19 | # Only for cifar-10
20 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
21 |
22 | def learning_rate(init, epoch):
23 | optim_factor = 0
24 | if(epoch > 160):
25 | optim_factor = 3
26 | elif(epoch > 120):
27 | optim_factor = 2
28 | elif(epoch > 60):
29 | optim_factor = 1
30 |
31 | return init*math.pow(0.2, optim_factor)
32 |
33 | def get_hms(seconds):
34 | m, s = divmod(seconds, 60)
35 | h, m = divmod(m, 60)
36 |
37 | return h, m, s
38 |
--------------------------------------------------------------------------------
/imgs/cifar100_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/cifar100_image.png
--------------------------------------------------------------------------------
/imgs/cifar10_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/cifar10_image.png
--------------------------------------------------------------------------------
/imgs/img_356.lua:
--------------------------------------------------------------------------------
1 | require 'image'
2 |
3 | new_path = 'svhn_image.png'
4 |
5 | img = image.load('cifar10_image.png')
6 | new_img = image.load(new_path)
7 |
8 | print(img:size())
9 | print(new_img:size())
10 |
11 | trans_img = image.scale(new_img, img:size(2), img:size(3), 'bicubic')
12 | image.save(new_path, trans_img)
13 |
14 | print("Image Convert Completed!")
15 | print(img:size())
16 | print(trans_img:size())
17 |
--------------------------------------------------------------------------------
/imgs/pytorch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/pytorch.png
--------------------------------------------------------------------------------
/imgs/svhn_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bmsookim/wide-resnet.pytorch/292b3ede0651e349dd566f9c23408aa572f1bd92/imgs/svhn_image.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | import torch.backends.cudnn as cudnn
8 | import config as cf
9 |
10 | import torchvision
11 | import torchvision.transforms as transforms
12 |
13 | import os
14 | import sys
15 | import time
16 | import argparse
17 | import datetime
18 |
19 | from networks import *
20 | from torch.autograd import Variable
21 |
22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training')
23 | parser.add_argument('--lr', default=0.1, type=float, help='learning_rate')
24 | parser.add_argument('--net_type', default='wide-resnet', type=str, help='model')
25 | parser.add_argument('--depth', default=28, type=int, help='depth of model')
26 | parser.add_argument('--widen_factor', default=10, type=int, help='width of model')
27 | parser.add_argument('--dropout', default=0.3, type=float, help='dropout_rate')
28 | parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [cifar10/cifar100]')
29 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
30 | parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model')
31 | args = parser.parse_args()
32 |
33 | # Hyper Parameter settings
34 | use_cuda = torch.cuda.is_available()
35 | best_acc = 0
36 | start_epoch, num_epochs, batch_size, optim_type = cf.start_epoch, cf.num_epochs, cf.batch_size, cf.optim_type
37 |
38 | # Data Uplaod
39 | print('\n[Phase 1] : Data Preparation')
40 | transform_train = transforms.Compose([
41 | transforms.RandomCrop(32, padding=4),
42 | transforms.RandomHorizontalFlip(),
43 | transforms.ToTensor(),
44 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
45 | ]) # meanstd transformation
46 |
47 | transform_test = transforms.Compose([
48 | transforms.ToTensor(),
49 | transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
50 | ])
51 |
52 | if(args.dataset == 'cifar10'):
53 | print("| Preparing CIFAR-10 dataset...")
54 | sys.stdout.write("| ")
55 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
56 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
57 | num_classes = 10
58 | elif(args.dataset == 'cifar100'):
59 | print("| Preparing CIFAR-100 dataset...")
60 | sys.stdout.write("| ")
61 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
62 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test)
63 | num_classes = 100
64 |
65 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
66 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
67 |
68 | # Return network & file name
69 | def getNetwork(args):
70 | if (args.net_type == 'lenet'):
71 | net = LeNet(num_classes)
72 | file_name = 'lenet'
73 | elif (args.net_type == 'vggnet'):
74 | net = VGG(args.depth, num_classes)
75 | file_name = 'vgg-'+str(args.depth)
76 | elif (args.net_type == 'resnet'):
77 | net = ResNet(args.depth, num_classes)
78 | file_name = 'resnet-'+str(args.depth)
79 | elif (args.net_type == 'wide-resnet'):
80 | net = Wide_ResNet(args.depth, args.widen_factor, args.dropout, num_classes)
81 | file_name = 'wide-resnet-'+str(args.depth)+'x'+str(args.widen_factor)
82 | else:
83 | print('Error : Network should be either [LeNet / VGGNet / ResNet / Wide_ResNet')
84 | sys.exit(0)
85 |
86 | return net, file_name
87 |
88 | # Test only option
89 | if (args.testOnly):
90 | print('\n[Test Phase] : Model setup')
91 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
92 | _, file_name = getNetwork(args)
93 | checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+'.t7')
94 | net = checkpoint['net']
95 |
96 | if use_cuda:
97 | net.cuda()
98 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
99 | cudnn.benchmark = True
100 |
101 | net.eval()
102 | net.training = False
103 | test_loss = 0
104 | correct = 0
105 | total = 0
106 |
107 | with torch.no_grad():
108 | for batch_idx, (inputs, targets) in enumerate(testloader):
109 | if use_cuda:
110 | inputs, targets = inputs.cuda(), targets.cuda()
111 | inputs, targets = Variable(inputs), Variable(targets)
112 | outputs = net(inputs)
113 |
114 | _, predicted = torch.max(outputs.data, 1)
115 | total += targets.size(0)
116 | correct += predicted.eq(targets.data).cpu().sum()
117 |
118 | acc = 100.*correct/total
119 | print("| Test Result\tAcc@1: %.2f%%" %(acc))
120 |
121 | sys.exit(0)
122 |
123 | # Model
124 | print('\n[Phase 2] : Model setup')
125 | if args.resume:
126 | # Load checkpoint
127 | print('| Resuming from checkpoint...')
128 | assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
129 | _, file_name = getNetwork(args)
130 | checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+'.t7')
131 | net = checkpoint['net']
132 | best_acc = checkpoint['acc']
133 | start_epoch = checkpoint['epoch']
134 | else:
135 | print('| Building net type [' + args.net_type + ']...')
136 | net, file_name = getNetwork(args)
137 | net.apply(conv_init)
138 |
139 | if use_cuda:
140 | net.cuda()
141 | net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
142 | cudnn.benchmark = True
143 |
144 | criterion = nn.CrossEntropyLoss()
145 |
146 | # Training
147 | def train(epoch):
148 | net.train()
149 | net.training = True
150 | train_loss = 0
151 | correct = 0
152 | total = 0
153 | optimizer = optim.SGD(net.parameters(), lr=cf.learning_rate(args.lr, epoch), momentum=0.9, weight_decay=5e-4)
154 |
155 | print('\n=> Training Epoch #%d, LR=%.4f' %(epoch, cf.learning_rate(args.lr, epoch)))
156 | for batch_idx, (inputs, targets) in enumerate(trainloader):
157 | if use_cuda:
158 | inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
159 | optimizer.zero_grad()
160 | inputs, targets = Variable(inputs), Variable(targets)
161 | outputs = net(inputs) # Forward Propagation
162 | loss = criterion(outputs, targets) # Loss
163 | loss.backward() # Backward Propagation
164 | optimizer.step() # Optimizer update
165 |
166 | train_loss += loss.item()
167 | _, predicted = torch.max(outputs.data, 1)
168 | total += targets.size(0)
169 | correct += predicted.eq(targets.data).cpu().sum()
170 |
171 | sys.stdout.write('\r')
172 | sys.stdout.write('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
173 | %(epoch, num_epochs, batch_idx+1,
174 | (len(trainset)//batch_size)+1, loss.item(), 100.*correct/total))
175 | sys.stdout.flush()
176 |
177 | def test(epoch):
178 | global best_acc
179 | net.eval()
180 | net.training = False
181 | test_loss = 0
182 | correct = 0
183 | total = 0
184 | with torch.no_grad():
185 | for batch_idx, (inputs, targets) in enumerate(testloader):
186 | if use_cuda:
187 | inputs, targets = inputs.cuda(), targets.cuda()
188 | inputs, targets = Variable(inputs), Variable(targets)
189 | outputs = net(inputs)
190 | loss = criterion(outputs, targets)
191 |
192 | test_loss += loss.item()
193 | _, predicted = torch.max(outputs.data, 1)
194 | total += targets.size(0)
195 | correct += predicted.eq(targets.data).cpu().sum()
196 |
197 | # Save checkpoint when best model
198 | acc = 100.*correct/total
199 | print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch, loss.item(), acc))
200 |
201 | if acc > best_acc:
202 | print('| Saving Best model...\t\t\tTop1 = %.2f%%' %(acc))
203 | state = {
204 | 'net':net.module if use_cuda else net,
205 | 'acc':acc,
206 | 'epoch':epoch,
207 | }
208 | if not os.path.isdir('checkpoint'):
209 | os.mkdir('checkpoint')
210 | save_point = './checkpoint/'+args.dataset+os.sep
211 | if not os.path.isdir(save_point):
212 | os.mkdir(save_point)
213 | torch.save(state, save_point+file_name+'.t7')
214 | best_acc = acc
215 |
216 | print('\n[Phase 3] : Training model')
217 | print('| Training Epochs = ' + str(num_epochs))
218 | print('| Initial Learning Rate = ' + str(args.lr))
219 | print('| Optimizer = ' + str(optim_type))
220 |
221 | elapsed_time = 0
222 | for epoch in range(start_epoch, start_epoch+num_epochs):
223 | start_time = time.time()
224 |
225 | train(epoch)
226 | test(epoch)
227 |
228 | epoch_time = time.time() - start_time
229 | elapsed_time += epoch_time
230 | print('| Elapsed time : %d:%02d:%02d' %(cf.get_hms(elapsed_time)))
231 |
232 | print('\n[Phase 4] : Testing model')
233 | print('* Test results : Acc@1 = %.2f%%' %(best_acc))
234 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .lenet import *
2 | from .vggnet import *
3 | from .resnet import *
4 | from .wide_resnet import *
5 |
--------------------------------------------------------------------------------
/networks/lenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | def conv_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
8 | init.constant(m.bias, 0)
9 |
10 | class LeNet(nn.Module):
11 | def __init__(self, num_classes):
12 | super(LeNet, self).__init__()
13 | self.conv1 = nn.Conv2d(3, 6, 5)
14 | self.conv2 = nn.Conv2d(6, 16, 5)
15 | self.fc1 = nn.Linear(16*5*5, 120)
16 | self.fc2 = nn.Linear(120, 84)
17 | self.fc3 = nn.Linear(84, num_classes)
18 |
19 | def forward(self, x):
20 | out = F.relu(self.conv1(x))
21 | out = F.max_pool2d(out, 2)
22 | out = F.relu(self.conv2(out))
23 | out = F.max_pool2d(out, 2)
24 | out = out.view(out.size(0), -1)
25 | out = F.relu(self.fc1(out))
26 | out = F.relu(self.fc2(out))
27 | out = self.fc3(out)
28 |
29 | return(out)
30 |
--------------------------------------------------------------------------------
/networks/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torch.autograd import Variable
6 | import sys
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
10 |
11 | def conv_init(m):
12 | classname = m.__class__.__name__
13 | if classname.find('Conv') != -1:
14 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
15 | init.constant(m.bias, 0)
16 |
17 | def cfg(depth):
18 | depth_lst = [18, 34, 50, 101, 152]
19 | assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152"
20 | cf_dict = {
21 | '18': (BasicBlock, [2,2,2,2]),
22 | '34': (BasicBlock, [3,4,6,3]),
23 | '50': (Bottleneck, [3,4,6,3]),
24 | '101':(Bottleneck, [3,4,23,3]),
25 | '152':(Bottleneck, [3,8,36,3]),
26 | }
27 |
28 | return cf_dict[str(depth)]
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, in_planes, planes, stride=1):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(in_planes, planes, stride)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 |
40 | self.shortcut = nn.Sequential()
41 | if stride != 1 or in_planes != self.expansion * planes:
42 | self.shortcut = nn.Sequential(
43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True),
44 | nn.BatchNorm2d(self.expansion*planes)
45 | )
46 |
47 | def forward(self, x):
48 | out = F.relu(self.bn1(self.conv1(x)))
49 | out = self.bn2(self.conv2(out))
50 | out += self.shortcut(x)
51 | out = F.relu(out)
52 |
53 | return out
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 |
58 | def __init__(self, in_planes, planes, stride=1):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=True)
61 | self.bn1 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True)
65 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
66 |
67 | self.shortcut = nn.Sequential()
68 | if stride != 1 or in_planes != self.expansion*planes:
69 | self.shortcut = nn.Sequential(
70 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True),
71 | nn.BatchNorm2d(self.expansion*planes)
72 | )
73 |
74 | def forward(self, x):
75 | out = F.relu(self.bn1(self.conv1(x)))
76 | out = F.relu(self.bn2(self.conv2(out)))
77 | out = self.bn3(self.conv3(out))
78 | out += self.shortcut(x)
79 | out = F.relu(out)
80 |
81 | return out
82 |
83 | class ResNet(nn.Module):
84 | def __init__(self, depth, num_classes):
85 | super(ResNet, self).__init__()
86 | self.in_planes = 16
87 |
88 | block, num_blocks = cfg(depth)
89 |
90 | self.conv1 = conv3x3(3,16)
91 | self.bn1 = nn.BatchNorm2d(16)
92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
95 | self.linear = nn.Linear(64*block.expansion, num_classes)
96 |
97 | def _make_layer(self, block, planes, num_blocks, stride):
98 | strides = [stride] + [1]*(num_blocks-1)
99 | layers = []
100 |
101 | for stride in strides:
102 | layers.append(block(self.in_planes, planes, stride))
103 | self.in_planes = planes * block.expansion
104 |
105 | return nn.Sequential(*layers)
106 |
107 | def forward(self, x):
108 | out = F.relu(self.bn1(self.conv1(x)))
109 | out = self.layer1(out)
110 | out = self.layer2(out)
111 | out = self.layer3(out)
112 | out = F.avg_pool2d(out, 8)
113 | out = out.view(out.size(0), -1)
114 | out = self.linear(out)
115 |
116 | return out
117 |
118 | if __name__ == '__main__':
119 | net=ResNet(50, 10)
120 | y = net(Variable(torch.randn(1,3,32,32)))
121 | print(y.size())
122 |
--------------------------------------------------------------------------------
/networks/vggnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | def conv_init(m):
6 | classname = m.__class__.__name__
7 | if classname.find('Conv') != -1:
8 | init.xavier_uniform(m.weight, gain=np.sqrt(2))
9 | init.constant(m.bias, 0)
10 |
11 | def cfg(depth):
12 | depth_lst = [11, 13, 16, 19]
13 | assert (depth in depth_lst), "Error : VGGnet depth should be either 11, 13, 16, 19"
14 | cf_dict = {
15 | '11': [
16 | 64, 'mp',
17 | 128, 'mp',
18 | 256, 256, 'mp',
19 | 512, 512, 'mp',
20 | 512, 512, 'mp'],
21 | '13': [
22 | 64, 64, 'mp',
23 | 128, 128, 'mp',
24 | 256, 256, 'mp',
25 | 512, 512, 'mp',
26 | 512, 512, 'mp'
27 | ],
28 | '16': [
29 | 64, 64, 'mp',
30 | 128, 128, 'mp',
31 | 256, 256, 256, 'mp',
32 | 512, 512, 512, 'mp',
33 | 512, 512, 512, 'mp'
34 | ],
35 | '19': [
36 | 64, 64, 'mp',
37 | 128, 128, 'mp',
38 | 256, 256, 256, 256, 'mp',
39 | 512, 512, 512, 512, 'mp',
40 | 512, 512, 512, 512, 'mp'
41 | ],
42 | }
43 |
44 | return cf_dict[str(depth)]
45 |
46 | def conv3x3(in_planes, out_planes, stride=1):
47 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
48 |
49 | class VGG(nn.Module):
50 | def __init__(self, depth, num_classes):
51 | super(VGG, self).__init__()
52 | self.features = self._make_layers(cfg(depth))
53 | self.linear = nn.Linear(512, num_classes)
54 |
55 | def forward(self, x):
56 | out = self.features(x)
57 | out = out.view(out.size(0), -1)
58 | out = self.linear(out)
59 |
60 | return out
61 |
62 | def _make_layers(self, cfg):
63 | layers = []
64 | in_planes = 3
65 |
66 | for x in cfg:
67 | if x == 'mp':
68 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
69 | else:
70 | layers += [conv3x3(in_planes, x), nn.BatchNorm2d(x), nn.ReLU(inplace=True)]
71 | in_planes = x
72 |
73 | # After cfg convolution
74 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
75 | return nn.Sequential(*layers)
76 |
77 | if __name__ == "__main__":
78 | net = VGG(16, 10)
79 | y = net(Variable(torch.randn(1,3,32,32)))
80 | print(y.size())
81 |
--------------------------------------------------------------------------------
/networks/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | import sys
8 | import numpy as np
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
12 |
13 | def conv_init(m):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv') != -1:
16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
17 | init.constant_(m.bias, 0)
18 | elif classname.find('BatchNorm') != -1:
19 | init.constant_(m.weight, 1)
20 | init.constant_(m.bias, 0)
21 |
22 | class wide_basic(nn.Module):
23 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
24 | super(wide_basic, self).__init__()
25 | self.bn1 = nn.BatchNorm2d(in_planes)
26 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
27 | self.dropout = nn.Dropout(p=dropout_rate)
28 | self.bn2 = nn.BatchNorm2d(planes)
29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
30 |
31 | self.shortcut = nn.Sequential()
32 | if stride != 1 or in_planes != planes:
33 | self.shortcut = nn.Sequential(
34 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
35 | )
36 |
37 | def forward(self, x):
38 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
39 | out = self.conv2(F.relu(self.bn2(out)))
40 | out += self.shortcut(x)
41 |
42 | return out
43 |
44 | class Wide_ResNet(nn.Module):
45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes):
46 | super(Wide_ResNet, self).__init__()
47 | self.in_planes = 16
48 |
49 | assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
50 | n = (depth-4)/6
51 | k = widen_factor
52 |
53 | print('| Wide-Resnet %dx%d' %(depth, k))
54 | nStages = [16, 16*k, 32*k, 64*k]
55 |
56 | self.conv1 = conv3x3(3,nStages[0])
57 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
58 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
59 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
60 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
61 | self.linear = nn.Linear(nStages[3], num_classes)
62 |
63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
64 | strides = [stride] + [1]*(int(num_blocks)-1)
65 | layers = []
66 |
67 | for stride in strides:
68 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
69 | self.in_planes = planes
70 |
71 | return nn.Sequential(*layers)
72 |
73 | def forward(self, x):
74 | out = self.conv1(x)
75 | out = self.layer1(out)
76 | out = self.layer2(out)
77 | out = self.layer3(out)
78 | out = F.relu(self.bn1(out))
79 | out = F.avg_pool2d(out, 8)
80 | out = out.view(out.size(0), -1)
81 | out = self.linear(out)
82 |
83 | return out
84 |
85 | if __name__ == '__main__':
86 | net=Wide_ResNet(28, 10, 0.3, 10)
87 | y = net(Variable(torch.randn(1,3,32,32)))
88 |
89 | print(y.size())
90 |
--------------------------------------------------------------------------------
/scripts/cifar100_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export netType='wide-resnet'
3 | export depth=28
4 | export width=10
5 | export dataset='cifar100'
6 |
7 | python main.py \
8 | --lr 0.1 \
9 | --net_type ${netType} \
10 | --depth ${depth} \
11 | --widen_factor ${width} \
12 | --dropout 0 \
13 | --dataset ${dataset}
14 |
--------------------------------------------------------------------------------
/scripts/cifar10_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export netType='wide-resnet'
3 | export depth=28
4 | export width=10
5 | export dataset='cifar10'
6 |
7 | python main.py \
8 | --lr 0.1 \
9 | --net_type ${netType} \
10 | --depth ${depth} \
11 | --widen_factor ${width} \
12 | --dropout 0 \
13 | --dataset ${dataset} \
14 |
--------------------------------------------------------------------------------
/scripts/resnet_cifar100_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export netType='resnet'
3 | export depth=18
4 | export dataset='cifar100'
5 |
6 | python main.py \
7 | --lr 0.1 \
8 | --net_type ${netType} \
9 | --depth ${depth} \
10 | --dropout 0 \
11 | --dataset ${dataset}
12 |
--------------------------------------------------------------------------------