├── .gitignore ├── README.md ├── evaluate_imagenet.py ├── evaluate_timing.py ├── models ├── __init__.py └── selecsls.py ├── util ├── __init__.py └── imagenet_data_loader.py └── weights └── readme.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SelecSLS Convolutional Net Pytorch Implementation 2 | Reference ImageNet implementation of SelecSLS Convolutional Neural Network architecture proposed in [XNect: Real-time Multi-Person 3D Motion Capture 3 | with a Single RGB Camera](http://gvv.mpi-inf.mpg.de/projects/XNect/) (SIGGRAPH 2020). 4 | 5 | The network architecture is 1.3-1.5x faster than ResNet-50, particularly for larger image sizes, with the same level of accuracy on different tasks! 6 | Further, it takes substantially less memory while training, so it can be trained with larger batch sizes! 7 | 8 | ### Update (28 Dec 2019) 9 | Better and more accurate models / snapshots are now available. See the additional ImageNet table below. 10 | 11 | ### Update (14 Oct 2019) 12 | Code for pruning the model based on [Implicit Filter Level Sparsity](http://openaccess.thecvf.com/content_CVPR_2019/html/Mehta_On_Implicit_Filter_Level_Sparsity_in_Convolutional_Neural_Networks_CVPR_2019_paper.html) is also a part of the [SelecSLS model](https://github.com/mehtadushy/SelecSLS-Pytorch/blob/master/models/selecsls.py#L280) now. The sparsity is a natural consequence of training with adaptive gradient descent approaches and L2 regularization. It gives a further speedup of **10-30%** on the pretrained models with no loss in accuracy. See usage and results below. 13 | 14 | ## ImageNet results 15 | 16 | The inference time for the models in the table below is measured on a TITAN X GPU using the accompanying scripts. The accuracy results for ResNet-50 are from torchvision, and the accuracy results for VoVNet-39 are from [VoVNet](https://github.com/stigma0617/VoVNet.pytorch). 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | * (P) indicates that the model has batch norm fusion and pruning applied 108 |
Forward Pass Time (ms)
for different image resolutions
ImageNet
Error
512x512400x400224x224Top-1Top-5
Batch Size116116116
ResNet-5015.0175.011.0114.07.239.023.97.1
VoVNet-3913.0197.010.8130.0641.023.26.6
SelecSLS-6011.0115.09.585.07.329.023.87.0
SelecSLS-60 (P)10.2102.08.271.06.125.023.87.0
SelecSLS-8416.1175.013.7124.09.942.323.36.9
SelecSLS-84 (P)11.9119.010.182.07.628.623.36.9
109 | 110 | 111 | The following models are trained using Cosine LR, Random Erasing, EMA, *Bicubic* Interpolation, and Color Jitter using [rwightman/pytorch-image-models](https://github.com/rwightman/pytorch-image-models). The inference time for models here is measured on a TITAN Xp GPU using the accompanying scripts. The script for evaluating ImageNet performance uses *Bilinear* interpolation, hence the results reported here are marginally worse than they would be with Bicubic interpolation at inference. 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 |
Forward Pass Time (ms)
for different image resolutions
ImageNet
Error
512x512400x400224x224Top-1Top-5
Batch Size116116116
SelecSLS-42_B6.460.85.842.15.714.722.96.6
SelecSLS-607.469.47.347.67.116.822.16.1
SelecSLS-60_B7.570.57.349.37.217.021.65.8
174 | 175 | 176 | 177 | # SelecSLS (Selective Short and Long Range Skip Connections) 178 | The key feature of the proposed architecture is that unlike the full dense connectivity in DenseNets, SelecSLS uses a much sparser skip connectivity pattern that uses both long and short-range concatenative-skip connections. Additionally, the network architecture is more amenable to [filter/channel pruning](http://openaccess.thecvf.com/content_CVPR_2019/html/Mehta_On_Implicit_Filter_Level_Sparsity_in_Convolutional_Neural_Networks_CVPR_2019_paper.html) than ResNets. 179 | You can find more details about the architecture in the following [paper](https://arxiv.org/abs/1907.00837), and details about implicit pruning in the [CVPR 2019 paper](http://openaccess.thecvf.com/content_CVPR_2019/html/Mehta_On_Implicit_Filter_Level_Sparsity_in_Convolutional_Neural_Networks_CVPR_2019_paper.html). 180 | 181 | Another recent paper proposed the VoVNet architecture, which shares some design similarities with our architecture. However, as shown in the above table, our architecture is significantly faster than both VoVNet-39 and ResNet-50 for larger batch sizes as well as larger image sizes. 182 | 183 | ## Usage 184 | This repo provides the model definition in Pytorch, trained weights for ImageNet, and code for evaluating the forward pass time 185 | and the accuracy of the trained model on ImageNet validation set. 186 | In the paper, the model has been used for the task of human pose estimation, and can also be applied to a myriad of other problems as a drop in replacement for ResNet-50. 187 | 188 | ``` 189 | wget http://gvv.mpi-inf.mpg.de/projects/XNectDemoV2/content/SelecSLS60_statedict.pth -O ./weights/SelecSLS60_statedict.pth 190 | python evaluate_timing.py --num_iter 100 --model_class selecsls --model_config SelecSLS60 --model_weights ./weights/SelecSLS60_statedict.pth --input_size 512 --gpu_id 191 | python evaluate_imagenet.py --model_class selecsls --model_config SelecSLS60 --model_weights ./weights/SelecSLS60_statedict.pth --gpu_id --imagenet_base_path 192 | 193 | #For pruning the model, and evaluating the pruned model (Using SelecSLS60 or other pretrained models) 194 | python evaluate_timing.py --num_iter 100 --model_class selecsls --model_config SelecSLS84 --model_weights ./weights/SelecSLS84_statedict.pth --input_size 512 --pruned_and_fused True --gamma_thresh 0.001 --gpu_id 195 | python evaluate_imagenet.py --model_class selecsls --model_config SelecSLS84 --model_weights ./weights/SelecSLS84_statedict.pth --pruned_and_fused True --gamma_thresh 0.001 --gpu_id --imagenet_base_path 196 | ``` 197 | 198 | ## Older Pretrained Models 199 | - [SelecSLS-60](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict.pth) 200 | - [SelecSLS-84](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS84_statedict.pth) 201 | 202 | ## Newer Pretrained Models (More Accurate) 203 | - [SelecSLS-42_B](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B_statedict.pth) 204 | - [SelecSLS-60](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict_better.pth) 205 | - [SelecSLS-60_B](http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B_statedict.pth) 206 | 207 | ## Requirements 208 | - Python 3.5 209 | - Pytorch >= 1.1 210 | 211 | ## License 212 | The contents of this repository, and the pretrained models are made available under CC BY 4.0. Please read the [license terms](https://creativecommons.org/licenses/by/4.0/legalcode). 213 | 214 | ### Citing 215 | If you use the model or the implicit sparisty based pruning in your work, please cite: 216 | 217 | ``` 218 | @inproceedings{XNect_SIGGRAPH2020, 219 | author = {Mehta, Dushyant and Sotnychenko, Oleksandr and Mueller, Franziska and Xu, Weipeng and Elgharib, Mohamed and Fua, Pascal and Seidel, Hans-Peter and Rhodin, Helge and Pons-Moll, Gerard and Theobalt, Christian}, 220 | title = {{XNect}: Real-time Multi-Person {3D} Motion Capture with a Single {RGB} Camera}, 221 | journal = {ACM Transactions on Graphics}, 222 | url = {http://gvv.mpi-inf.mpg.de/projects/XNect/}, 223 | numpages = {17}, 224 | volume={39}, 225 | number={4}, 226 | month = July, 227 | year = {2020}, 228 | doi={10.1145/3386569.3392410} 229 | } 230 | 231 | @InProceedings{Mehta_2019_CVPR, 232 | author = {Mehta, Dushyant and Kim, Kwang In and Theobalt, Christian}, 233 | title = {On Implicit Filter Level Sparsity in Convolutional Neural Networks}, 234 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 235 | month = {June}, 236 | year = {2019} 237 | } 238 | ``` 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /evaluate_imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Script for evaluating accuracy on Imagenet Validation Set. 5 | ''' 6 | import os 7 | import logging 8 | import sys 9 | import time 10 | from argparse import ArgumentParser 11 | import importlib 12 | 13 | import numpy as np 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | torch.backends.cudnn.benchmark = True 19 | from util.imagenet_data_loader import get_data_loader 20 | 21 | 22 | 23 | def opts_parser(): 24 | usage = 'Configure the dataset using imagenet_data_loader' 25 | parser = ArgumentParser(description=usage) 26 | parser.add_argument( 27 | '--model_class', type=str, default='selecsls', metavar='FILE', 28 | help='Select model type to use (DenseNet, SelecSLS, ResNet etc.)') 29 | parser.add_argument( 30 | '--model_config', type=str, default='SelecSLS60', metavar='NET_CONFIG', 31 | help='Select the model configuration') 32 | parser.add_argument( 33 | '--model_weights', type=str, default='./weights/SelecSLS60_statedict.pth', metavar='FILE', 34 | help='Path to model weights') 35 | parser.add_argument( 36 | '--imagenet_base_path', type=str, default='', metavar='FILE', 37 | help='Path to ImageNet dataset') 38 | parser.add_argument( 39 | '--gpu_id', type=int, default=0, 40 | help='Which GPU to use.') 41 | parser.add_argument( 42 | '--simulate_pruning', type=bool, default=False, 43 | help='Whether to zero out features with gamma below a certain threshold') 44 | parser.add_argument( 45 | '--pruned_and_fused', type=bool, default=False, 46 | help='Whether to prune based on gamma below a certain threshold and fuse BN') 47 | parser.add_argument( 48 | '--gamma_thresh', type=float, default=1e-4, 49 | help='gamma threshold to use for simulating pruning') 50 | return parser 51 | 52 | 53 | def accuracy(output, target, topk=(1,)): 54 | """Computes the precision@k for the specified values of k""" 55 | maxk = max(topk) 56 | batch_size = target.size(0) 57 | 58 | _, pred = output.topk(maxk, 1, True, True) 59 | pred = pred.t() 60 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 61 | 62 | res = [] 63 | for k in topk: 64 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 65 | res.append(correct_k.mul_(100.0 / batch_size)) 66 | return res 67 | 68 | 69 | def evaluate_imagenet_validation_accuracy(model_class, model_config, model_weights, imagenet_base_path, gpu_id, simulate_pruning, pruned_and_fused, gamma_thresh): 70 | model_module = importlib.import_module('models.'+model_class) 71 | net = model_module.Net(nClasses=1000, config=model_config) 72 | net.load_state_dict(torch.load(model_weights, map_location= lambda storage, loc: storage)) 73 | 74 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 75 | net = net.to(device) 76 | if pruned_and_fused: 77 | print('Fusing BN and pruning channels based on gamma ' + str(gamma_thresh)) 78 | net.prune_and_fuse(gamma_thresh) 79 | 80 | if simulate_pruning: 81 | print('Simulating pruning by zeroing all features with gamma less than '+str(gamma_thresh)) 82 | with torch.no_grad(): 83 | for n, m in net.named_modules(): 84 | if isinstance(m, nn.BatchNorm2d): 85 | m.weight[abs(m.weight) < gamma_thresh] = 0 86 | m.bias[abs(m.weight) < gamma_thresh] = 0 87 | 88 | net.eval() 89 | _,test_loader = get_data_loader(augment=False, batch_size=100, base_path=imagenet_base_path) 90 | with torch.no_grad(): 91 | val1_err = [] 92 | val5_err = [] 93 | for x, y in test_loader: 94 | pred = F.log_softmax(net(x.to(device))) 95 | top1, top5 = accuracy(pred, y.to(device), topk=(1, 5)) 96 | val1_err.append(100-top1) 97 | val5_err.append(100-top5) 98 | avg1_err= float(np.sum(val1_err)) / len(val1_err) 99 | avg5_err= float(np.sum(val5_err)) / len(val5_err) 100 | print('Top-1 Error: {} Top-5 Error {}'.format(avg1_err, avg5_err)) 101 | 102 | 103 | def main(): 104 | # parse command line 105 | torch.manual_seed(1234) 106 | parser = opts_parser() 107 | args = parser.parse_args() 108 | 109 | # run 110 | evaluate_imagenet_validation_accuracy(**vars(args)) 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /evaluate_timing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Script for timing models in eval mode and torchscript eval modes. 5 | ''' 6 | 7 | import os 8 | import logging 9 | import sys 10 | import time 11 | from argparse import ArgumentParser 12 | import importlib 13 | 14 | import numpy as np 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | torch.backends.cudnn.benchmark = True 20 | 21 | 22 | 23 | def opts_parser(): 24 | usage = 'Pass the model and' 25 | parser = ArgumentParser(description=usage) 26 | parser.add_argument( 27 | '--num_iter', type=int, default=50, 28 | help='Number of iterations to average over.') 29 | parser.add_argument( 30 | '--model_class', type=str, default='selecsls', metavar='FILE', 31 | help='Select model type to use (DenseNet, SelecSLS, ResNet etc.)') 32 | parser.add_argument( 33 | '--model_config', type=str, default='SelecSLS84', metavar='NET_CONFIG', 34 | help='Select the model configuration') 35 | parser.add_argument( 36 | '--model_weights', type=str, default='./weights/SelecSLS84_statedict.pth', metavar='FILE', 37 | help='Path to model weights') 38 | parser.add_argument( 39 | '--input_size', type=int, default=400, 40 | help='Input image size.') 41 | parser.add_argument( 42 | '--gpu_id', type=int, default=0, 43 | help='Which GPU to use.') 44 | parser.add_argument( 45 | '--pruned_and_fused', type=bool, default=False, 46 | help='Whether to zero out features with gamma below a certain threshold') 47 | parser.add_argument( 48 | '--gamma_thresh', type=float, default=1e-3, 49 | help='gamma threshold to use for simulating pruning. Set this to -1 to only fuse BN without pruning') 50 | return parser 51 | 52 | 53 | def measure_cpu(model, x): 54 | # synchronize gpu time and measure fp 55 | model.eval() 56 | with torch.no_grad(): 57 | t0 = time.time() 58 | y_pred = model(x) 59 | elapsed_fp_nograd = time.time()-t0 60 | return elapsed_fp_nograd 61 | 62 | def measure_gpu(model, x): 63 | # synchronize gpu time and measure fp 64 | model.eval() 65 | with torch.no_grad(): 66 | torch.cuda.synchronize() 67 | t0 = time.time() 68 | y_pred = model(x) 69 | torch.cuda.synchronize() 70 | elapsed_fp_nograd = time.time()-t0 71 | return elapsed_fp_nograd 72 | 73 | 74 | def benchmark(model_class, model_config, gpu_id, num_iter, model_weights, input_size, pruned_and_fused, gamma_thresh): 75 | # Import the model module 76 | model_module = importlib.import_module('models.'+model_class) 77 | net = model_module.Net(nClasses=1000, config=model_config) 78 | net.load_state_dict(torch.load(model_weights, map_location= lambda storage, loc: storage)) 79 | 80 | if pruned_and_fused: 81 | print('Pruning and fusing the model') 82 | net.prune_and_fuse(gamma_thresh, True) 83 | 84 | device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu") 85 | net = net.to(device) 86 | print('\nEvaluating on GPU {}'.format(device)) 87 | 88 | print('\nGPU, Batch Size: 1') 89 | x = torch.randn(1, 3, input_size, input_size) 90 | #Warm up 91 | for i in range(10): 92 | _ = measure_gpu(net, x.to(device)) 93 | fp = [] 94 | for i in range(num_iter): 95 | t = measure_gpu(net, x.to(device)) 96 | fp.append(t) 97 | print('Model FP: '+str(np.mean(np.asarray(fp)*1000))+'ms') 98 | 99 | jit_net = torch.jit.trace(net, x.to(device)) 100 | for i in range(10): 101 | _ = measure_gpu(jit_net, x.to(device)) 102 | fp = [] 103 | for i in range(num_iter): 104 | t = measure_gpu(jit_net, x.to(device)) 105 | fp.append(t) 106 | print('JIT FP: '+str(np.mean(np.asarray(fp)*1000))+'ms') 107 | 108 | 109 | print('\nGPU, Batch Size: 16') 110 | x = torch.randn(16, 3, input_size, input_size) 111 | #Warm up 112 | for i in range(10): 113 | _ = measure_gpu(net, x.to(device)) 114 | fp = [] 115 | for i in range(num_iter): 116 | t = measure_gpu(net, x.to(device)) 117 | fp.append(t) 118 | print('Model FP: '+str(np.mean(np.asarray(fp)*1000))+'ms') 119 | 120 | jit_net = torch.jit.trace(net, x.to(device)) 121 | for i in range(10): 122 | _ = measure_gpu(jit_net, x.to(device)) 123 | fp = [] 124 | for i in range(num_iter): 125 | t = measure_gpu(jit_net, x.to(device)) 126 | fp.append(t) 127 | print('JIT FP: '+str(np.mean(np.asarray(fp)*1000))+'ms') 128 | 129 | device = torch.device("cpu") 130 | print('\nEvaluating on {}'.format(device)) 131 | net = net.to(device) 132 | 133 | print('\nCPU, Batch Size: 1') 134 | x = torch.randn(1, 3, input_size, input_size) 135 | #Warm up 136 | for i in range(10): 137 | _ = measure_cpu(net, x.to(device)) 138 | fp = [] 139 | for i in range(num_iter): 140 | t = measure_cpu(net, x.to(device)) 141 | fp.append(t) 142 | print('Model FP: '+str(np.mean(np.asarray(fp)*1000))+'ms') 143 | 144 | jit_net = torch.jit.trace(net, x.to(device)) 145 | for i in range(10): 146 | _ = measure_cpu(jit_net, x.to(device)) 147 | fp = [] 148 | for i in range(num_iter): 149 | t = measure_cpu(jit_net, x.to(device)) 150 | fp.append(t) 151 | print('JIT FP: '+str(np.mean(np.asarray(fp)*1000))+'ms') 152 | 153 | 154 | 155 | def main(): 156 | # parse command line 157 | torch.manual_seed(1234) 158 | parser = opts_parser() 159 | args = parser.parse_args() 160 | 161 | # run 162 | benchmark(**vars(args)) 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehtadushy/SelecSLS-Pytorch/3852734af392b8fa69834984c76856dbd39179a3/models/__init__.py -------------------------------------------------------------------------------- /models/selecsls.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Pytorch implementation of SelecSLS Network architecture as described in 3 | 'XNect: Real-time Multi-person 3D Human Pose Estimation with a Single RGB Camera, Mehta et al. 2019'. 4 | The network architecture performs comparable to ResNet-50 while being 1.4-1.8x faster, 5 | particularly with larger image sizes. The network architecture has a much smaller memory 6 | footprint, and can be used as a drop in replacement for ResNet-50 in various tasks. 7 | This Pytorch implementation establishes an official baseline of the model on ImageNet 8 | 9 | This model also provides functionality to prune channels based on implicit sparsity, as 10 | described in 'On Implicit Filter Level Sparsity in Convolutional Neural Networks, Mehta et al. CVPR 2019'. 11 | This gives a 10-15% speedup depending on the model used. 12 | 13 | Author: Dushyant Mehta (dmehta[at]mpi-inf.mpg.de) 14 | 15 | This code is made available under CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode) 16 | ''' 17 | from __future__ import absolute_import 18 | import torch 19 | import torch.nn as nn 20 | import torch.optim as optim 21 | import torch.nn.functional as F 22 | import math 23 | import fractions 24 | 25 | 26 | def conv_bn(inp, oup, stride): 27 | return nn.Sequential( 28 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 29 | nn.BatchNorm2d(oup), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | def conv_1x1_bn(inp, oup): 34 | return nn.Sequential( 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.ReLU(inplace=True) 38 | ) 39 | 40 | def bn_fuse(c, b): 41 | ''' BN fusion code adapted from my Caffe BN fusion code and code from @MIPT-Oulu. This function assumes everything is on the cpu''' 42 | with torch.no_grad(): 43 | # BatchNorm params 44 | eps = b.eps 45 | mu = b.running_mean 46 | var = b.running_var 47 | gamma = b.weight 48 | 49 | if 'bias' in b.state_dict(): 50 | beta = b.bias 51 | else: 52 | #beta = torch.zeros(gamma.size(0)).float().to(gamma.device) 53 | beta = torch.zeros(gamma.size(0)).float() 54 | 55 | # Conv params 56 | W = c.weight 57 | if 'bias' in c.state_dict(): 58 | bias = c.bias 59 | else: 60 | bias = torch.zeros(W.size(0)).float() 61 | 62 | denom = torch.sqrt(var + eps) 63 | b = beta - gamma.mul(mu).div(denom) 64 | A = gamma.div(denom) 65 | bias *= A 66 | A = A.expand_as(W.transpose(0, -1)).transpose(0, -1) 67 | 68 | W.mul_(A) 69 | bias.add_(b) 70 | 71 | return W.clone().detach(), bias.clone().detach() 72 | 73 | class SelecSLSBlock(nn.Module): 74 | def __init__(self, inp, skip, k, oup, isFirst, stride): 75 | super(SelecSLSBlock, self).__init__() 76 | self.stride = stride 77 | self.isFirst = isFirst 78 | assert stride in [1, 2] 79 | 80 | #Process input with 4 conv blocks with the same number of input and output channels 81 | self.conv1 = nn.Sequential( 82 | nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1), 83 | nn.BatchNorm2d(k), 84 | nn.ReLU(inplace=True) 85 | ) 86 | self.conv2 = nn.Sequential( 87 | nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1), 88 | nn.BatchNorm2d(k), 89 | nn.ReLU(inplace=True) 90 | ) 91 | self.conv3 = nn.Sequential( 92 | nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1), 93 | nn.BatchNorm2d(k//2), 94 | nn.ReLU(inplace=True) 95 | ) 96 | self.conv4 = nn.Sequential( 97 | nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1), 98 | nn.BatchNorm2d(k), 99 | nn.ReLU(inplace=True) 100 | ) 101 | self.conv5 = nn.Sequential( 102 | nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1), 103 | nn.BatchNorm2d(k//2), 104 | nn.ReLU(inplace=True) 105 | ) 106 | self.conv6 = nn.Sequential( 107 | nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1), 108 | nn.BatchNorm2d(oup), 109 | nn.ReLU(inplace=True) 110 | ) 111 | 112 | def forward(self, x): 113 | assert isinstance(x,list) 114 | assert len(x) in [1,2] 115 | 116 | d1 = self.conv1(x[0]) 117 | d2 = self.conv3(self.conv2(d1)) 118 | d3 = self.conv5(self.conv4(d2)) 119 | if self.isFirst: 120 | out = self.conv6(torch.cat([d1, d2, d3], 1)) 121 | return [out, out] 122 | else: 123 | return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]] 124 | 125 | class SelecSLSBlockFused(nn.Module): 126 | def __init__(self, inp, skip, a,b,c,d,e, oup, isFirst, stride): 127 | super(SelecSLSBlockFused, self).__init__() 128 | self.stride = stride 129 | self.isFirst = isFirst 130 | assert stride in [1, 2] 131 | 132 | #Process input with 4 conv blocks with the same number of input and output channels 133 | self.conv1 = nn.Sequential( 134 | nn.Conv2d(inp, a, 3, stride, 1,groups= 1, bias=True, dilation=1), 135 | nn.ReLU(inplace=True) 136 | ) 137 | self.conv2 = nn.Sequential( 138 | nn.Conv2d(a, b, 1, 1, 0,groups= 1, bias=True, dilation=1), 139 | nn.ReLU(inplace=True) 140 | ) 141 | self.conv3 = nn.Sequential( 142 | nn.Conv2d(b, c, 3, 1, 1,groups= 1, bias=True, dilation=1), 143 | nn.ReLU(inplace=True) 144 | ) 145 | self.conv4 = nn.Sequential( 146 | nn.Conv2d(c, d, 1, 1, 0,groups= 1, bias=True, dilation=1), 147 | nn.ReLU(inplace=True) 148 | ) 149 | self.conv5 = nn.Sequential( 150 | nn.Conv2d(d, e, 3, 1, 1,groups= 1, bias=True, dilation=1), 151 | nn.ReLU(inplace=True) 152 | ) 153 | self.conv6 = nn.Sequential( 154 | nn.Conv2d(a+c+e + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=True, dilation=1), 155 | nn.ReLU(inplace=True) 156 | ) 157 | 158 | def forward(self, x): 159 | assert isinstance(x,list) 160 | assert len(x) in [1,2] 161 | 162 | d1 = self.conv1(x[0]) 163 | d2 = self.conv3(self.conv2(d1)) 164 | d3 = self.conv5(self.conv4(d2)) 165 | if self.isFirst: 166 | out = self.conv6(torch.cat([d1, d2, d3], 1)) 167 | return [out, out] 168 | else: 169 | return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]] 170 | 171 | class Net(nn.Module): 172 | def __init__(self, nClasses=1000, config='SelecSLS60'): 173 | super(Net, self).__init__() 174 | 175 | #Stem 176 | self.stem = conv_bn(3, 32, 2) 177 | 178 | #Core Network 179 | self.features = [] 180 | if config=='SelecSLS42': 181 | print('SelecSLS42') 182 | #Define configuration of the network after the initial neck 183 | self.selecSLS_config = [ 184 | #inp,skip, k, oup, isFirst, stride 185 | [ 32, 0, 64, 64, True, 2], 186 | [ 64, 64, 64, 128, False, 1], 187 | [128, 0, 144, 144, True, 2], 188 | [144, 144, 144, 288, False, 1], 189 | [288, 0, 304, 304, True, 2], 190 | [304, 304, 304, 480, False, 1], 191 | ] 192 | #Head can be replaced with alternative configurations depending on the problem 193 | self.head = nn.Sequential( 194 | conv_bn(480, 960, 2), 195 | conv_bn(960, 1024, 1), 196 | conv_bn(1024, 1024, 2), 197 | conv_1x1_bn(1024, 1280), 198 | ) 199 | self.num_features = 1280 200 | elif config=='SelecSLS42_B': 201 | print('SelecSLS42_B') 202 | #Define configuration of the network after the initial neck 203 | self.selecSLS_config = [ 204 | #inp,skip, k, oup, isFirst, stride 205 | [ 32, 0, 64, 64, True, 2], 206 | [ 64, 64, 64, 128, False, 1], 207 | [128, 0, 144, 144, True, 2], 208 | [144, 144, 144, 288, False, 1], 209 | [288, 0, 304, 304, True, 2], 210 | [304, 304, 304, 480, False, 1], 211 | ] 212 | #Head can be replaced with alternative configurations depending on the problem 213 | self.head = nn.Sequential( 214 | conv_bn(480, 960, 2), 215 | conv_bn(960, 1024, 1), 216 | conv_bn(1024, 1280, 2), 217 | conv_1x1_bn(1280, 1024), 218 | ) 219 | self.num_features = 1024 220 | elif config=='SelecSLS60': 221 | print('SelecSLS60') 222 | #Define configuration of the network after the initial neck 223 | self.selecSLS_config = [ 224 | #inp,skip, k, oup, isFirst, stride 225 | [ 32, 0, 64, 64, True, 2], 226 | [ 64, 64, 64, 128, False, 1], 227 | [128, 0, 128, 128, True, 2], 228 | [128, 128, 128, 128, False, 1], 229 | [128, 128, 128, 288, False, 1], 230 | [288, 0, 288, 288, True, 2], 231 | [288, 288, 288, 288, False, 1], 232 | [288, 288, 288, 288, False, 1], 233 | [288, 288, 288, 416, False, 1], 234 | ] 235 | #Head can be replaced with alternative configurations depending on the problem 236 | self.head = nn.Sequential( 237 | conv_bn(416, 756, 2), 238 | conv_bn(756, 1024, 1), 239 | conv_bn(1024, 1024, 2), 240 | conv_1x1_bn(1024, 1280), 241 | ) 242 | self.num_features = 1280 243 | elif config=='SelecSLS60_B': 244 | print('SelecSLS60_B') 245 | #Define configuration of the network after the initial neck 246 | self.selecSLS_config = [ 247 | #inp,skip, k, oup, isFirst, stride 248 | [ 32, 0, 64, 64, True, 2], 249 | [ 64, 64, 64, 128, False, 1], 250 | [128, 0, 128, 128, True, 2], 251 | [128, 128, 128, 128, False, 1], 252 | [128, 128, 128, 288, False, 1], 253 | [288, 0, 288, 288, True, 2], 254 | [288, 288, 288, 288, False, 1], 255 | [288, 288, 288, 288, False, 1], 256 | [288, 288, 288, 416, False, 1], 257 | ] 258 | #Head can be replaced with alternative configurations depending on the problem 259 | self.head = nn.Sequential( 260 | conv_bn(416, 756, 2), 261 | conv_bn(756, 1024, 1), 262 | conv_bn(1024, 1280, 2), 263 | conv_1x1_bn(1280, 1024), 264 | ) 265 | self.num_features = 1024 266 | elif config=='SelecSLS84': 267 | print('SelecSLS84') 268 | #Define configuration of the network after the initial neck 269 | self.selecSLS_config = [ 270 | #inp,skip, k, oup, isFirst, stride 271 | [ 32, 0, 64, 64, True, 2], 272 | [ 64, 64, 64, 144, False, 1], 273 | [144, 0, 144, 144, True, 2], 274 | [144, 144, 144, 144, False, 1], 275 | [144, 144, 144, 144, False, 1], 276 | [144, 144, 144, 144, False, 1], 277 | [144, 144, 144, 304, False, 1], 278 | [304, 0, 304, 304, True, 2], 279 | [304, 304, 304, 304, False, 1], 280 | [304, 304, 304, 304, False, 1], 281 | [304, 304, 304, 304, False, 1], 282 | [304, 304, 304, 304, False, 1], 283 | [304, 304, 304, 512, False, 1], 284 | ] 285 | #Head can be replaced with alternative configurations depending on the problem 286 | self.head = nn.Sequential( 287 | conv_bn(512, 960, 2), 288 | conv_bn(960, 1024, 1), 289 | conv_bn(1024, 1024, 2), 290 | conv_1x1_bn(1024, 1280), 291 | ) 292 | self.num_features = 1280 293 | elif config=='SelecSLS102': 294 | print('SelecSLS102') 295 | #Define configuration of the network after the initial neck 296 | self.selecSLS_config = [ 297 | #inp,skip, k, oup, isFirst, stride 298 | [ 32, 0, 64, 64, True, 2], 299 | [ 64, 64, 64, 64, False, 1], 300 | [ 64, 64, 64, 64, False, 1], 301 | [ 64, 64, 64, 128, False, 1], 302 | [128, 0, 128, 128, True, 2], 303 | [128, 128, 128, 128, False, 1], 304 | [128, 128, 128, 128, False, 1], 305 | [128, 128, 128, 128, False, 1], 306 | [128, 128, 128, 288, False, 1], 307 | [288, 0, 288, 288, True, 2], 308 | [288, 288, 288, 288, False, 1], 309 | [288, 288, 288, 288, False, 1], 310 | [288, 288, 288, 288, False, 1], 311 | [288, 288, 288, 288, False, 1], 312 | [288, 288, 288, 288, False, 1], 313 | [288, 288, 288, 480, False, 1], 314 | ] 315 | #Head can be replaced with alternative configurations depending on the problem 316 | self.head = nn.Sequential( 317 | conv_bn(480, 960, 2), 318 | conv_bn(960, 1024, 1), 319 | conv_bn(1024, 1024, 2), 320 | conv_1x1_bn(1024, 1280), 321 | ) 322 | self.num_features = 1280 323 | else: 324 | raise ValueError('Invalid net configuration '+config+' !!!') 325 | 326 | #Build SelecSLS Core 327 | for inp, skip, k, oup, isFirst, stride in self.selecSLS_config: 328 | self.features.append(SelecSLSBlock(inp, skip, k, oup, isFirst, stride)) 329 | self.features = nn.Sequential(*self.features) 330 | 331 | #Classifier To Produce Inputs to Softmax 332 | self.classifier = nn.Sequential( 333 | nn.Linear(self.num_features, nClasses), 334 | ) 335 | 336 | 337 | def forward(self, x): 338 | x = self.stem(x) 339 | x = self.features([x]) 340 | x = self.head(x[0]) 341 | x = x.mean(3).mean(2) 342 | x = self.classifier(x) 343 | #x = F.log_softmax(x) 344 | return x 345 | 346 | 347 | 348 | def prune_and_fuse(self, gamma_thresh, verbose=False): 349 | ''' Function that iterates over the modules in the model and prunes different parts by name. Sparsity emerges implicitly due to the use of 350 | adaptive gradient descent approaches such as Adam, in conjunction with L2 or WD regularization on the parameters. The filters 351 | that are implicitly zeroed out can be explicitly pruned without any impact on the model accuracy (and might even improve in some cases). 352 | ''' 353 | #This function assumes a specific structure. If the structure of stem or head is changed, this code would need to be changed too 354 | #Also, this be ugly. Needs to be written better, but is at least functional 355 | #Perhaps one need not worry about the layers made redundant, they can be removed from storage by tracing with the JIT module?? 356 | 357 | #We bring everything to the CPU, then later restore the device 358 | device = next(self.parameters()).device 359 | self.to("cpu") 360 | with torch.no_grad(): 361 | #Assumes that stem is flat and has conv,bn,relu in order. Can handle one or more of these if one wants to deepen the stem. 362 | new_stem = [] 363 | input_validity = torch.ones(3) 364 | for i in range(0,len(self.stem),3): 365 | input_size = sum(input_validity.int()).item() 366 | #Calculate the extent of sparsity 367 | out_validity = abs(self.stem[i+1].weight) > gamma_thresh 368 | out_size = sum(out_validity.int()).item() 369 | W, b = bn_fuse(self.stem[i],self.stem[i+1]) 370 | new_stem.append(nn.Conv2d(input_size, out_size, kernel_size = self.stem[i].kernel_size, stride=self.stem[i].stride, padding = self.stem[i].padding)) 371 | new_stem.append(nn.ReLU(inplace=True)) 372 | new_stem[-2].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(input_validity).squeeze()), 0, torch.nonzero(out_validity).squeeze())) 373 | new_stem[-2].bias.copy_(b[out_validity]) 374 | input_validity = out_validity.clone().detach() 375 | if verbose: 376 | print('Stem '+str(len(new_stem)/2 -1)+': Pruned '+str(len(out_validity) - out_size) + ' from '+str(len(out_validity))) 377 | self.stem = nn.Sequential(*new_stem) 378 | 379 | new_features = [] 380 | skip_validity = 0 381 | for i in range(len(self.features)): 382 | inp = int(sum(input_validity.int()).item()) 383 | if self.features[i].isFirst: 384 | skip = 0 385 | a_validity = abs(self.features[i].conv1[1].weight) > gamma_thresh 386 | b_validity = abs(self.features[i].conv2[1].weight) > gamma_thresh 387 | c_validity = abs(self.features[i].conv3[1].weight) > gamma_thresh 388 | d_validity = abs(self.features[i].conv4[1].weight) > gamma_thresh 389 | e_validity = abs(self.features[i].conv5[1].weight) > gamma_thresh 390 | out_validity = abs(self.features[i].conv6[1].weight) > gamma_thresh 391 | 392 | new_features.append(SelecSLSBlockFused(inp, skip, int(sum(a_validity.int()).item()),int(sum(b_validity.int()).item()),int(sum(c_validity.int()).item()),int(sum(d_validity.int()).item()),int(sum(e_validity.int()).item()), int(sum(out_validity.int()).item()), self.features[i].isFirst, self.features[i].stride)) 393 | 394 | #Conv1 395 | i_validity = input_validity.clone().detach() 396 | o_validity = a_validity.clone().detach() 397 | W, bias = bn_fuse(self.features[i].conv1[0], self.features[i].conv1[1]) 398 | new_features[i].conv1[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze())) 399 | new_features[i].conv1[0].bias.copy_(bias[o_validity]) 400 | if verbose: 401 | print('features.'+str(i)+'.conv1: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity))) 402 | #Conv2 403 | i_validity = o_validity.clone().detach() 404 | o_validity = b_validity.clone().detach() 405 | W, bias = bn_fuse(self.features[i].conv2[0], self.features[i].conv2[1]) 406 | new_features[i].conv2[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze())) 407 | new_features[i].conv2[0].bias.copy_(bias[o_validity]) 408 | if verbose: 409 | print('features.'+str(i)+'.conv2: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity))) 410 | #Conv3 411 | i_validity = o_validity.clone().detach() 412 | o_validity = c_validity.clone().detach() 413 | W, bias = bn_fuse(self.features[i].conv3[0], self.features[i].conv3[1]) 414 | new_features[i].conv3[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze())) 415 | new_features[i].conv3[0].bias.copy_(bias[o_validity]) 416 | if verbose: 417 | print('features.'+str(i)+'.conv3: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity))) 418 | #Conv4 419 | i_validity = o_validity.clone().detach() 420 | o_validity = d_validity.clone().detach() 421 | W, bias = bn_fuse(self.features[i].conv4[0], self.features[i].conv4[1]) 422 | new_features[i].conv4[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze())) 423 | new_features[i].conv4[0].bias.copy_(bias[o_validity]) 424 | if verbose: 425 | print('features.'+str(i)+'.conv4: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity))) 426 | #Conv5 427 | i_validity = o_validity.clone().detach() 428 | o_validity = e_validity.clone().detach() 429 | W, bias = bn_fuse(self.features[i].conv5[0], self.features[i].conv5[1]) 430 | new_features[i].conv5[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze())) 431 | new_features[i].conv5[0].bias.copy_(bias[o_validity]) 432 | if verbose: 433 | print('features.'+str(i)+'.conv5: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity))) 434 | #Conv6 435 | i_validity = torch.cat([a_validity.clone().detach(), c_validity.clone().detach(), e_validity.clone().detach()], 0) 436 | if self.features[i].isFirst: 437 | skip = int(sum(out_validity.int()).item()) 438 | skip_validity = out_validity.clone().detach() 439 | else: 440 | i_validity = torch.cat([i_validity, skip_validity], 0) 441 | o_validity = out_validity.clone().detach() 442 | W, bias = bn_fuse(self.features[i].conv6[0], self.features[i].conv6[1]) 443 | new_features[i].conv6[0].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(i_validity).squeeze()), 0, torch.nonzero(o_validity).squeeze())) 444 | new_features[i].conv6[0].bias.copy_(bias[o_validity]) 445 | if verbose: 446 | print('features.'+str(i)+'.conv6: Pruned '+str(len(o_validity) - sum(o_validity.int()).item()) + ' from '+str(len(o_validity))) 447 | 448 | input_validity = out_validity.clone().detach() 449 | self.features = nn.Sequential(*new_features) 450 | 451 | new_head = [] 452 | for i in range(len(self.head)): 453 | input_size = int(sum(input_validity.int()).item()) 454 | #Calculate the extent of sparsity 455 | out_validity = abs(self.head[i][1].weight) > gamma_thresh 456 | out_size = int(sum(out_validity.int()).item()) 457 | W, b = bn_fuse(self.head[i][0],self.head[i][1]) 458 | new_head.append(nn.Conv2d(input_size, out_size, kernel_size = self.head[i][0].kernel_size, stride=self.head[i][0].stride, padding = self.head[i][0].padding)) 459 | new_head.append(nn.ReLU(inplace=True)) 460 | new_head[-2].weight.copy_( torch.index_select(torch.index_select(W, 1, torch.nonzero(input_validity).squeeze()), 0, torch.nonzero(out_validity).squeeze())) 461 | new_head[-2].bias.copy_(b[out_validity]) 462 | input_validity = out_validity.clone().detach() 463 | if verbose: 464 | print('Head '+str(len(new_head)/2 -1)+': Pruned '+str(len(out_validity) - out_size) + ' from '+str(len(out_validity))) 465 | self.head = nn.Sequential(*new_head) 466 | 467 | new_classifier = [] 468 | new_classifier.append(nn.Linear(int(sum(input_validity.int()).item()), self.classifier[0].weight.shape[0])) 469 | new_classifier[0].weight.copy_(torch.index_select(self.classifier[0].weight, 1, torch.nonzero(input_validity).squeeze())) 470 | new_classifier[0].bias.copy_(self.classifier[0].bias) 471 | self.classifier = nn.Sequential(*new_classifier) 472 | 473 | self.to(device) 474 | 475 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mehtadushy/SelecSLS-Pytorch/3852734af392b8fa69834984c76856dbd39179a3/util/__init__.py -------------------------------------------------------------------------------- /util/imagenet_data_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import time 4 | import path 5 | import math 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import DataLoader 14 | 15 | def get_data_loader(augment=False, batch_size=50, base_path="path_to_ImageNet"): 16 | 17 | print('Loading ImageNet in all its glory...') 18 | dataset = dset.ImageFolder 19 | 20 | # Prepare transforms and data augmentation 21 | norm_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 22 | train_transform = transforms.Compose([ 23 | transforms.RandomResizedCrop(224), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | norm_transform 27 | ]) 28 | test_transform = transforms.Compose([ 29 | transforms.Resize(256), 30 | transforms.CenterCrop(224), 31 | transforms.ToTensor(), 32 | norm_transform 33 | ]) 34 | kwargs = {'num_workers': 8, 'pin_memory': True} 35 | 36 | train_set = dataset( 37 | root=base_path+'/train/', 38 | transform=train_transform if augment else test_transform) 39 | test_set = dataset(base_path+'/val/', 40 | transform=test_transform) 41 | 42 | # Prepare data loaders 43 | train_loader = DataLoader(train_set, batch_size=batch_size, 44 | shuffle=True, **kwargs) 45 | test_loader = DataLoader(test_set, batch_size=batch_size, 46 | shuffle=False, **kwargs) 47 | 48 | return train_loader, test_loader 49 | -------------------------------------------------------------------------------- /weights/readme.txt: -------------------------------------------------------------------------------- 1 | Get pretrained imagenet models from: 2 | 3 | Old Training Results 4 | SelecSLS60: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict.pth 5 | SelecSLS84: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS84_statedict.pth 6 | New Fangled Training Results 7 | SelecSLS60: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_statedict_better.pth 8 | SelecSLS42_B: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS42_B_statedict.pth 9 | SelecSLS60_B: http://gvv.mpi-inf.mpg.de/projects/XNect/assets/models/SelecSLS60_B_statedict.pth 10 | --------------------------------------------------------------------------------