├── .gitignore
├── README.md
├── evaluate_imagenet.py
├── evaluate_timing.py
├── models
├── __init__.py
└── selecsls.py
├── util
├── __init__.py
└── imagenet_data_loader.py
└── weights
└── readme.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SelecSLS Convolutional Net Pytorch Implementation
2 | Reference ImageNet implementation of SelecSLS Convolutional Neural Network architecture proposed in [XNect: Real-time Multi-Person 3D Motion Capture
3 | with a Single RGB Camera](http://gvv.mpi-inf.mpg.de/projects/XNect/) (SIGGRAPH 2020).
4 |
5 | The network architecture is 1.3-1.5x faster than ResNet-50, particularly for larger image sizes, with the same level of accuracy on different tasks!
6 | Further, it takes substantially less memory while training, so it can be trained with larger batch sizes!
7 |
8 | ### Update (28 Dec 2019)
9 | Better and more accurate models / snapshots are now available. See the additional ImageNet table below.
10 |
11 | ### Update (14 Oct 2019)
12 | Code for pruning the model based on [Implicit Filter Level Sparsity](http://openaccess.thecvf.com/content_CVPR_2019/html/Mehta_On_Implicit_Filter_Level_Sparsity_in_Convolutional_Neural_Networks_CVPR_2019_paper.html) is also a part of the [SelecSLS model](https://github.com/mehtadushy/SelecSLS-Pytorch/blob/master/models/selecsls.py#L280) now. The sparsity is a natural consequence of training with adaptive gradient descent approaches and L2 regularization. It gives a further speedup of **10-30%** on the pretrained models with no loss in accuracy. See usage and results below.
13 |
14 | ## ImageNet results
15 |
16 | The inference time for the models in the table below is measured on a TITAN X GPU using the accompanying scripts. The accuracy results for ResNet-50 are from torchvision, and the accuracy results for VoVNet-39 are from [VoVNet](https://github.com/stigma0617/VoVNet.pytorch).
17 |
18 |
19 | |
20 | Forward Pass Time (ms) for different image resolutions |
21 | ImageNet Error |
22 |
23 |
24 | |
25 | 512x512 |
26 | 400x400 |
27 | 224x224 |
28 | Top-1 |
29 | Top-5 |
30 |
31 |
32 | Batch Size |
33 | 1 |
34 | 16 |
35 | 1 |
36 | 16 |
37 | 1 |
38 | 16 |
39 | |
40 | |
41 |
42 |
43 | ResNet-50 |
44 | 15.0 |
45 | 175.0 |
46 | 11.0 |
47 | 114.0 |
48 | 7.2 |
49 | 39.0 |
50 | 23.9 |
51 | 7.1 |
52 |
53 |
54 | VoVNet-39 |
55 | 13.0 |
56 | 197.0 |
57 | 10.8 |
58 | 130.0 |
59 | 6 |
60 | 41.0 |
61 | 23.2 |
62 | 6.6 |
63 |
64 |
65 | SelecSLS-60 |
66 | 11.0 |
67 | 115.0 |
68 | 9.5 |
69 | 85.0 |
70 | 7.3 |
71 | 29.0 |
72 | 23.8 |
73 | 7.0 |
74 |
75 |
76 | SelecSLS-60 (P) |
77 | 10.2 |
78 | 102.0 |
79 | 8.2 |
80 | 71.0 |
81 | 6.1 |
82 | 25.0 |
83 | 23.8 |
84 | 7.0 |
85 |
86 |
87 | SelecSLS-84 |
88 | 16.1 |
89 | 175.0 |
90 | 13.7 |
91 | 124.0 |
92 | 9.9 |
93 | 42.3 |
94 | 23.3 |
95 | 6.9 |
96 |
97 | SelecSLS-84 (P) |
98 | 11.9 |
99 | 119.0 |
100 | 10.1 |
101 | 82.0 |
102 | 7.6 |
103 | 28.6 |
104 | 23.3 |
105 | 6.9 |
106 |
107 | * (P) indicates that the model has batch norm fusion and pruning applied
108 |
109 |
110 |
111 | The following models are trained using Cosine LR, Random Erasing, EMA, *Bicubic* Interpolation, and Color Jitter using [rwightman/pytorch-image-models](https://github.com/rwightman/pytorch-image-models). The inference time for models here is measured on a TITAN Xp GPU using the accompanying scripts. The script for evaluating ImageNet performance uses *Bilinear* interpolation, hence the results reported here are marginally worse than they would be with Bicubic interpolation at inference.
112 |
113 |
114 |
115 |
116 | |
117 | Forward Pass Time (ms) for different image resolutions |
118 | ImageNet Error |
119 |
120 |
121 | |
122 | 512x512 |
123 | 400x400 |
124 | 224x224 |
125 | Top-1 |
126 | Top-5 |
127 |
128 |
129 | Batch Size |
130 | 1 |
131 | 16 |
132 | 1 |
133 | 16 |
134 | 1 |
135 | 16 |
136 | |
137 | |
138 |
139 |
140 | SelecSLS-42_B |
141 | 6.4 |
142 | 60.8 |
143 | 5.8 |
144 | 42.1 |
145 | 5.7 |
146 | 14.7 |
147 | 22.9 |
148 | 6.6 |
149 |
150 |
151 | SelecSLS-60 |
152 | 7.4 |
153 | 69.4 |
154 | 7.3 |
155 | 47.6 |
156 | 7.1 |
157 | 16.8 |
158 | 22.1 |
159 | 6.1 |
160 |
161 |
162 | SelecSLS-60_B |
163 | 7.5 |
164 | 70.5 |
165 | 7.3 |
166 | 49.3 |
167 | 7.2 |
168 | 17.0 |
169 | 21.6 |
170 | 5.8 |
171 |
172 |
173 |
174 |
175 |
176 |
177 | # SelecSLS (Selective Short and Long Range Skip Connections)
178 | The key feature of the proposed architecture is that unlike the full dense connectivity in DenseNets, SelecSLS uses a much sparser skip connectivity pattern that uses both long and short-range concatenative-skip connections. Additionally, the network architecture is more amenable to [filter/channel pruning](http://openaccess.thecvf.com/content_CVPR_2019/html/Mehta_On_Implicit_Filter_Level_Sparsity_in_Convolutional_Neural_Networks_CVPR_2019_paper.html) than ResNets.
179 | You can find more details about the architecture in the following [paper](https://arxiv.org/abs/1907.00837), and details about implicit pruning in the [CVPR 2019 paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Mehta_On_Implicit_Filter_Level_Sparsity_in_Convolutional_Neural_Networks_CVPR_2019_paper.html).
180 |
181 | Another recent paper proposed the VoVNet architecture, which shares some design similarities with our architecture. However, as shown in the above table, our architecture is significantly faster than both VoVNet-39 and ResNet-50 for larger batch sizes as well as larger image sizes.
182 |
183 | ## Usage
184 | This repo provides the model definition in Pytorch, trained weights for ImageNet, and code for evaluating the forward pass time
185 | and the accuracy of the trained model on ImageNet validation set.
186 | In the paper, the model has been used for the task of human pose estimation, and can also be applied to a myriad of other problems as a drop in replacement for ResNet-50.
187 |
188 | ```
189 | wget http://gvv.mpi-inf.mpg.de/projects/XNectDemoV2/content/SelecSLS60_statedict.pth -O ./weights/SelecSLS60_statedict.pth
190 | python evaluate_timing.py --num_iter 100 --model_class selecsls --model_config SelecSLS60 --model_weights ./weights/SelecSLS60_statedict.pth --input_size 512 --gpu_id
191 | python evaluate_imagenet.py --model_class selecsls --model_config SelecSLS60 --model_weights ./weights/SelecSLS60_statedict.pth --gpu_id --imagenet_base_path
192 |
193 | #For pruning the model, and evaluating the pruned model (Using SelecSLS60 or other pretrained models)
194 | python evaluate_timing.py --num_iter 100 --model_class selecsls --model_config SelecSLS84 --model_weights ./weights/SelecSLS84_statedict.pth --input_size 512 --pruned_and_fused True --gamma_thresh 0.001 --gpu_id
195 | python evaluate_imagenet.py --model_class selecsls --model_config SelecSLS84 --model_weights ./weights/SelecSLS84_statedict.pth --pruned_and_fused True --gamma_thresh 0.001 --gpu_id --imagenet_base_path
196 | ```
197 |
198 | ## Older Pretrained Models
199 | - [SelecSLS-60](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict.pth)
200 | - [SelecSLS-84](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS84_statedict.pth)
201 |
202 | ## Newer Pretrained Models (More Accurate)
203 | - [SelecSLS-42_B](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B_statedict.pth)
204 | - [SelecSLS-60](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict_better.pth)
205 | - [SelecSLS-60_B](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B_statedict.pth)
206 |
207 | ## Requirements
208 | - Python 3.5
209 | - Pytorch >= 1.1
210 |
211 | ## License
212 | The contents of this repository, and the pretrained models are made available under CC BY 4.0. Please read the [license terms](https://creativecommons.org/licenses/by/4.0/legalcode).
213 |
214 | ### Citing
215 | If you use the model or the implicit sparisty based pruning in your work, please cite:
216 |
217 | ```
218 | @inproceedings{XNect_SIGGRAPH2020,
219 | author = {Mehta, Dushyant and Sotnychenko, Oleksandr and Mueller, Franziska and Xu, Weipeng and Elgharib, Mohamed and Fua, Pascal and Seidel, Hans-Peter and Rhodin, Helge and Pons-Moll, Gerard and Theobalt, Christian},
220 | title = {{XNect}: Real-time Multi-Person {3D} Motion Capture with a Single {RGB} Camera},
221 | journal = {ACM Transactions on Graphics},
222 | url = {http://gvv.mpi-inf.mpg.de/projects/XNect/},
223 | numpages = {17},
224 | volume={39},
225 | number={4},
226 | month = July,
227 | year = {2020},
228 | doi={10.1145/3386569.3392410}
229 | }
230 |
231 | @InProceedings{Mehta_2019_CVPR,
232 | author = {Mehta, Dushyant and Kim, Kwang In and Theobalt, Christian},
233 | title = {On Implicit Filter Level Sparsity in Convolutional Neural Networks},
234 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
235 | month = {June},
236 | year = {2019}
237 | }
238 | ```
239 |
240 |
241 |
242 |
--------------------------------------------------------------------------------
/evaluate_imagenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Script for evaluating accuracy on Imagenet Validation Set.
5 | '''
6 | import os
7 | import logging
8 | import sys
9 | import time
10 | from argparse import ArgumentParser
11 | import importlib
12 |
13 | import numpy as np
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | torch.backends.cudnn.benchmark = True
19 | from util.imagenet_data_loader import get_data_loader
20 |
21 |
22 |
23 | def opts_parser():
24 | usage = 'Configure the dataset using imagenet_data_loader'
25 | parser = ArgumentParser(description=usage)
26 | parser.add_argument(
27 | '--model_class', type=str, default='selecsls', metavar='FILE',
28 | help='Select model type to use (DenseNet, SelecSLS, ResNet etc.)')
29 | parser.add_argument(
30 | '--model_config', type=str, default='SelecSLS60', metavar='NET_CONFIG',
31 | help='Select the model configuration')
32 | parser.add_argument(
33 | '--model_weights', type=str, default='./weights/SelecSLS60_statedict.pth', metavar='FILE',
34 | help='Path to model weights')
35 | parser.add_argument(
36 | '--imagenet_base_path', type=str, default='', metavar='FILE',
37 | help='Path to ImageNet dataset')
38 | parser.add_argument(
39 | '--gpu_id', type=int, default=0,
40 | help='Which GPU to use.')
41 | parser.add_argument(
42 | '--simulate_pruning', type=bool, default=False,
43 | help='Whether to zero out features with gamma below a certain threshold')
44 | parser.add_argument(
45 | '--pruned_and_fused', type=bool, default=False,
46 | help='Whether to prune based on gamma below a certain threshold and fuse BN')
47 | parser.add_argument(
48 | '--gamma_thresh', type=float, default=1e-4,
49 | help='gamma threshold to use for simulating pruning')
50 | return parser
51 |
52 |
53 | def accuracy(output, target, topk=(1,)):
54 | """Computes the precision@k for the specified values of k"""
55 | maxk = max(topk)
56 | batch_size = target.size(0)
57 |
58 | _, pred = output.topk(maxk, 1, True, True)
59 | pred = pred.t()
60 | correct = pred.eq(target.view(1, -1).expand_as(pred))
61 |
62 | res = []
63 | for k in topk:
64 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
65 | res.append(correct_k.mul_(100.0 / batch_size))
66 | return res
67 |
68 |
69 | def evaluate_imagenet_validation_accuracy(model_class, model_config, model_weights, imagenet_base_path, gpu_id, simulate_pruning, pruned_and_fused, gamma_thresh):
70 | model_module = importlib.import_module('models.'+model_class)
71 | net = model_module.Net(nClasses=1000, config=model_config)
72 | net.load_state_dict(torch.load(model_weights, map_location= lambda storage, loc: storage))
73 |
74 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
75 | net = net.to(device)
76 | if pruned_and_fused:
77 | print('Fusing BN and pruning channels based on gamma ' + str(gamma_thresh))
78 | net.prune_and_fuse(gamma_thresh)
79 |
80 | if simulate_pruning:
81 | print('Simulating pruning by zeroing all features with gamma less than '+str(gamma_thresh))
82 | with torch.no_grad():
83 | for n, m in net.named_modules():
84 | if isinstance(m, nn.BatchNorm2d):
85 | m.weight[abs(m.weight) < gamma_thresh] = 0
86 | m.bias[abs(m.weight) < gamma_thresh] = 0
87 |
88 | net.eval()
89 | _,test_loader = get_data_loader(augment=False, batch_size=100, base_path=imagenet_base_path)
90 | with torch.no_grad():
91 | val1_err = []
92 | val5_err = []
93 | for x, y in test_loader:
94 | pred = F.log_softmax(net(x.to(device)))
95 | top1, top5 = accuracy(pred, y.to(device), topk=(1, 5))
96 | val1_err.append(100-top1)
97 | val5_err.append(100-top5)
98 | avg1_err= float(np.sum(val1_err)) / len(val1_err)
99 | avg5_err= float(np.sum(val5_err)) / len(val5_err)
100 | print('Top-1 Error: {} Top-5 Error {}'.format(avg1_err, avg5_err))
101 |
102 |
103 | def main():
104 | # parse command line
105 | torch.manual_seed(1234)
106 | parser = opts_parser()
107 | args = parser.parse_args()
108 |
109 | # run
110 | evaluate_imagenet_validation_accuracy(**vars(args))
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/evaluate_timing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Script for timing models in eval mode and torchscript eval modes.
5 | '''
6 |
7 | import os
8 | import logging
9 | import sys
10 | import time
11 | from argparse import ArgumentParser
12 | import importlib
13 |
14 | import numpy as np
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | torch.backends.cudnn.benchmark = True
20 |
21 |
22 |
23 | def opts_parser():
24 | usage = 'Pass the model and'
25 | parser = ArgumentParser(description=usage)
26 | parser.add_argument(
27 | '--num_iter', type=int, default=50,
28 | help='Number of iterations to average over.')
29 | parser.add_argument(
30 | '--model_class', type=str, default='selecsls', metavar='FILE',
31 | help='Select model type to use (DenseNet, SelecSLS, ResNet etc.)')
32 | parser.add_argument(
33 | '--model_config', type=str, default='SelecSLS84', metavar='NET_CONFIG',
34 | help='Select the model configuration')
35 | parser.add_argument(
36 | '--model_weights', type=str, default='./weights/SelecSLS84_statedict.pth', metavar='FILE',
37 | help='Path to model weights')
38 | parser.add_argument(
39 | '--input_size', type=int, default=400,
40 | help='Input image size.')
41 | parser.add_argument(
42 | '--gpu_id', type=int, default=0,
43 | help='Which GPU to use.')
44 | parser.add_argument(
45 | '--pruned_and_fused', type=bool, default=False,
46 | help='Whether to zero out features with gamma below a certain threshold')
47 | parser.add_argument(
48 | '--gamma_thresh', type=float, default=1e-3,
49 | help='gamma threshold to use for simulating pruning. Set this to -1 to only fuse BN without pruning')
50 | return parser
51 |
52 |
53 | def measure_cpu(model, x):
54 | # synchronize gpu time and measure fp
55 | model.eval()
56 | with torch.no_grad():
57 | t0 = time.time()
58 | y_pred = model(x)
59 | elapsed_fp_nograd = time.time()-t0
60 | return elapsed_fp_nograd
61 |
62 | def measure_gpu(model, x):
63 | # synchronize gpu time and measure fp
64 | model.eval()
65 | with torch.no_grad():
66 | torch.cuda.synchronize()
67 | t0 = time.time()
68 | y_pred = model(x)
69 | torch.cuda.synchronize()
70 | elapsed_fp_nograd = time.time()-t0
71 | return elapsed_fp_nograd
72 |
73 |
74 | def benchmark(model_class, model_config, gpu_id, num_iter, model_weights, input_size, pruned_and_fused, gamma_thresh):
75 | # Import the model module
76 | model_module = importlib.import_module('models.'+model_class)
77 | net = model_module.Net(nClasses=1000, config=model_config)
78 | net.load_state_dict(torch.load(model_weights, map_location= lambda storage, loc: storage))
79 |
80 | if pruned_and_fused:
81 | print('Pruning and fusing the model')
82 | net.prune_and_fuse(gamma_thresh, True)
83 |
84 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
85 | net = net.to(device)
86 | print('\nEvaluating on GPU {}'.format(device))
87 |
88 | print('\nGPU, Batch Size: 1')
89 | x = torch.randn(1, 3, input_size, input_size)
90 | #Warm up
91 | for i in range(10):
92 | _ = measure_gpu(net, x.to(device))
93 | fp = []
94 | for i in range(num_iter):
95 | t = measure_gpu(net, x.to(device))
96 | fp.append(t)
97 | print('Model FP: '+str(np.mean(np.asarray(fp)*1000))+'ms')
98 |
99 | jit_net = torch.jit.trace(net, x.to(device))
100 | for i in range(10):
101 | _ = measure_gpu(jit_net, x.to(device))
102 | fp = []
103 | for i in range(num_iter):
104 | t = measure_gpu(jit_net, x.to(device))
105 | fp.append(t)
106 | print('JIT FP: '+str(np.mean(np.asarray(fp)*1000))+'ms')
107 |
108 |
109 | print('\nGPU, Batch Size: 16')
110 | x = torch.randn(16, 3, input_size, input_size)
111 | #Warm up
112 | for i in range(10):
113 | _ = measure_gpu(net, x.to(device))
114 | fp = []
115 | for i in range(num_iter):
116 | t = measure_gpu(net, x.to(device))
117 | fp.append(t)
118 | print('Model FP: '+str(np.mean(np.asarray(fp)*1000))+'ms')
119 |
120 | jit_net = torch.jit.trace(net, x.to(device))
121 | for i in range(10):
122 | _ = measure_gpu(jit_net, x.to(device))
123 | fp = []
124 | for i in range(num_iter):
125 | t = measure_gpu(jit_net, x.to(device))
126 | fp.append(t)
127 | print('JIT FP: '+str(np.mean(np.asarray(fp)*1000))+'ms')
128 |
129 | device = torch.device("cpu")
130 | print('\nEvaluating on {}'.format(device))
131 | net = net.to(device)
132 |
133 | print('\nCPU, Batch Size: 1')
134 | x = torch.randn(1, 3, input_size, input_size)
135 | #Warm up
136 | for i in range(10):
137 | _ = measure_cpu(net, x.to(device))
138 | fp = []
139 | for i in range(num_iter):
140 | t = measure_cpu(net, x.to(device))
141 | fp.append(t)
142 | print('Model FP: '+str(np.mean(np.asarray(fp)*1000))+'ms')
143 |
144 | jit_net = torch.jit.trace(net, x.to(device))
145 | for i in range(10):
146 | _ = measure_cpu(jit_net, x.to(device))
147 | fp = []
148 | for i in range(num_iter):
149 | t = measure_cpu(jit_net, x.to(device))
150 | fp.append(t)
151 | print('JIT FP: '+str(np.mean(np.asarray(fp)*1000))+'ms')
152 |
153 |
154 |
155 | def main():
156 | # parse command line
157 | torch.manual_seed(1234)
158 | parser = opts_parser()
159 | args = parser.parse_args()
160 |
161 | # run
162 | benchmark(**vars(args))
163 |
164 | if __name__ == '__main__':
165 | main()
166 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mehtadushy/SelecSLS-Pytorch/3852734af392b8fa69834984c76856dbd39179a3/models/__init__.py
--------------------------------------------------------------------------------
/models/selecsls.py:
--------------------------------------------------------------------------------
1 | '''
2 | Pytorch implementation of SelecSLS Network architecture as described in
3 | 'XNect: Real-time Multi-person 3D Human Pose Estimation with a Single RGB Camera, Mehta et al. 2019'.
4 | The network architecture performs comparable to ResNet-50 while being 1.4-1.8x faster,
5 | particularly with larger image sizes. The network architecture has a much smaller memory
6 | footprint, and can be used as a drop in replacement for ResNet-50 in various tasks.
7 | This Pytorch implementation establishes an official baseline of the model on ImageNet
8 |
9 | This model also provides functionality to prune channels based on implicit sparsity, as
10 | described in 'On Implicit Filter Level Sparsity in Convolutional Neural Networks, Mehta et al. CVPR 2019'.
11 | This gives a 10-15% speedup depending on the model used.
12 |
13 | Author: Dushyant Mehta (dmehta[at]mpi-inf.mpg.de)
14 |
15 | This code is made available under CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode)
16 | '''
17 | from __future__ import absolute_import
18 | import torch
19 | import torch.nn as nn
20 | import torch.optim as optim
21 | import torch.nn.functional as F
22 | import math
23 | import fractions
24 |
25 |
26 | def conv_bn(inp, oup, stride):
27 | return nn.Sequential(
28 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
29 | nn.BatchNorm2d(oup),
30 | nn.ReLU(inplace=True)
31 | )
32 |
33 | def conv_1x1_bn(inp, oup):
34 | return nn.Sequential(
35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36 | nn.BatchNorm2d(oup),
37 | nn.ReLU(inplace=True)
38 | )
39 |
40 | def bn_fuse(c, b):
41 | ''' BN fusion code adapted from my Caffe BN fusion code and code from @MIPT-Oulu. This function assumes everything is on the cpu'''
42 | with torch.no_grad():
43 | # BatchNorm params
44 | eps = b.eps
45 | mu = b.running_mean
46 | var = b.running_var
47 | gamma = b.weight
48 |
49 | if 'bias' in b.state_dict():
50 | beta = b.bias
51 | else:
52 | #beta = torch.zeros(gamma.size(0)).float().to(gamma.device)
53 | beta = torch.zeros(gamma.size(0)).float()
54 |
55 | # Conv params
56 | W = c.weight
57 | if 'bias' in c.state_dict():
58 | bias = c.bias
59 | else:
60 | bias = torch.zeros(W.size(0)).float()
61 |
62 | denom = torch.sqrt(var + eps)
63 | b = beta - gamma.mul(mu).div(denom)
64 | A = gamma.div(denom)
65 | bias *= A
66 | A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)
67 |
68 | W.mul_(A)
69 | bias.add_(b)
70 |
71 | return W.clone().detach(), bias.clone().detach()
72 |
73 | class SelecSLSBlock(nn.Module):
74 | def __init__(self, inp, skip, k, oup, isFirst, stride):
75 | super(SelecSLSBlock, self).__init__()
76 | self.stride = stride
77 | self.isFirst = isFirst
78 | assert stride in [1, 2]
79 |
80 | #Process input with 4 conv blocks with the same number of input and output channels
81 | self.conv1 = nn.Sequential(
82 | nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1),
83 | nn.BatchNorm2d(k),
84 | nn.ReLU(inplace=True)
85 | )
86 | self.conv2 = nn.Sequential(
87 | nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
88 | nn.BatchNorm2d(k),
89 | nn.ReLU(inplace=True)
90 | )
91 | self.conv3 = nn.Sequential(
92 | nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
93 | nn.BatchNorm2d(k//2),
94 | nn.ReLU(inplace=True)
95 | )
96 | self.conv4 = nn.Sequential(
97 | nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
98 | nn.BatchNorm2d(k),
99 | nn.ReLU(inplace=True)
100 | )
101 | self.conv5 = nn.Sequential(
102 | nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
103 | nn.BatchNorm2d(k//2),
104 | nn.ReLU(inplace=True)
105 | )
106 | self.conv6 = nn.Sequential(
107 | nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1),
108 | nn.BatchNorm2d(oup),
109 | nn.ReLU(inplace=True)
110 | )
111 |
112 | def forward(self, x):
113 | assert isinstance(x,list)
114 | assert len(x) in [1,2]
115 |
116 | d1 = self.conv1(x[0])
117 | d2 = self.conv3(self.conv2(d1))
118 | d3 = self.conv5(self.conv4(d2))
119 | if self.isFirst:
120 | out = self.conv6(torch.cat([d1, d2, d3], 1))
121 | return [out, out]
122 | else:
123 | return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]]
124 |
125 | class SelecSLSBlockFused(nn.Module):
126 | def __init__(self, inp, skip, a,b,c,d,e, oup, isFirst, stride):
127 | super(SelecSLSBlockFused, self).__init__()
128 | self.stride = stride
129 | self.isFirst = isFirst
130 | assert stride in [1, 2]
131 |
132 | #Process input with 4 conv blocks with the same number of input and output channels
133 | self.conv1 = nn.Sequential(
134 | nn.Conv2d(inp, a, 3, stride, 1,groups= 1, bias=True, dilation=1),
135 | nn.ReLU(inplace=True)
136 | )
137 | self.conv2 = nn.Sequential(
138 | nn.Conv2d(a, b, 1, 1, 0,groups= 1, bias=True, dilation=1),
139 | nn.ReLU(inplace=True)
140 | )
141 | self.conv3 = nn.Sequential(
142 | nn.Conv2d(b, c, 3, 1, 1,groups= 1, bias=True, dilation=1),
143 | nn.ReLU(inplace=True)
144 | )
145 | self.conv4 = nn.Sequential(
146 | nn.Conv2d(c, d, 1, 1, 0,groups= 1, bias=True, dilation=1),
147 | nn.ReLU(inplace=True)
148 | )
149 | self.conv5 = nn.Sequential(
150 | nn.Conv2d(d, e, 3, 1, 1,groups= 1, bias=True, dilation=1),
151 | nn.ReLU(inplace=True)
152 | )
153 | self.conv6 = nn.Sequential(
154 | nn.Conv2d(a+c+e + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=True, dilation=1),
155 | nn.ReLU(inplace=True)
156 | )
157 |
158 | def forward(self, x):
159 | assert isinstance(x,list)
160 | assert len(x) in [1,2]
161 |
162 | d1 = self.conv1(x[0])
163 | d2 = self.conv3(self.conv2(d1))
164 | d3 = self.conv5(self.conv4(d2))
165 | if self.isFirst:
166 | out = self.conv6(torch.cat([d1, d2, d3], 1))
167 | return [out, out]
168 | else:
169 | return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]]
170 |
171 | class Net(nn.Module):
172 | def __init__(self, nClasses=1000, config='SelecSLS60'):
173 | super(Net, self).__init__()
174 |
175 | #Stem
176 | self.stem = conv_bn(3, 32, 2)
177 |
178 | #Core Network
179 | self.features = []
180 | if config=='SelecSLS42':
181 | print('SelecSLS42')
182 | #Define configuration of the network after the initial neck
183 | self.selecSLS_config = [
184 | #inp,skip, k, oup, isFirst, stride
185 | [ 32, 0, 64, 64, True, 2],
186 | [ 64, 64, 64, 128, False, 1],
187 | [128, 0, 144, 144, True, 2],
188 | [144, 144, 144, 288, False, 1],
189 | [288, 0, 304, 304, True, 2],
190 | [304, 304, 304, 480, False, 1],
191 | ]
192 | #Head can be replaced with alternative configurations depending on the problem
193 | self.head = nn.Sequential(
194 | conv_bn(480, 960, 2),
195 | conv_bn(960, 1024, 1),
196 | conv_bn(1024, 1024, 2),
197 | conv_1x1_bn(1024, 1280),
198 | )
199 | self.num_features = 1280
200 | elif config=='SelecSLS42_B':
201 | print('SelecSLS42_B')
202 | #Define configuration of the network after the initial neck
203 | self.selecSLS_config = [
204 | #inp,skip, k, oup, isFirst, stride
205 | [ 32, 0, 64, 64, True, 2],
206 | [ 64, 64, 64, 128, False, 1],
207 | [128, 0, 144, 144, True, 2],
208 | [144, 144, 144, 288, False, 1],
209 | [288, 0, 304, 304, True, 2],
210 | [304, 304, 304, 480, False, 1],
211 | ]
212 | #Head can be replaced with alternative configurations depending on the problem
213 | self.head = nn.Sequential(
214 | conv_bn(480, 960, 2),
215 | conv_bn(960, 1024, 1),
216 | conv_bn(1024, 1280, 2),
217 | conv_1x1_bn(1280, 1024),
218 | )
219 | self.num_features = 1024
220 | elif config=='SelecSLS60':
221 | print('SelecSLS60')
222 | #Define configuration of the network after the initial neck
223 | self.selecSLS_config = [
224 | #inp,skip, k, oup, isFirst, stride
225 | [ 32, 0, 64, 64, True, 2],
226 | [ 64, 64, 64, 128, False, 1],
227 | [128, 0, 128, 128, True, 2],
228 | [128, 128, 128, 128, False, 1],
229 | [128, 128, 128, 288, False, 1],
230 | [288, 0, 288, 288, True, 2],
231 | [288, 288, 288, 288, False, 1],
232 | [288, 288, 288, 288, False, 1],
233 | [288, 288, 288, 416, False, 1],
234 | ]
235 | #Head can be replaced with alternative configurations depending on the problem
236 | self.head = nn.Sequential(
237 | conv_bn(416, 756, 2),
238 | conv_bn(756, 1024, 1),
239 | conv_bn(1024, 1024, 2),
240 | conv_1x1_bn(1024, 1280),
241 | )
242 | self.num_features = 1280
243 | elif config=='SelecSLS60_B':
244 | print('SelecSLS60_B')
245 | #Define configuration of the network after the initial neck
246 | self.selecSLS_config = [
247 | #inp,skip, k, oup, isFirst, stride
248 | [ 32, 0, 64, 64, True, 2],
249 | [ 64, 64, 64, 128, False, 1],
250 | [128, 0, 128, 128, True, 2],
251 | [128, 128, 128, 128, False, 1],
252 | [128, 128, 128, 288, False, 1],
253 | [288, 0, 288, 288, True, 2],
254 | [288, 288, 288, 288, False, 1],
255 | [288, 288, 288, 288, False, 1],
256 | [288, 288, 288, 416, False, 1],
257 | ]
258 | #Head can be replaced with alternative configurations depending on the problem
259 | self.head = nn.Sequential(
260 | conv_bn(416, 756, 2),
261 | conv_bn(756, 1024, 1),
262 | conv_bn(1024, 1280, 2),
263 | conv_1x1_bn(1280, 1024),
264 | )
265 | self.num_features = 1024
266 | elif config=='SelecSLS84':
267 | print('SelecSLS84')
268 | #Define configuration of the network after the initial neck
269 | self.selecSLS_config = [
270 | #inp,skip, k, oup, isFirst, stride
271 | [ 32, 0, 64, 64, True, 2],
272 | [ 64, 64, 64, 144, False, 1],
273 | [144, 0, 144, 144, True, 2],
274 | [144, 144, 144, 144, False, 1],
275 | [144, 144, 144, 144, False, 1],
276 | [144, 144, 144, 144, False, 1],
277 | [144, 144, 144, 304, False, 1],
278 | [304, 0, 304, 304, True, 2],
279 | [304, 304, 304, 304, False, 1],
280 | [304, 304, 304, 304, False, 1],
281 | [304, 304, 304, 304, False, 1],
282 | [304, 304, 304, 304, False, 1],
283 | [304, 304, 304, 512, False, 1],
284 | ]
285 | #Head can be replaced with alternative configurations depending on the problem
286 | self.head = nn.Sequential(
287 | conv_bn(512, 960, 2),
288 | conv_bn(960, 1024, 1),
289 | conv_bn(1024, 1024, 2),
290 | conv_1x1_bn(1024, 1280),
291 | )
292 | self.num_features = 1280
293 | elif config=='SelecSLS102':
294 | print('SelecSLS102')
295 | #Define configuration of the network after the initial neck
296 | self.selecSLS_config = [
297 | #inp,skip, k, oup, isFirst, stride
298 | [ 32, 0, 64, 64, True, 2],
299 | [ 64, 64, 64, 64, False, 1],
300 | [ 64, 64, 64, 64, False, 1],
301 | [ 64, 64, 64, 128, False, 1],
302 | [128, 0, 128, 128, True, 2],
303 | [128, 128, 128, 128, False, 1],
304 | [128, 128, 128, 128, False, 1],
305 | [128, 128, 128, 128, False, 1],
306 | [128, 128, 128, 288, False, 1],
307 | [288, 0, 288, 288, True, 2],
308 | [288, 288, 288, 288, False, 1],
309 | [288, 288, 288, 288, False, 1],
310 | [288, 288, 288, 288, False, 1],
311 | [288, 288, 288, 288, False, 1],
312 | [288, 288, 288, 288, False, 1],
313 | [288, 288, 288, 480, False, 1],
314 | ]
315 | #Head can be replaced with alternative configurations depending on the problem
316 | self.head = nn.Sequential(
317 | conv_bn(480, 960, 2),
318 | conv_bn(960, 1024, 1),
319 | conv_bn(1024, 1024, 2),
320 | conv_1x1_bn(1024, 1280),
321 | )
322 | self.num_features = 1280
323 | else:
324 | raise ValueError('Invalid net configuration '+config+' !!!')
325 |
326 | #Build SelecSLS Core
327 | for inp, skip, k, oup, isFirst, stride in self.selecSLS_config:
328 | self.features.append(SelecSLSBlock(inp, skip, k, oup, isFirst, stride))
329 | self.features = nn.Sequential(*self.features)
330 |
331 | #Classifier To Produce Inputs to Softmax
332 | self.classifier = nn.Sequential(
333 | nn.Linear(self.num_features, nClasses),
334 | )
335 |
336 |
337 | def forward(self, x):
338 | x = self.stem(x)
339 | x = self.features([x])
340 | x = self.head(x[0])
341 | x = x.mean(3).mean(2)
342 | x = self.classifier(x)
343 | #x = F.log_softmax(x)
344 | return x
345 |
346 |
347 |
348 | def prune_and_fuse(self, gamma_thresh, verbose=False):
349 | ''' Function that iterates over the modules in the model and prunes different parts by name. Sparsity emerges implicitly due to the use of
350 | adaptive gradient descent approaches such as Adam, in conjunction with L2 or WD regularization on the parameters. The filters
351 | that are implicitly zeroed out can be explicitly pruned without any impact on the model accuracy (and might even improve in some cases).
352 | '''
353 | #This function assumes a specific structure. If the structure of stem or head is changed, this code would need to be changed too
354 | #Also, this be ugly. Needs to be written better, but is at least functional
355 | #Perhaps one need not worry about the layers made redundant, they can be removed from storage by tracing with the JIT module??
356 |
357 | #We bring everything to the CPU, then later restore the device
358 | device = next(self.parameters()).device
359 | self.to("cpu")
360 | with torch.no_grad():
361 | #Assumes that stem is flat and has conv,bn,relu in order. Can handle one or more of these if one wants to deepen the stem.
362 | new_stem = []
363 | input_validity = torch.ones(3)
364 | for i in range(0,len(self.stem),3):
365 | input_size = sum(input_validity.int()).item()
366 | #Calculate the extent of sparsity
367 | out_validity = abs(self.stem[i+1].weight) > gamma_thresh
368 | out_size = sum(out_validity.int()).item()
369 | W, b = bn_fuse(self.stem[i],self.stem[i+1])
370 | new_stem.append(nn.Conv2d(input_size, out_size, kernel_size = self.stem[i].kernel_size, stride=self.stem[i].stride, padding = self.stem[i].padding))
371 | new_stem.append(nn.ReLU(inplace=True))
372 | new_stem[-2].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(input_validity).squeeze()), 0, torch.nonzero(out_validity).squeeze()))
373 | new_stem[-2].bias.copy_(b[out_validity])
374 | input_validity = out_validity.clone().detach()
375 | if verbose:
376 | print('Stem '+str(len(new_stem)/2 -1)+': Pruned '+str(len(out_validity) - out_size) + ' from '+str(len(out_validity)))
377 | self.stem = nn.Sequential(*new_stem)
378 |
379 | new_features = []
380 | skip_validity = 0
381 | for i in range(len(self.features)):
382 | inp = int(sum(input_validity.int()).item())
383 | if self.features[i].isFirst:
384 | skip = 0
385 | a_validity = abs(self.features[i].conv1[1].weight) > gamma_thresh
386 | b_validity = abs(self.features[i].conv2[1].weight) > gamma_thresh
387 | c_validity = abs(self.features[i].conv3[1].weight) > gamma_thresh
388 | d_validity = abs(self.features[i].conv4[1].weight) > gamma_thresh
389 | e_validity = abs(self.features[i].conv5[1].weight) > gamma_thresh
390 | out_validity = abs(self.features[i].conv6[1].weight) > gamma_thresh
391 |
392 | new_features.append(SelecSLSBlockFused(inp, skip, int(sum(a_validity.int()).item()),int(sum(b_validity.int()).item()),int(sum(c_validity.int()).item()),int(sum(d_validity.int()).item()),int(sum(e_validity.int()).item()), int(sum(out_validity.int()).item()), self.features[i].isFirst, self.features[i].stride))
393 |
394 | #Conv1
395 | i_validity = input_validity.clone().detach()
396 | o_validity = a_validity.clone().detach()
397 | W, bias = bn_fuse(self.features[i].conv1[0], self.features[i].conv1[1])
398 | new_features[i].conv1[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze()))
399 | new_features[i].conv1[0].bias.copy_(bias[o_validity])
400 | if verbose:
401 | print('features.'+str(i)+'.conv1: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity)))
402 | #Conv2
403 | i_validity = o_validity.clone().detach()
404 | o_validity = b_validity.clone().detach()
405 | W, bias = bn_fuse(self.features[i].conv2[0], self.features[i].conv2[1])
406 | new_features[i].conv2[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze()))
407 | new_features[i].conv2[0].bias.copy_(bias[o_validity])
408 | if verbose:
409 | print('features.'+str(i)+'.conv2: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity)))
410 | #Conv3
411 | i_validity = o_validity.clone().detach()
412 | o_validity = c_validity.clone().detach()
413 | W, bias = bn_fuse(self.features[i].conv3[0], self.features[i].conv3[1])
414 | new_features[i].conv3[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze()))
415 | new_features[i].conv3[0].bias.copy_(bias[o_validity])
416 | if verbose:
417 | print('features.'+str(i)+'.conv3: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity)))
418 | #Conv4
419 | i_validity = o_validity.clone().detach()
420 | o_validity = d_validity.clone().detach()
421 | W, bias = bn_fuse(self.features[i].conv4[0], self.features[i].conv4[1])
422 | new_features[i].conv4[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze()))
423 | new_features[i].conv4[0].bias.copy_(bias[o_validity])
424 | if verbose:
425 | print('features.'+str(i)+'.conv4: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity)))
426 | #Conv5
427 | i_validity = o_validity.clone().detach()
428 | o_validity = e_validity.clone().detach()
429 | W, bias = bn_fuse(self.features[i].conv5[0], self.features[i].conv5[1])
430 | new_features[i].conv5[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze()))
431 | new_features[i].conv5[0].bias.copy_(bias[o_validity])
432 | if verbose:
433 | print('features.'+str(i)+'.conv5: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity)))
434 | #Conv6
435 | i_validity = torch.cat([a_validity.clone().detach(), c_validity.clone().detach(), e_validity.clone().detach()], 0)
436 | if self.features[i].isFirst:
437 | skip = int(sum(out_validity.int()).item())
438 | skip_validity = out_validity.clone().detach()
439 | else:
440 | i_validity = torch.cat([i_validity, skip_validity], 0)
441 | o_validity = out_validity.clone().detach()
442 | W, bias = bn_fuse(self.features[i].conv6[0], self.features[i].conv6[1])
443 | new_features[i].conv6[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze()))
444 | new_features[i].conv6[0].bias.copy_(bias[o_validity])
445 | if verbose:
446 | print('features.'+str(i)+'.conv6: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity)))
447 |
448 | input_validity = out_validity.clone().detach()
449 | self.features = nn.Sequential(*new_features)
450 |
451 | new_head = []
452 | for i in range(len(self.head)):
453 | input_size = int(sum(input_validity.int()).item())
454 | #Calculate the extent of sparsity
455 | out_validity = abs(self.head[i][1].weight) > gamma_thresh
456 | out_size = int(sum(out_validity.int()).item())
457 | W, b = bn_fuse(self.head[i][0],self.head[i][1])
458 | new_head.append(nn.Conv2d(input_size, out_size, kernel_size = self.head[i][0].kernel_size, stride=self.head[i][0].stride, padding = self.head[i][0].padding))
459 | new_head.append(nn.ReLU(inplace=True))
460 | new_head[-2].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(input_validity).squeeze()), 0, torch.nonzero(out_validity).squeeze()))
461 | new_head[-2].bias.copy_(b[out_validity])
462 | input_validity = out_validity.clone().detach()
463 | if verbose:
464 | print('Head '+str(len(new_head)/2 -1)+': Pruned '+str(len(out_validity) - out_size) + ' from '+str(len(out_validity)))
465 | self.head = nn.Sequential(*new_head)
466 |
467 | new_classifier = []
468 | new_classifier.append(nn.Linear(int(sum(input_validity.int()).item()), self.classifier[0].weight.shape[0]))
469 | new_classifier[0].weight.copy_(torch.index_select(self.classifier[0].weight, 1, torch.nonzero(input_validity).squeeze()))
470 | new_classifier[0].bias.copy_(self.classifier[0].bias)
471 | self.classifier = nn.Sequential(*new_classifier)
472 |
473 | self.to(device)
474 |
475 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mehtadushy/SelecSLS-Pytorch/3852734af392b8fa69834984c76856dbd39179a3/util/__init__.py
--------------------------------------------------------------------------------
/util/imagenet_data_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import time
4 | import path
5 | import math
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torchvision.datasets as dset
12 | import torchvision.transforms as transforms
13 | from torch.utils.data import DataLoader
14 |
15 | def get_data_loader(augment=False, batch_size=50, base_path="path_to_ImageNet"):
16 |
17 | print('Loading ImageNet in all its glory...')
18 | dataset = dset.ImageFolder
19 |
20 | # Prepare transforms and data augmentation
21 | norm_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22 | train_transform = transforms.Compose([
23 | transforms.RandomResizedCrop(224),
24 | transforms.RandomHorizontalFlip(),
25 | transforms.ToTensor(),
26 | norm_transform
27 | ])
28 | test_transform = transforms.Compose([
29 | transforms.Resize(256),
30 | transforms.CenterCrop(224),
31 | transforms.ToTensor(),
32 | norm_transform
33 | ])
34 | kwargs = {'num_workers': 8, 'pin_memory': True}
35 |
36 | train_set = dataset(
37 | root=base_path+'/train/',
38 | transform=train_transform if augment else test_transform)
39 | test_set = dataset(base_path+'/val/',
40 | transform=test_transform)
41 |
42 | # Prepare data loaders
43 | train_loader = DataLoader(train_set, batch_size=batch_size,
44 | shuffle=True, **kwargs)
45 | test_loader = DataLoader(test_set, batch_size=batch_size,
46 | shuffle=False, **kwargs)
47 |
48 | return train_loader, test_loader
49 |
--------------------------------------------------------------------------------
/weights/readme.txt:
--------------------------------------------------------------------------------
1 | Get pretrained imagenet models from:
2 |
3 | Old Training Results
4 | SelecSLS60: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict.pth
5 | SelecSLS84: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS84_statedict.pth
6 | New Fangled Training Results
7 | SelecSLS60: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict_better.pth
8 | SelecSLS42_B: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B_statedict.pth
9 | SelecSLS60_B: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B_statedict.pth
10 |
--------------------------------------------------------------------------------