├── .gitignore ├── LICENSE ├── README.md ├── classify.py ├── data └── pascal_seg_colormap.mat ├── docker_classify.py ├── figures └── figure.png ├── models ├── __init__.py ├── densenet.py ├── dla.py ├── dla_up.py ├── resnext.py └── utils.py ├── multilabel_classify.py ├── segment.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ELASTIC 2 | This repo contains the original PyTorch implementation of Elastic introduced in the following paper 3 | 4 | [ELASTIC: Improving CNNs with Dynamic Scaling Policies](https://arxiv.org/abs/1812.05262) (CVPR 2019, Oral) 5 | 6 | [Huiyu Wang](https://csrhddlam.github.io/), [Aniruddha Kembhavi](https://anikem.github.io/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Alan Yuille](http://www.cs.jhu.edu/~ayuille/), and [Mohammad Rastegari](https://allenai.org/team/mohammadr/) 7 | 8 | It is compatible with PyTorch 1.0-stable, PyTorch 1.0-preview and PyTorch 0.4.1. All released models are exactly the models evaluated in the paper. 9 | 10 | ## Contents 11 | * [ImageNet Classification](#imagenet-classification) 12 | * [MSCOCO Multi-label Classification](#mscoco-multi-label-classification) 13 | * [PASCAL VOC Semantic Segmentation](#pascal-voc-semantic-segmentation) 14 | 15 | ## ImageNet Classification 16 | We prepare our data following https://github.com/pytorch/examples/tree/master/imagenet 17 | 18 | Pretrained models available at 19 | ``` 20 | for a in resnext50 resnext50_elastic resnext101 resnext101_elastic dla60x dla60x_elastic dla102x se_resnext50_elastic densenet201 densenet201_elastic; do 21 | wget http://ai2-vision.s3.amazonaws.com/elastic/imagenet_models/"$a".pth.tar 22 | done 23 | ``` 24 | ### Testing 25 | ``` 26 | python classify.py /path/to/imagenet/ --evaluate --resume /path/to/model.pth.tar 27 | ``` 28 | ### Training 29 | ``` 30 | python classify.py /path/to/imagenet/ 31 | ``` 32 | ### Multi-processing distributed training in Docker (recommended): 33 | We train all the models in docker containers: https://docs.nvidia.com/deeplearning/dgx/pytorch-release-notes/rel_18.07.html 34 | 35 | You may need to follow instructions in the link above to install [docker](https://www.docker.com/) and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) if you haven't done so. 36 | 37 | After pulling the docker image, we run a docker container: 38 | ``` 39 | nvidia-docker run -it -e NVIDIA_VISIBLE_DEVICES=0,1 --ipc=host --rm -v /path/to/code:/path/to/code -v /path/to/imagenet:/path/to/imagenet nvcr.io/nvidia/pytorch:18.07-py3 40 | ``` 41 | Then run this training script inside the docker container. 42 | ``` 43 | python -m apex.parallel.multiproc docker_classify.py /path/to/imagenet 44 | ``` 45 | ## MSCOCO Multi-label Classification 46 | We extract data into this structure and use python cocoapi to load data: https://github.com/cocodataset/cocoapi 47 | ``` 48 | /path/to/mscoco/annotations/instances_train2014.json 49 | /path/to/mscoco/annotations/instances_val2014.json 50 | /path/to/mscoco/train2014 51 | /path/to/mscoco/val2014 52 | ``` 53 | Pretrained models available at 54 | ``` 55 | for a in resnext50 resnext50_elastic resnext101 resnext101_elastic dla60x dla60x_elastic densenet201 densenet201_elastic; do 56 | wget http://ai2-vision.s3.amazonaws.com/elastic/coco_models/coco_"$a".pth.tar 57 | done 58 | ``` 59 | ### Testing 60 | ``` 61 | python multilabel_classify.py /path/to/mscoco --resume /path/to/model.pth.tar --evaluate 62 | ``` 63 | ### Finetuning or resume training 64 | ``` 65 | python multilabel_classify.py /path/to/mscoco --resume /path/to/model.pth.tar 66 | ``` 67 | ## PASCAL VOC Semantic Segmentation 68 | We prepare PASCAL VOC data following https://github.com/chenxi116/DeepLabv3.pytorch 69 | 70 | Pretrained models available at 71 | ``` 72 | for a in resnext50 resnext50_elastic resnext101 resnext101_elastic dla60x dla60x_elastic; do 73 | wget http://ai2-vision.s3.amazonaws.com/elastic/pascal_models/deeplab_"$a"_pascal_v3_original_epoch50.pth 74 | done 75 | ``` 76 | ### Testing 77 | Models should be put at data/deeplab_*.pth 78 | ``` 79 | CUDA_VISIBLE_DEVICES=0 python segment.py --exp original 80 | ``` 81 | ### Finetuning or resume training 82 | All PASCAL VOC semantic segmentation models are trained on one GPU. 83 | ``` 84 | CUDA_VISIBLE_DEVICES=0 python segment.py --exp my_exp --train --resume /path/to/model.pth.tar 85 | ``` 86 | ## Note 87 | Distributed training maintains batchnorm statistics on each GPU/worker/process without synchronization, which leads to different performances on different GPUs. At the end of each epoch, our distributed script reports averaged performance (top-1, top-5) by evaluating the whole validation set on all GPUs, and saves the model on the first GPU (throws away models on other GPUs). As a result, evaluating the saved model after training leads to slightly (<0.1%) different (could be either better or worse) numbers. In the paper, we reported the average performances for all models. Averaging batchnorm statistics before evaluation may lead to marginally better numbers. 88 | 89 | ## Citation 90 | Please consider citing this paper if you find this project useful in your research. 91 | ``` 92 | @article{wang2019elastic, 93 | title={ELASTIC: Improving CNNs with Dynamic Scaling Policies}, 94 | author={Huiyu Wang, Aniruddha Kembhavi, Ali Farhadi, Alan Yuille, Mohammad Rastegari}, 95 | journal={The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 96 | year={2019} 97 | } 98 | ``` 99 | ## Credits 100 | * ImageNet training script is modified from https://github.com/pytorch/pytorch 101 | * ImageNet distributed training script is modified from https://github.com/NVIDIA/apex 102 | * Pascal segmentation code is modified from https://github.com/chenxi116/DeepLabv3.pytorch 103 | * ResNext model is modified form https://github.com/last-one/tools 104 | * DLA models are modified from https://github.com/ucbdrive/dla 105 | * DenseNet model is modified from https://github.com/csrhddlam/pytorch-checkpoint 106 | 107 | -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.optim 11 | import torch.utils.data as data 12 | import torch.utils.data.distributed 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | import models 16 | import os 17 | import datetime 18 | from utils import add_flops_counting_methods, accuracy, save_checkpoint, AverageMeter 19 | 20 | 21 | model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', 22 | 'dla60x', 'dla60x_elastic', 'dla102x', 'dla102x_elastic', 23 | 'se_resnext50', 'se_resnext50_elastic', 'densenet201', 'densenet201_elastic'] 24 | 25 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 26 | parser.add_argument('data', metavar='DIR', help='path to dataset') 27 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, 28 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') 29 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 30 | help='number of data loading workers (default: 16)') 31 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=256, type=int, 36 | metavar='N', help='mini-batch size (default: 256)') 37 | parser.add_argument('-g', '--num-gpus', default=8, type=int, 38 | metavar='N', help='number of GPUs to match (default: 8)') 39 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 40 | metavar='LR', help='initial learning rate') 41 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 42 | help='momentum') 43 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 44 | metavar='W', help='weight decay (default: 1e-4)') 45 | parser.add_argument('--print-freq', '-p', default=117, type=int, 46 | metavar='N', help='print frequency (default: 117)') 47 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 48 | help='path to latest checkpoint (default: none)') 49 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 50 | help='evaluate model on validation set') 51 | parser.add_argument('--world-size', default=1, type=int, 52 | help='number of distributed processes') 53 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 54 | help='url used to set up distributed training') 55 | parser.add_argument('--dist-backend', default='gloo', type=str, 56 | help='distributed backend') 57 | 58 | best_err1 = 100 59 | 60 | 61 | def main(): 62 | global args, best_err1 63 | args = parser.parse_args() 64 | print('config: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'num_gpus', args.num_gpus) 65 | iteration_size = args.num_gpus // torch.cuda.device_count() # do multiple iterations 66 | assert iteration_size >= 1 67 | args.weight_decay = args.weight_decay * iteration_size # will cancel out with lr 68 | args.lr = args.lr / iteration_size 69 | args.batch_size = args.batch_size // iteration_size 70 | print('real: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'iteration_size', iteration_size) 71 | 72 | args.distributed = args.world_size > 1 73 | 74 | if args.distributed: 75 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 76 | world_size=args.world_size) 77 | 78 | # create model 79 | print("=> creating model '{}'".format(args.arch)) 80 | model = models.__dict__[args.arch]() 81 | 82 | # count number of parameters 83 | count = 0 84 | params = list() 85 | for n, p in model.named_parameters(): 86 | if '.ups.' not in n: 87 | params.append(p) 88 | count += np.prod(p.size()) 89 | print('Parameters:', count) 90 | 91 | # count flops 92 | model = add_flops_counting_methods(model) 93 | model.eval() 94 | image = torch.randn(1, 3, 224, 224) 95 | 96 | model.start_flops_count() 97 | model(image).sum() 98 | model.stop_flops_count() 99 | print("GFLOPs", model.compute_average_flops_cost() / 1000000000.0) 100 | 101 | # normal code 102 | if not args.distributed: 103 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 104 | model.features = torch.nn.DataParallel(model.features) 105 | model.cuda() 106 | else: 107 | model = torch.nn.DataParallel(model).cuda() 108 | else: 109 | model.cuda() 110 | model = torch.nn.parallel.DistributedDataParallel(model) 111 | 112 | # cuda warm up 113 | model = model.cuda() 114 | image = torch.randn(args.batch_size, 3, 224, 224) 115 | image_cuda = image.cuda() 116 | 117 | for i in range(3): 118 | start = time.time() 119 | model(image_cuda).sum().backward() # Warmup CUDA memory allocator 120 | print(time.time() - start) 121 | 122 | # with torch.autograd.profiler.profile(use_cuda=True) as prof: 123 | # start = time.time() 124 | # model(image_cuda).sum().backward() 125 | # print(time.time() - start) 126 | # prof.export_chrome_trace('trace_gpu') 127 | 128 | # import cProfile, pstats, io 129 | # pr = cProfile.Profile(time.perf_counter) 130 | # pr.enable() 131 | # model(image_cuda).sum().backward() 132 | # pr.disable() 133 | # s = io.StringIO() 134 | # sortby = 'cumulative' 135 | # ps = pstats.Stats(pr, stream=s).sort_stats(sortby) 136 | # ps.print_stats() 137 | # print(s.getvalue()) 138 | 139 | # define loss function (criterion) and optimizer 140 | criterion = nn.CrossEntropyLoss().cuda() 141 | optimizer = torch.optim.SGD([{'params': iter(params), 'lr': args.lr}, 142 | ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 143 | 144 | # optionally resume from a checkpoint 145 | if args.resume: 146 | if os.path.isfile(args.resume): 147 | print("=> loading checkpoint '{}'".format(args.resume)) 148 | checkpoint = torch.load(args.resume) 149 | 150 | model.load_state_dict(checkpoint['state_dict'], strict=False) if 'state_dict' in checkpoint else print('no state_dict found') 151 | optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else print('no optimizer found') 152 | args.start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else args.start_epoch 153 | best_err1 = checkpoint['best_err1'] if 'best_err' in checkpoint else best_err1 154 | 155 | print("=> loaded checkpoint '{}' (epoch {})" 156 | .format(args.resume, checkpoint['epoch'] if 'epoch' in checkpoint else 'unknown')) 157 | else: 158 | print("=> no checkpoint found at '{}'".format(args.resume)) 159 | 160 | cudnn.benchmark = True 161 | 162 | # Data loading code 163 | traindir = os.path.join(args.data, 'train') 164 | valdir = os.path.join(args.data, 'val') 165 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 166 | std=[0.229, 0.224, 0.225]) 167 | 168 | train_dataset = datasets.ImageFolder( 169 | traindir, 170 | transforms.Compose([ 171 | transforms.RandomResizedCrop(224), 172 | transforms.RandomHorizontalFlip(), 173 | transforms.ToTensor(), 174 | normalize, 175 | ])) 176 | 177 | if args.distributed: 178 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 179 | else: 180 | train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) 181 | 182 | train_loader = torch.utils.data.DataLoader( 183 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 184 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 185 | 186 | val_loader = torch.utils.data.DataLoader( 187 | datasets.ImageFolder(valdir, transforms.Compose([ 188 | transforms.Resize(256), 189 | transforms.CenterCrop(224), 190 | transforms.ToTensor(), 191 | normalize, 192 | ])), 193 | batch_size=args.batch_size, shuffle=False, 194 | num_workers=args.workers, pin_memory=True) 195 | 196 | if args.evaluate: 197 | validate(val_loader, model, criterion) 198 | return 199 | 200 | for epoch in range(args.start_epoch, args.epochs): 201 | if args.distributed: 202 | train_sampler.set_epoch(epoch) 203 | adjust_learning_rate(optimizer, epoch) 204 | 205 | # train for one epoch 206 | train(train_loader, model, criterion, optimizer, epoch, iteration_size) 207 | 208 | # evaluate on validation set 209 | err1 = validate(val_loader, model, criterion) 210 | 211 | # remember best err@1 and save checkpoint 212 | is_best = err1 < best_err1 213 | best_err1 = min(err1, best_err1) 214 | save_checkpoint({ 215 | 'epoch': epoch + 1, 216 | 'arch': args.arch, 217 | 'state_dict': model.state_dict(), 218 | 'best_err1': best_err1, 219 | 'optimizer': optimizer.state_dict(), 220 | }, is_best, filename=args.arch + '_checkpoint.pth.tar') 221 | print(str(float(best_err1))) 222 | 223 | 224 | def train(train_loader, model, criterion, optimizer, epoch, iteration_size): 225 | batch_time = AverageMeter() 226 | data_time = AverageMeter() 227 | losses = AverageMeter() 228 | top1 = AverageMeter() 229 | top5 = AverageMeter() 230 | 231 | # switch to train mode 232 | model.train() 233 | optimizer.zero_grad() 234 | 235 | end = time.time() 236 | for i, (input, target) in enumerate(train_loader): 237 | # measure data loading time 238 | data_time.update(time.time() - end) 239 | target = target.cuda(non_blocking=True) 240 | 241 | # compute output 242 | output = model(input) 243 | loss = criterion(output, target) 244 | 245 | # measure accuracy and record loss 246 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 247 | losses.update(float(loss), input.size(0)) 248 | top1.update(100 - float(prec1), input.size(0)) 249 | top5.update(100 - float(prec5), input.size(0)) 250 | # compute gradient and do SGD step 251 | loss.backward() 252 | 253 | if i % iteration_size == iteration_size - 1: 254 | optimizer.step() 255 | optimizer.zero_grad() 256 | # measure elapsed time 257 | batch_time.update(time.time() - end) 258 | end = time.time() 259 | 260 | if i % args.print_freq == 0: 261 | print('Epoch: [{0}][{1}/{2}]\t' 262 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 263 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 264 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 265 | 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' 266 | 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 267 | epoch, i, len(train_loader), batch_time=batch_time, 268 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 269 | 270 | 271 | def validate(val_loader, model, criterion): 272 | batch_time = AverageMeter() 273 | losses = AverageMeter() 274 | top1 = AverageMeter() 275 | top5 = AverageMeter() 276 | 277 | # switch to evaluate mode 278 | model.eval() 279 | 280 | end = time.time() 281 | for i, (input, target) in enumerate(val_loader): 282 | target = target.cuda(non_blocking=True) 283 | 284 | # compute output 285 | with torch.no_grad(): 286 | output = model(input) 287 | loss = criterion(output, target) 288 | 289 | # measure accuracy and record loss 290 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 291 | losses.update(float(loss), input.size(0)) 292 | top1.update(100 - float(prec1), input.size(0)) 293 | top5.update(100 - float(prec5), input.size(0)) 294 | 295 | # measure elapsed time 296 | batch_time.update(time.time() - end) 297 | end = time.time() 298 | 299 | if i % args.print_freq == 0: 300 | print('Test: [{0}/{1}]\t' 301 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 302 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 303 | 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' 304 | 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 305 | i, len(val_loader), batch_time=batch_time, loss=losses, 306 | top1=top1, top5=top5)) 307 | 308 | print(str(datetime.datetime.now()) + ' * Err@1 {top1.avg:.3f} Err@5 {top5.avg:.3f}' 309 | .format(top1=top1, top5=top5)) 310 | return top1.avg 311 | 312 | 313 | def adjust_learning_rate(optimizer, epoch): 314 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 315 | lr = args.lr * (0.1 ** (epoch // 30)) 316 | for param_group in optimizer.param_groups: 317 | param_group['lr'] = lr 318 | 319 | 320 | if __name__ == '__main__': 321 | main() 322 | -------------------------------------------------------------------------------- /data/pascal_seg_colormap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/elastic/57345c600c63fbde163c41929d6d6dd894d408ce/data/pascal_seg_colormap.mat -------------------------------------------------------------------------------- /docker_classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | # import torchvision.models as models 18 | import models 19 | import numpy as np 20 | import gc 21 | 22 | from utils import add_flops_counting_methods, save_checkpoint, AverageMeter, accuracy 23 | 24 | try: 25 | from apex.parallel import DistributedDataParallel as DDP 26 | from apex.fp16_utils import * 27 | except ImportError: 28 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 29 | 30 | model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', 31 | 'dla60x', 'dla60x_elastic', 'dla102x', 'dla102x_elastic', 32 | 'se_resnext50', 'se_resnext50_elastic', 'densenet201', 'densenet201_elastic'] 33 | 34 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 35 | parser.add_argument('data', metavar='DIR', help='path to dataset') 36 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, 37 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') 38 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 39 | help='number of data loading workers (default: 8)') 40 | parser.add_argument('--epochs', default=120, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 43 | help='manual epoch number (useful on restarts)') 44 | parser.add_argument('-b', '--batch-size', default=32, type=int, 45 | metavar='N', help='mini-batch size (default: 32)') 46 | parser.add_argument('-g', '--num-gpus', default=8, type=int, 47 | metavar='N', help='number of GPUs we pretend to have (default: 8)') 48 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 49 | metavar='LR', help='initial learning rate') 50 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 51 | help='momentum') 52 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 53 | metavar='W', help='weight decay (default: 1e-4)') 54 | parser.add_argument('--print-freq', '-p', default=117, type=int, 55 | metavar='N', help='print frequency (default: 117)') 56 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 57 | help='path to latest checkpoint (default: none)') 58 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 59 | help='evaluate model on validation set') 60 | parser.add_argument('--fp16', action='store_true', 61 | help='Run model fp16 mode.') 62 | parser.add_argument('--static-loss-scale', type=float, default=1, 63 | help='Static loss scale, positive power of 2 values can improve fp16 convergence.') 64 | parser.add_argument('--dist-url', default='file://sync.file', type=str, 65 | help='url used to set up distributed training') 66 | parser.add_argument('--dist-backend', default='nccl', type=str, 67 | help='distributed backend') 68 | parser.add_argument('--world-size', default=1, type=int, 69 | help='Number of GPUs to use. Can either be manually set ' + 70 | 'or automatically set by using \'python -m multiproc\'.') 71 | parser.add_argument('--rank', default=0, type=int, 72 | help='Used for multi-process training. Can either be manually set ' + 73 | 'or automatically set by using \'python -m multiproc\'.') 74 | 75 | cudnn.benchmark = True 76 | 77 | 78 | def fast_collate(batch): 79 | imgs = [img[0] for img in batch] 80 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 81 | w = imgs[0].size[0] 82 | h = imgs[0].size[1] 83 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) 84 | for i, img in enumerate(imgs): 85 | nump_array = np.asarray(img, dtype=np.uint8) 86 | tens = torch.from_numpy(nump_array) 87 | if(nump_array.ndim < 3): 88 | nump_array = np.expand_dims(nump_array, axis=-1) 89 | nump_array = np.rollaxis(nump_array, 2) 90 | 91 | tensor[i] += torch.from_numpy(nump_array) 92 | 93 | return tensor, targets 94 | 95 | 96 | best_err1 = 100 97 | args = parser.parse_args() 98 | 99 | 100 | def main(): 101 | global best_err1, args 102 | 103 | iteration_size = args.num_gpus // args.world_size 104 | args.weight_decay = args.weight_decay * iteration_size # will cancel out with lr 105 | args.lr = args.lr / iteration_size 106 | print('real: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'iteration_size', iteration_size) 107 | args.distributed = args.world_size > 1 108 | args.gpu = 0 109 | if args.distributed: 110 | args.gpu = args.rank % torch.cuda.device_count() 111 | 112 | if args.distributed: 113 | torch.cuda.set_device(args.gpu) 114 | dist.init_process_group(backend=args.dist_backend, 115 | init_method=args.dist_url, 116 | world_size=args.world_size, 117 | rank=args.rank) 118 | 119 | if args.fp16: 120 | assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." 121 | 122 | # create model 123 | print("=> creating model '{}'".format(args.arch)) 124 | model = models.__dict__[args.arch]() 125 | 126 | # count number of parameters 127 | count = 0 128 | params = list() 129 | for n, p in model.named_parameters(): 130 | if '.ups.' not in n: 131 | params.append(p) 132 | count += np.prod(p.size()) 133 | print('Parameters:', count) 134 | 135 | # count flops 136 | model = add_flops_counting_methods(model) 137 | model.eval() 138 | image = torch.randn(1, 3, 224, 224) 139 | 140 | model.start_flops_count() 141 | model(image).sum() 142 | model.stop_flops_count() 143 | print("GFLOPs", model.compute_average_flops_cost() / 1000000000.0) 144 | 145 | model = model.cuda() 146 | if args.fp16: 147 | model = network_to_half(model) 148 | if args.distributed: 149 | #shared param turns off bucketing in DDP, for lower latency runs this can improve perf 150 | model = DDP(model, shared_param=True) 151 | 152 | global model_params, master_params 153 | if args.fp16: 154 | model_params, master_params = prep_param_lists(model) 155 | else: 156 | master_params = list(model.parameters()) 157 | 158 | # define loss function (criterion) and optimizer 159 | criterion = nn.CrossEntropyLoss().cuda() 160 | optimizer = torch.optim.SGD([{'params': iter(params), 'lr': args.lr}, 161 | ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 162 | 163 | # optionally resume from a checkpoint 164 | if args.resume: 165 | if os.path.isfile(args.resume): 166 | print("=> loading checkpoint '{}'".format(args.resume)) 167 | checkpoint = torch.load(args.resume) 168 | 169 | model.load_state_dict(checkpoint['state_dict'], strict=False) if 'state_dict' in checkpoint else print('no state_dict found') 170 | optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else print('no optimizer found') 171 | args.start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else args.start_epoch 172 | best_err1 = checkpoint['best_err1'] if 'best_err' in checkpoint else best_err1 173 | 174 | print("=> loaded checkpoint '{}' (epoch {})" 175 | .format(args.resume, checkpoint['epoch'] if 'epoch' in checkpoint else 'unknown')) 176 | else: 177 | print("=> no checkpoint found at '{}'".format(args.resume)) 178 | 179 | # Data loading code 180 | traindir = os.path.join(args.data, 'train') 181 | valdir = os.path.join(args.data, 'val') 182 | 183 | crop_size = 224 184 | val_size = 256 185 | 186 | train_dataset = datasets.ImageFolder( 187 | traindir, 188 | transforms.Compose([ 189 | transforms.RandomResizedCrop(crop_size), 190 | transforms.RandomHorizontalFlip(), 191 | # transforms.ToTensor(), Too slow 192 | # normalize, 193 | ])) 194 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ 195 | transforms.Resize(val_size), 196 | transforms.CenterCrop(crop_size), 197 | ])) 198 | if args.distributed: 199 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 200 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 201 | else: 202 | train_sampler = None 203 | 204 | train_loader = torch.utils.data.DataLoader( 205 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 206 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate, drop_last=True) 207 | 208 | val_loader = torch.utils.data.DataLoader( 209 | datasets.ImageFolder(valdir, transforms.Compose([ 210 | transforms.Resize(val_size), 211 | transforms.CenterCrop(crop_size), 212 | ])), 213 | batch_size=args.batch_size, shuffle=False, 214 | num_workers=args.workers, pin_memory=True, 215 | collate_fn=fast_collate) 216 | # print(len(train_loader), len(val_loader)) 217 | if args.evaluate: 218 | validate(val_loader, model, criterion) 219 | return 220 | 221 | for epoch in range(args.start_epoch, args.epochs): 222 | if args.distributed: 223 | train_sampler.set_epoch(epoch) 224 | adjust_learning_rate(optimizer, epoch) 225 | print('allocated before', torch.cuda.memory_allocated()) 226 | print('cached before', torch.cuda.memory_cached()) 227 | gc.collect() 228 | torch.cuda.empty_cache() 229 | print('allocated after', torch.cuda.memory_allocated()) 230 | print('cached after', torch.cuda.memory_cached()) 231 | # train for one epoch 232 | train(train_loader, model, criterion, optimizer, epoch, iteration_size) 233 | 234 | # # sync models on multiple GPUs 235 | # if args.rank == 0: 236 | # save_checkpoint({ 237 | # 'epoch': epoch + 1, 238 | # 'arch': args.arch, 239 | # 'state_dict': model.state_dict(), 240 | # 'optimizer' : optimizer.state_dict(), 241 | # }, False, 'temp.pth.tar') 242 | # # barrier 243 | # loss = torch.FloatTensor([args.rank]).cuda() 244 | # reduced_loss = reduce_tensor(loss.data) 245 | # print(loss.data, reduced_loss) 246 | # if os.path.isfile('temp.pth.tar'): 247 | # print("=> loading checkpoint '{}'".format('temp.pth.tar')) 248 | # checkpoint = torch.load('temp.pth.tar', map_location = lambda storage, loc: storage.cuda(args.gpu)) 249 | # model.load_state_dict(checkpoint['state_dict'], strict=False) 250 | # optimizer.load_state_dict(checkpoint['optimizer']) 251 | # print("=> loaded checkpoint '{}' (epoch {})" 252 | # .format('temp.pth.tar', checkpoint['epoch'])) 253 | # assert checkpoint['epoch'] == epoch + 1 254 | 255 | # evaluate on validation set 256 | err1 = validate(val_loader, model, criterion) 257 | # remember best err@1 and save checkpoint 258 | if args.rank == 0: 259 | is_best = err1 < best_err1 260 | best_err1 = min(err1, best_err1) 261 | save_checkpoint({ 262 | 'epoch': epoch + 1, 263 | 'arch': args.arch, 264 | 'state_dict': model.state_dict(), 265 | 'best_err1': best_err1, 266 | 'optimizer': optimizer.state_dict(), 267 | }, is_best) 268 | print(str(float(best_err1))) 269 | 270 | 271 | class data_prefetcher(): 272 | def __init__(self, loader): 273 | self.loader = iter(loader) 274 | self.stream = torch.cuda.Stream() 275 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 276 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 277 | if args.fp16: 278 | self.mean = self.mean.half() 279 | self.std = self.std.half() 280 | self.preload() 281 | 282 | def preload(self): 283 | try: 284 | self.next_input, self.next_target = next(self.loader) 285 | except StopIteration: 286 | self.next_input = None 287 | self.next_target = None 288 | return 289 | with torch.cuda.stream(self.stream): 290 | self.next_input = self.next_input.cuda(async=True) 291 | self.next_target = self.next_target.cuda(async=True) 292 | if args.fp16: 293 | self.next_input = self.next_input.half() 294 | else: 295 | self.next_input = self.next_input.float() 296 | self.next_input = self.next_input.sub_(self.mean).div_(self.std) 297 | 298 | def next(self): 299 | torch.cuda.current_stream().wait_stream(self.stream) 300 | input = self.next_input 301 | target = self.next_target 302 | self.preload() 303 | return input, target 304 | 305 | 306 | def train(train_loader, model, criterion, optimizer, epoch, iteration_size): 307 | batch_time = AverageMeter() 308 | data_time = AverageMeter() 309 | losses = AverageMeter() 310 | top1 = AverageMeter() 311 | top5 = AverageMeter() 312 | 313 | # switch to train mode 314 | model.train() 315 | optimizer.zero_grad() 316 | 317 | end = time.time() 318 | 319 | prefetcher = data_prefetcher(train_loader) 320 | input, target = prefetcher.next() 321 | i = -1 322 | while input is not None: 323 | i += 1 324 | 325 | # measure data loading time 326 | data_time.update(time.time() - end) 327 | input_var = Variable(input) 328 | target_var = Variable(target) 329 | 330 | # compute output 331 | output = model(input_var) 332 | loss = criterion(output, target_var) 333 | 334 | # measure accuracy and record loss 335 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 336 | 337 | if args.distributed: 338 | reduced_loss = reduce_tensor(loss.data) 339 | prec1 = reduce_tensor(prec1) 340 | prec5 = reduce_tensor(prec5) 341 | else: 342 | reduced_loss = loss.data 343 | 344 | losses.update(to_python_float(reduced_loss), input.size(0)) 345 | top1.update(100 - to_python_float(prec1), input.size(0)) 346 | top5.update(100 - to_python_float(prec5), input.size(0)) 347 | 348 | loss = loss*args.static_loss_scale 349 | # compute gradient and do SGD step 350 | loss.backward() 351 | if i % iteration_size == iteration_size - 1: 352 | optimizer.step() 353 | optimizer.zero_grad() 354 | 355 | torch.cuda.synchronize() 356 | # measure elapsed time 357 | batch_time.update(time.time() - end) 358 | 359 | end = time.time() 360 | input, target = prefetcher.next() 361 | if args.rank == 0 and i % args.print_freq == 0: 362 | print('Epoch: [{0}][{1}/{2}]\t' 363 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 364 | 'Speed {3:.3f} ({4:.3f})\t' 365 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 366 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 367 | 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' 368 | 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 369 | epoch, i, len(train_loader), 370 | args.world_size * args.batch_size / batch_time.val, 371 | args.world_size * args.batch_size / batch_time.avg, 372 | batch_time=batch_time, 373 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 374 | 375 | 376 | def validate(val_loader, model, criterion): 377 | batch_time = AverageMeter() 378 | losses = AverageMeter() 379 | top1 = AverageMeter() 380 | top5 = AverageMeter() 381 | 382 | # switch to evaluate mode 383 | model.eval() 384 | 385 | end = time.time() 386 | 387 | prefetcher = data_prefetcher(val_loader) 388 | input, target = prefetcher.next() 389 | i = -1 390 | while input is not None: 391 | i += 1 392 | 393 | target = target.cuda(async=True) 394 | input_var = Variable(input) 395 | target_var = Variable(target) 396 | 397 | # compute output 398 | with torch.no_grad(): 399 | output = model(input_var) 400 | loss = criterion(output, target_var) 401 | 402 | # measure accuracy and record loss 403 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 404 | 405 | if args.distributed: 406 | reduced_loss = reduce_tensor(loss.data) 407 | prec1 = reduce_tensor(prec1) 408 | prec5 = reduce_tensor(prec5) 409 | else: 410 | reduced_loss = loss.data 411 | 412 | losses.update(to_python_float(reduced_loss), input.size(0)) 413 | top1.update(100 - to_python_float(prec1), input.size(0)) 414 | top5.update(100 - to_python_float(prec5), input.size(0)) 415 | 416 | # measure elapsed time 417 | batch_time.update(time.time() - end) 418 | end = time.time() 419 | 420 | if args.rank == 0 and i % args.print_freq == 0: 421 | print('Test: [{0}/{1}]\t' 422 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 423 | 'Speed {2:.3f} ({3:.3f})\t' 424 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 425 | 'Err@1 {top1.val:.3f} ({top1.avg:.3f})\t' 426 | 'Err@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 427 | i, len(val_loader), 428 | args.world_size * args.batch_size / batch_time.val, 429 | args.world_size * args.batch_size / batch_time.avg, 430 | batch_time=batch_time, loss=losses, 431 | top1=top1, top5=top5)) 432 | input, target = prefetcher.next() 433 | print(' * Err@1 {top1.avg:.3f} Err@5 {top5.avg:.3f}' 434 | .format(top1=top1, top5=top5)) 435 | return top1.avg 436 | 437 | 438 | def adjust_learning_rate(optimizer, epoch): 439 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 440 | lr = args.lr * (0.1 ** (epoch // 30)) 441 | for param_group in optimizer.param_groups: 442 | param_group['lr'] = lr 443 | 444 | 445 | def reduce_tensor(tensor): 446 | rt = tensor.clone() 447 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 448 | rt /= args.world_size 449 | return rt 450 | 451 | 452 | if __name__ == '__main__': 453 | main() 454 | -------------------------------------------------------------------------------- /figures/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/elastic/57345c600c63fbde163c41929d6d6dd894d408ce/figures/figure.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import * 2 | from .dla import * 3 | from .dla_up import * 4 | from .resnext import * 5 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | from .utils import CheckpointFunction, CpBatchNorm2d 7 | 8 | 9 | class _DenseLayerElastic(nn.Module): 10 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): 11 | super(_DenseLayerElastic, self).__init__() 12 | self.pool = nn.AvgPool2d(2, stride=2) 13 | self.dummy = nn.Sequential() 14 | self.add_module('conv1_d', nn.Conv2d(num_input_features, bn_size * 15 | growth_rate // 2, kernel_size=1, stride=1, bias=False)), 16 | self.add_module('norm2_d', CpBatchNorm2d(bn_size * growth_rate // 2)), 17 | self.add_module('relu2_d', nn.ReLU(inplace=True)), 18 | self.add_module('conv2_d', nn.Conv2d(bn_size * growth_rate // 2, growth_rate, 19 | kernel_size=3, stride=1, padding=1, bias=False)), 20 | self.add_module('norm1', CpBatchNorm2d(num_input_features)), 21 | self.add_module('relu1', nn.ReLU(inplace=True)), 22 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 23 | growth_rate // 2, kernel_size=1, stride=1, bias=False)), 24 | self.add_module('norm2', CpBatchNorm2d(bn_size * growth_rate // 2)), 25 | self.add_module('relu2', nn.ReLU(inplace=True)), 26 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate // 2, growth_rate, 27 | kernel_size=3, stride=1, padding=1, bias=False)), 28 | self.drop_rate = drop_rate 29 | self.efficient = efficient 30 | 31 | def forward(self, *prev_features): 32 | concated_features = torch.cat(prev_features, 1) 33 | bottleneck_output = self.relu1(self.norm1(concated_features)) 34 | bottleneck_output_d = bottleneck_output 35 | if prev_features[0].size(2) != 7: 36 | bottleneck_output_d = self.pool(bottleneck_output_d) 37 | bottleneck_output_d = self.conv1_d(bottleneck_output_d) 38 | bottleneck_output = self.conv1(bottleneck_output) 39 | new_features_d = self.conv2_d(self.relu2_d(self.norm2_d(bottleneck_output_d))) 40 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 41 | if self.drop_rate > 0: 42 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 43 | if prev_features[0].size(2) != 7: 44 | new_features_d = F.upsample(new_features_d, None, 2, 'bilinear', False) 45 | return new_features + new_features_d 46 | 47 | 48 | def _bn_function_factory(norm, relu, conv): 49 | def bn_function(*inputs): 50 | concated_features = torch.cat(inputs, 1) 51 | bottleneck_output = conv(relu(norm(concated_features))) 52 | return bottleneck_output 53 | return bn_function 54 | 55 | 56 | class _DenseLayer(nn.Module): 57 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): 58 | super(_DenseLayer, self).__init__() 59 | self.add_module('norm1', CpBatchNorm2d(num_input_features)), 60 | self.add_module('relu1', nn.ReLU(inplace=True)), 61 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 62 | growth_rate, kernel_size=1, stride=1, bias=False)), 63 | self.add_module('norm2', CpBatchNorm2d(bn_size * growth_rate)), 64 | self.add_module('relu2', nn.ReLU(inplace=True)), 65 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 66 | kernel_size=3, stride=1, padding=1, bias=False)), 67 | self.drop_rate = drop_rate 68 | self.efficient = efficient 69 | 70 | def forward(self, *prev_features): 71 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 72 | if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 73 | args = prev_features + tuple(self.norm1.parameters()) + tuple(self.conv1.parameters()) 74 | bottleneck_output = CheckpointFunction.apply(bn_function, len(prev_features), *args) 75 | else: 76 | bottleneck_output = bn_function(*prev_features) 77 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 78 | if self.drop_rate > 0: 79 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 80 | return new_features 81 | 82 | 83 | class _Transition(nn.Sequential): 84 | def __init__(self, num_input_features, num_output_features): 85 | super(_Transition, self).__init__() 86 | self.add_module('norm', CpBatchNorm2d(num_input_features)) 87 | self.add_module('relu', nn.ReLU(inplace=True)) 88 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 89 | kernel_size=1, stride=1, bias=False)) 90 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 91 | 92 | 93 | class _DenseBlock(nn.Module): 94 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, 95 | efficient=False, dense_layer=_DenseLayer): 96 | super(_DenseBlock, self).__init__() 97 | for i in range(num_layers): 98 | layer = dense_layer( 99 | num_input_features + i * growth_rate, 100 | growth_rate=growth_rate, 101 | bn_size=bn_size, 102 | drop_rate=drop_rate, 103 | efficient=efficient, 104 | ) 105 | self.add_module('denselayer%d' % (i + 1), layer) 106 | 107 | def forward(self, init_features): 108 | features = [init_features] 109 | for name, layer in self.named_children(): 110 | new_features = layer(*features) 111 | features.append(new_features) 112 | return torch.cat(features, 1) 113 | 114 | 115 | class DenseNet(nn.Module): 116 | r"""Densenet-BC model class, based on 117 | `"Densely Connected Convolutional Networks" ` 118 | Args: 119 | growth_rate (int) - how many filters to add each layer (`k` in paper) 120 | block_config (list of 3 or 4 ints) - how many layers in each pooling block 121 | num_init_features (int) - the number of filters to learn in the first convolution layer 122 | bn_size (int) - multiplicative factor for number of bottle neck layers 123 | (i.e. bn_size * k features in the bottleneck layer) 124 | drop_rate (float) - dropout rate after each dense layer 125 | num_classes (int) - number of classification classes 126 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 127 | efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. 128 | """ 129 | def __init__(self, growth_rate=32, block_config=(16, 16, 16), compression=0.5, 130 | num_init_features=64, bn_size=4, drop_rate=0, 131 | num_classes=1000, small_inputs=False, efficient=True, elastic=False): 132 | 133 | super(DenseNet, self).__init__() 134 | assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' 135 | self.avgpool_size = 8 if small_inputs else 7 136 | 137 | # First convolution 138 | if small_inputs: 139 | self.features = nn.Sequential(OrderedDict([ 140 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), 141 | ])) 142 | else: 143 | self.features = nn.Sequential(OrderedDict([ 144 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 145 | ])) 146 | self.features.add_module('norm0', CpBatchNorm2d(num_init_features)) 147 | self.features.add_module('relu0', nn.ReLU(inplace=True)) 148 | self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, 149 | ceil_mode=False)) 150 | 151 | # Each denseblock 152 | num_features = num_init_features 153 | for i, num_layers in enumerate(block_config): 154 | block = _DenseBlock( 155 | num_layers=num_layers, 156 | num_input_features=num_features, 157 | bn_size=bn_size, 158 | growth_rate=growth_rate, 159 | drop_rate=drop_rate, 160 | efficient=efficient and i == 0, 161 | dense_layer=_DenseLayer if not elastic else _DenseLayerElastic 162 | ) 163 | self.features.add_module('denseblock%d' % (i + 1), block) 164 | num_features = num_features + num_layers * growth_rate 165 | if i != len(block_config) - 1: 166 | trans = _Transition(num_input_features=num_features, 167 | num_output_features=int(num_features * compression)) 168 | self.features.add_module('transition%d' % (i + 1), trans) 169 | num_features = int(num_features * compression) 170 | 171 | # Final batch norm 172 | self.features.add_module('norm_final', CpBatchNorm2d(num_features)) 173 | 174 | # Linear layer 175 | self.classifier = nn.Linear(num_features, num_classes) 176 | 177 | # Initialization 178 | for name, param in self.named_parameters(): 179 | if 'conv' in name and 'weight' in name: 180 | n = param.size(0) * param.size(2) * param.size(3) 181 | param.data.normal_().mul_(math.sqrt(2. / n)) 182 | elif 'norm' in name and 'weight' in name: 183 | param.data.fill_(1) 184 | elif 'norm' in name and 'bias' in name: 185 | param.data.fill_(0) 186 | elif 'classifier' in name and 'bias' in name: 187 | param.data.fill_(0) 188 | 189 | def forward(self, x): 190 | features = self.features(x) 191 | out = F.relu(features, inplace=True) 192 | out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view(features.size(0), -1) 193 | out = self.classifier(out) 194 | return out 195 | 196 | 197 | def densenet201(**kwargs): 198 | model = DenseNet(block_config=(6, 12, 48, 32), elastic=False, **kwargs) 199 | return model 200 | 201 | 202 | def densenet201_elastic(**kwargs): 203 | model = DenseNet(block_config=(10, 20, 40, 30), elastic=True, **kwargs) 204 | return model 205 | -------------------------------------------------------------------------------- /models/dla.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import math 5 | from torch.utils.checkpoint import * 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | # from .utils import fill_up_weights, CpBatchNorm2d 11 | BatchNorm = nn.BatchNorm2d 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | "3x3 convolution with padding" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BottleneckX(nn.Module): 21 | expansion = 2 22 | cardinality = 32 23 | 24 | def __init__(self, inplanes, planes, stride=1, dilation=1): 25 | super(BottleneckX, self).__init__() 26 | cardinality = BottleneckX.cardinality 27 | bottle_planes = planes * cardinality // 32 28 | self.conv1 = nn.Conv2d(inplanes, bottle_planes, 29 | kernel_size=1, bias=False) 30 | self.bn1 = BatchNorm(bottle_planes) 31 | self.conv2 = nn.Conv2d(bottle_planes, bottle_planes, kernel_size=3, 32 | stride=stride, padding=dilation, bias=False, 33 | dilation=dilation, groups=cardinality) 34 | self.bn2 = BatchNorm(bottle_planes) 35 | self.conv3 = nn.Conv2d(bottle_planes, planes, 36 | kernel_size=1, bias=False) 37 | self.bn3 = BatchNorm(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.stride = stride 40 | 41 | def forward(self, x, residual=None): 42 | if residual is None: 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv3(out) 54 | out = self.bn3(out) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class BottleneckXElastic(nn.Module): 63 | expansion = 2 64 | cardinality = 32 65 | 66 | def __init__(self, inplanes, planes, stride=1, dilation=1): 67 | super(BottleneckXElastic, self).__init__() 68 | cardinality = BottleneckX.cardinality 69 | self.elastic = (stride == 1 and planes < 1024) 70 | if self.elastic: 71 | # self.ups = nn.ConvTranspose2d( 72 | # inplanes, inplanes, 4, stride=2, padding=1, 73 | # output_padding=0, groups=inplanes, bias=False) 74 | # fill_up_weights(self.ups) 75 | self.down = nn.AvgPool2d(2, stride=2) 76 | self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 77 | 78 | bottle_planes = planes * cardinality // 32 79 | 80 | self.conv1_d = nn.Conv2d(inplanes, bottle_planes // 2, 81 | kernel_size=1, bias=False) 82 | self.bn1_d = BatchNorm(bottle_planes // 2) 83 | self.conv2_d = nn.Conv2d(bottle_planes // 2, bottle_planes // 2, kernel_size=3, 84 | stride=stride, padding=dilation, bias=False, 85 | dilation=dilation, groups=cardinality // 2) 86 | self.bn2_d = BatchNorm(bottle_planes // 2) 87 | self.conv3_d = nn.Conv2d(bottle_planes // 2, planes, 88 | kernel_size=1, bias=False) 89 | 90 | self.conv1 = nn.Conv2d(inplanes, bottle_planes // 2, 91 | kernel_size=1, bias=False) 92 | self.bn1 = BatchNorm(bottle_planes // 2) 93 | self.conv2 = nn.Conv2d(bottle_planes // 2, bottle_planes // 2, kernel_size=3, 94 | stride=stride, padding=dilation, bias=False, 95 | dilation=dilation, groups=cardinality // 2) 96 | self.bn2 = BatchNorm(bottle_planes // 2) 97 | self.conv3 = nn.Conv2d(bottle_planes // 2, planes, 98 | kernel_size=1, bias=False) 99 | 100 | self.bn3 = BatchNorm(planes) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.stride = stride 103 | self.__flops__ = 0 104 | 105 | def forward(self, x, residual=None): 106 | if residual is None: 107 | residual = x 108 | out_d = x 109 | if self.elastic: 110 | if x.size(2) % 2 > 0 or x.size(3) % 2 > 0: 111 | out_d = F.pad(out_d, (0, x.size(3) % 2, 0, x.size(2) % 2), mode='replicate') 112 | out_d = self.down(out_d) 113 | 114 | out_d = self.conv1_d(out_d) 115 | out_d = self.bn1_d(out_d) 116 | out_d = self.relu(out_d) 117 | 118 | out_d = self.conv2_d(out_d) 119 | out_d = self.bn2_d(out_d) 120 | out_d = self.relu(out_d) 121 | 122 | out_d = self.conv3_d(out_d) 123 | if self.elastic: 124 | out_d = self.ups(out_d) 125 | self.__flops__ += np.prod(out_d[0].shape) * 8 126 | if out_d.size(2) > x.size(2) or out_d.size(3) > x.size(3): 127 | out_d = out_d[:, :, :x.size(2), :x.size(3)] 128 | 129 | out = self.conv1(x) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv2(out) 134 | out = self.bn2(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv3(out) 138 | 139 | out = out + out_d 140 | out = self.bn3(out) 141 | 142 | out += residual 143 | out = self.relu(out) 144 | 145 | return out 146 | 147 | 148 | class Root(nn.Module): 149 | def __init__(self, in_channels, out_channels, kernel_size, residual): 150 | super(Root, self).__init__() 151 | self.conv = nn.Conv2d( 152 | in_channels, out_channels, 1, 153 | stride=1, bias=False, padding=(kernel_size - 1) // 2) 154 | self.bn = BatchNorm(out_channels) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.residual = residual 157 | 158 | def forward(self, *x): 159 | children = x 160 | x = self.conv(torch.cat(x, 1)) 161 | x = self.bn(x) 162 | if self.residual: 163 | x += children[0] 164 | x = self.relu(x) 165 | 166 | return x 167 | 168 | 169 | class Tree(nn.Module): 170 | def __init__(self, levels, block, in_channels, out_channels, stride=1, 171 | level_root=False, root_dim=0, root_kernel_size=1, 172 | dilation=1, root_residual=False, seg=False): 173 | super(Tree, self).__init__() 174 | if root_dim == 0: 175 | root_dim = 2 * out_channels 176 | if level_root: 177 | root_dim += in_channels 178 | if levels == 1: 179 | self.tree1 = block(in_channels, out_channels, stride, 180 | dilation=dilation) 181 | self.tree2 = block(out_channels, out_channels, 1, 182 | dilation=dilation) 183 | else: 184 | self.tree1 = Tree(levels - 1, block, in_channels, out_channels, 185 | stride, root_dim=0, 186 | root_kernel_size=root_kernel_size, 187 | dilation=dilation, root_residual=root_residual, seg=seg) 188 | self.tree2 = Tree(levels - 1, block, out_channels, out_channels, 189 | root_dim=root_dim + out_channels, 190 | root_kernel_size=root_kernel_size, 191 | dilation=dilation, root_residual=root_residual, seg=seg) 192 | if levels == 1: 193 | self.root = Root(root_dim, out_channels, root_kernel_size, 194 | root_residual) 195 | self.level_root = level_root 196 | self.root_dim = root_dim 197 | self.downsample = None 198 | self.project = None 199 | self.levels = levels 200 | if stride > 1: 201 | self.downsample = nn.MaxPool2d(stride, stride=stride, ceil_mode=seg) 202 | if in_channels != out_channels: 203 | self.project = nn.Sequential( 204 | nn.Conv2d(in_channels, out_channels, 205 | kernel_size=1, stride=1, bias=False), 206 | BatchNorm(out_channels) 207 | ) 208 | 209 | def forward(self, x, residual=None, children=None): 210 | children = [] if children is None else children 211 | bottom = self.downsample(x) if self.downsample else x 212 | residual = self.project(bottom) if self.project else bottom 213 | if self.level_root: 214 | children.append(bottom) 215 | x1 = self.tree1(x, residual) 216 | if self.levels == 1: 217 | x2 = self.tree2(x1) 218 | x = self.root(x2, x1, *children) 219 | else: 220 | children.append(x1) 221 | x = self.tree2(x1, children=children) 222 | return x 223 | 224 | 225 | class DLA(nn.Module): 226 | def __init__(self, levels, channels, num_classes=1000, 227 | block=BottleneckX, residual_root=False, return_levels=False, 228 | pool_size=7, linear_root=False, seg=False): 229 | super(DLA, self).__init__() 230 | self.channels = channels 231 | self.seg = seg 232 | self.return_levels = return_levels 233 | self.num_classes = num_classes 234 | self.base_layer = nn.Sequential( 235 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 236 | padding=3, bias=False), 237 | BatchNorm(channels[0]), 238 | nn.ReLU(inplace=True)) 239 | self.level0 = self._make_conv_level( 240 | channels[0], channels[0], levels[0]) 241 | self.level1 = self._make_conv_level( 242 | channels[0], channels[1], levels[1], stride=2) 243 | self.level2 = Tree(levels[2], block, channels[1], channels[2], 2, 244 | level_root=False, root_residual=residual_root, seg=seg) 245 | self.level3 = Tree(levels[3], block, channels[2], channels[3], 2, 246 | level_root=True, root_residual=residual_root, seg=seg) 247 | self.level4 = Tree(levels[4], block, channels[3], channels[4], 2, 248 | level_root=True, root_residual=residual_root, seg=seg) 249 | self.level5 = Tree(levels[5], block, channels[4], channels[5], 2, 250 | level_root=True, root_residual=residual_root, seg=seg) 251 | 252 | self.avgpool = nn.AvgPool2d(pool_size) 253 | self.fc = nn.Conv2d(channels[-1], num_classes, kernel_size=1, 254 | stride=1, padding=0, bias=True) 255 | 256 | for m in self.modules(): 257 | if isinstance(m, nn.Conv2d): 258 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 259 | m.weight.data.normal_(0, math.sqrt(2. / n)) 260 | elif isinstance(m, BatchNorm): 261 | m.weight.data.fill_(1) 262 | m.bias.data.zero_() 263 | 264 | def _make_level(self, block, inplanes, planes, blocks, stride=1): 265 | downsample = None 266 | if stride != 1 or inplanes != planes: 267 | downsample = nn.Sequential( 268 | nn.MaxPool2d(stride, stride=stride, ceil_mode=self.seg), 269 | nn.Conv2d(inplanes, planes, 270 | kernel_size=1, stride=1, bias=False), 271 | BatchNorm(planes), 272 | ) 273 | 274 | layers = [] 275 | layers.append(block(inplanes, planes, stride, downsample=downsample)) 276 | for i in range(1, blocks): 277 | layers.append(block(inplanes, planes)) 278 | 279 | return nn.Sequential(*layers) 280 | 281 | def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): 282 | modules = [] 283 | for i in range(convs): 284 | modules.extend([ 285 | nn.Conv2d(inplanes, planes, kernel_size=3, 286 | stride=stride if i == 0 else 1, 287 | padding=dilation, bias=False, dilation=dilation), 288 | BatchNorm(planes), 289 | nn.ReLU(inplace=True)]) 290 | inplanes = planes 291 | return nn.Sequential(*modules) 292 | 293 | def forward(self, x): 294 | y = [] 295 | x = self.base_layer(x) 296 | for i in range(6): 297 | if self.seg: 298 | x = checkpoint(getattr(self, 'level{}'.format(i)), x) 299 | else: 300 | x = getattr(self, 'level{}'.format(i))(x) 301 | y.append(x) 302 | if self.return_levels: 303 | return y 304 | else: 305 | x = self.avgpool(x) 306 | x = self.fc(x) 307 | x = x.view(x.size(0), -1) 308 | return x 309 | 310 | 311 | def dla60x(**kwargs): 312 | model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], 313 | block=BottleneckX, **kwargs) 314 | return model 315 | 316 | 317 | def dla102x(**kwargs): 318 | model = DLA([1, 1, 1, 3, 4, 1], [16, 32, 128, 256, 512, 1024], 319 | block=BottleneckX, residual_root=True, **kwargs) 320 | return model 321 | 322 | 323 | def dla60x_elastic(**kwargs): 324 | model = DLA([1, 1, 1, 2, 3, 1], [16, 32, 128, 256, 512, 1024], 325 | block=BottleneckXElastic, **kwargs) 326 | return model 327 | 328 | 329 | def dla102x_elastic(**kwargs): 330 | BottleneckX.cardinality = 50 331 | model = DLA([1, 1, 3, 3, 3, 1], [16, 32, 128, 256, 512, 1024], 332 | block=BottleneckXElastic, residual_root=True, **kwargs) 333 | return model 334 | -------------------------------------------------------------------------------- /models/dla_up.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from . import dla 7 | from .utils import fill_up_weights 8 | BatchNorm = nn.BatchNorm2d 9 | 10 | 11 | class Identity(nn.Module): 12 | def __init__(self): 13 | super(Identity, self).__init__() 14 | 15 | def forward(self, x): 16 | return x 17 | 18 | 19 | class IDAUp(nn.Module): 20 | def __init__(self, node_kernel, out_dim, channels, up_factors): 21 | super(IDAUp, self).__init__() 22 | self.channels = channels 23 | self.out_dim = out_dim 24 | for i, c in enumerate(channels): 25 | if c == out_dim: 26 | proj = Identity() 27 | else: 28 | proj = nn.Sequential( 29 | nn.Conv2d(c, out_dim, 30 | kernel_size=1, stride=1, bias=False), 31 | BatchNorm(out_dim), 32 | nn.ReLU(inplace=True)) 33 | f = int(up_factors[i]) 34 | if f == 1: 35 | up = Identity() 36 | else: 37 | up = nn.ConvTranspose2d( 38 | out_dim, out_dim, f * 2, stride=f, padding=f // 2, 39 | output_padding=0, groups=out_dim, bias=False) 40 | fill_up_weights(up) 41 | setattr(self, 'proj_' + str(i), proj) 42 | setattr(self, 'up_' + str(i), up) 43 | 44 | for i in range(1, len(channels)): 45 | node = nn.Sequential( 46 | nn.Conv2d(out_dim * 2, out_dim, 47 | kernel_size=node_kernel, stride=1, 48 | padding=node_kernel // 2, bias=False), 49 | BatchNorm(out_dim), 50 | nn.ReLU(inplace=True)) 51 | setattr(self, 'node_' + str(i), node) 52 | 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | elif isinstance(m, BatchNorm): 58 | m.weight.data.fill_(1) 59 | m.bias.data.zero_() 60 | 61 | def forward(self, layers): 62 | assert len(self.channels) == len(layers), \ 63 | '{} vs {} layers'.format(len(self.channels), len(layers)) 64 | layers = list(layers) 65 | for i, l in enumerate(layers): 66 | upsample = getattr(self, 'up_' + str(i)) 67 | project = getattr(self, 'proj_' + str(i)) 68 | layers[i] = upsample(project(l)) 69 | x = layers[0] 70 | y = [] 71 | for i in range(1, len(layers)): 72 | node = getattr(self, 'node_' + str(i)) 73 | x = node(torch.cat([x, layers[i][:, :, :x.size(2), :x.size(3)]], 1)) 74 | y.append(x) 75 | return x, y 76 | 77 | 78 | class DLAUp(nn.Module): 79 | def __init__(self, channels, scales=(1, 2, 4, 8, 16), in_channels=None): 80 | super(DLAUp, self).__init__() 81 | if in_channels is None: 82 | in_channels = channels 83 | self.channels = channels 84 | channels = list(channels) 85 | scales = np.array(scales, dtype=int) 86 | for i in range(len(channels) - 1): 87 | j = -i - 2 88 | setattr(self, 'ida_{}'.format(i), 89 | IDAUp(3, channels[j], in_channels[j:], 90 | scales[j:] // scales[j])) 91 | scales[j + 1:] = scales[j] 92 | in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] 93 | 94 | def forward(self, layers): 95 | layers = list(layers) 96 | assert len(layers) > 1 97 | for i in range(len(layers) - 1): 98 | ida = getattr(self, 'ida_{}'.format(i)) 99 | x, y = ida(layers[-i - 2:]) 100 | layers[-i - 1:] = y 101 | return x 102 | 103 | 104 | class DLASeg(nn.Module): 105 | def __init__(self, base_name, classes, down_ratio=2): 106 | super(DLASeg, self).__init__() 107 | assert down_ratio in [2, 4, 8, 16] 108 | self.first_level = int(np.log2(down_ratio)) 109 | self.base = dla.__dict__[base_name](return_levels=True, seg=True) 110 | channels = self.base.channels 111 | # print(channels, self.first_level) 112 | scales = [2 ** i for i in range(len(channels[self.first_level:]))] 113 | self.dla_up = DLAUp(channels[self.first_level:], scales=scales) 114 | self.fc = nn.Sequential( 115 | nn.Conv2d(channels[self.first_level], classes, kernel_size=1, 116 | stride=1, padding=0, bias=True) 117 | ) 118 | up_factor = 2 ** self.first_level 119 | if up_factor > 1: 120 | up = nn.ConvTranspose2d(classes, classes, up_factor * 2, 121 | stride=up_factor, padding=up_factor // 2, 122 | output_padding=0, groups=classes, 123 | bias=False) 124 | fill_up_weights(up) 125 | up.weight.requires_grad = False 126 | else: 127 | up = Identity() 128 | self.up = up 129 | self.softmax = nn.LogSoftmax(dim=1) 130 | 131 | for m in self.fc.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | elif isinstance(m, BatchNorm): 136 | m.weight.data.fill_(1) 137 | m.bias.data.zero_() 138 | 139 | def forward(self, x): 140 | x = self.base(x) 141 | x = self.dla_up(x[self.first_level:]) 142 | x = self.fc(x) 143 | y = self.softmax(self.up(x)) 144 | return y[:, :, :-1, :-1] 145 | 146 | def optim_parameters(self, memo=None): 147 | for param in self.base.parameters(): 148 | yield param 149 | for param in self.dla_up.parameters(): 150 | yield param 151 | for param in self.fc.parameters(): 152 | yield param 153 | 154 | 155 | def dla60x_seg(classes, **kwargs): 156 | model = DLASeg('dla60x', classes, **kwargs) 157 | return model 158 | 159 | 160 | def dla60x_elastic_seg(classes, **kwargs): 161 | model = DLASeg('dla60x_elastic', classes, **kwargs) 162 | return model 163 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | from torch.utils.checkpoint import * 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | import numpy as np 9 | 10 | 11 | class ASPP(nn.Module): 12 | def __init__(self, C, depth, num_classes, norm=nn.BatchNorm2d, momentum=0.0003, mult=1): 13 | super(ASPP, self).__init__() 14 | self._C = C 15 | self._depth = depth 16 | self._num_classes = num_classes 17 | self._norm = norm 18 | 19 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.aspp1 = nn.Conv2d(C, depth, kernel_size=1, stride=1, bias=False) 22 | self.aspp2 = nn.Conv2d(C, depth, kernel_size=3, stride=1, 23 | dilation=int(6*mult), padding=int(6*mult), 24 | bias=False) 25 | self.aspp3 = nn.Conv2d(C, depth, kernel_size=3, stride=1, 26 | dilation=int(12*mult), padding=int(12*mult), 27 | bias=False) 28 | self.aspp4 = nn.Conv2d(C, depth, kernel_size=3, stride=1, 29 | dilation=int(18*mult), padding=int(18*mult), 30 | bias=False) 31 | self.aspp5 = nn.Conv2d(C, depth, kernel_size=1, stride=1, bias=False) 32 | self.aspp1_bn = self._norm(depth, momentum) 33 | self.aspp2_bn = self._norm(depth, momentum) 34 | self.aspp3_bn = self._norm(depth, momentum) 35 | self.aspp4_bn = self._norm(depth, momentum) 36 | self.aspp5_bn = self._norm(depth, momentum) 37 | self.conv2 = nn.Conv2d(depth * 5, depth, kernel_size=1, stride=1, 38 | bias=False) 39 | self.bn2 = self._norm(depth, momentum) 40 | self.conv3 = nn.Conv2d(depth, num_classes, kernel_size=1, stride=1) 41 | 42 | def forward(self, x): 43 | x1 = self.aspp1(x) 44 | x1 = self.aspp1_bn(x1) 45 | x1 = self.relu(x1) 46 | x2 = self.aspp2(x) 47 | x2 = self.aspp2_bn(x2) 48 | x2 = self.relu(x2) 49 | x3 = self.aspp3(x) 50 | x3 = self.aspp3_bn(x3) 51 | x3 = self.relu(x3) 52 | x4 = self.aspp4(x) 53 | x4 = self.aspp4_bn(x4) 54 | x4 = self.relu(x4) 55 | x5 = self.global_pooling(x) 56 | x5 = self.aspp5(x5) 57 | x5 = self.aspp5_bn(x5) 58 | x5 = self.relu(x5) 59 | x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', 60 | align_corners=True)(x5) 61 | x = torch.cat((x1, x2, x3, x4, x5), 1) 62 | x = self.conv2(x) 63 | x = self.bn2(x) 64 | x = self.relu(x) 65 | x = self.conv3(x) 66 | 67 | return x 68 | 69 | 70 | class Selayer(nn.Module): 71 | 72 | def __init__(self, inplanes): 73 | super(Selayer, self).__init__() 74 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 75 | self.conv1 = nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1) 76 | self.conv2 = nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.sigmoid = nn.Sigmoid() 79 | 80 | def forward(self, x): 81 | 82 | out = self.global_avgpool(x) 83 | 84 | out = self.conv1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.sigmoid(out) 89 | 90 | return x * out 91 | 92 | 93 | class BottleneckX(nn.Module): 94 | expansion = 4 95 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None, dilation=1, norm=None, elastic=False, se=False): 96 | super(BottleneckX, self).__init__() 97 | self.se = se 98 | self.elastic = elastic and stride == 1 and planes < 512 99 | if self.elastic: 100 | self.down = nn.AvgPool2d(2, stride=2) 101 | self.ups = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 102 | # half resolution 103 | self.conv1_d = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 104 | self.bn1_d = norm(planes) 105 | self.conv2_d = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, groups=cardinality // 2, 106 | dilation=dilation, padding=dilation, bias=False) 107 | self.bn2_d = norm(planes) 108 | self.conv3_d = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 109 | # full resolution 110 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 111 | self.bn1 = norm(planes) 112 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, groups=cardinality // 2, 113 | dilation=dilation, padding=dilation, bias=False) 114 | self.bn2 = norm(planes) 115 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 116 | # after merging 117 | self.bn3 = norm(planes * self.expansion) 118 | if self.se: 119 | self.selayer = Selayer(planes * 4) 120 | self.relu = nn.ReLU(inplace=True) 121 | self.downsample = downsample 122 | self.stride = stride 123 | self.__flops__ = 0 124 | 125 | def forward(self, x): 126 | residual = x 127 | out_d = x 128 | if self.elastic: 129 | if x.size(2) % 2 > 0 or x.size(3) % 2 > 0: 130 | out_d = F.pad(out_d, (0, x.size(3) % 2, 0, x.size(2) % 2), mode='replicate') 131 | out_d = self.down(out_d) 132 | 133 | out_d = self.conv1_d(out_d) 134 | out_d = self.bn1_d(out_d) 135 | out_d = self.relu(out_d) 136 | 137 | out_d = self.conv2_d(out_d) 138 | out_d = self.bn2_d(out_d) 139 | out_d = self.relu(out_d) 140 | 141 | out_d = self.conv3_d(out_d) 142 | 143 | if self.elastic: 144 | out_d = self.ups(out_d) 145 | self.__flops__ += np.prod(out_d[0].shape) * 8 146 | if out_d.size(2) > x.size(2) or out_d.size(3) > x.size(3): 147 | out_d = out_d[:, :, :x.size(2), :x.size(3)] 148 | 149 | out = self.conv1(x) 150 | out = self.bn1(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv2(out) 154 | out = self.bn2(out) 155 | out = self.relu(out) 156 | 157 | out = self.conv3(out) 158 | out = out + out_d 159 | out = self.bn3(out) 160 | 161 | if self.se: 162 | out = self.selayer(out) 163 | 164 | if self.downsample is not None: 165 | residual = self.downsample(x) 166 | 167 | out += residual 168 | out = self.relu(out) 169 | 170 | return out 171 | 172 | 173 | class ResNext(nn.Module): 174 | 175 | def __init__(self, block, layers, num_classes=1000, seg=False, elastic=False, se=False): 176 | self.inplanes = 64 177 | self.cardinality = 32 178 | self.seg = seg 179 | self._norm = lambda planes, momentum=0.05 if seg else 0.1: torch.nn.BatchNorm2d(planes, momentum=momentum) 180 | 181 | super(ResNext, self).__init__() 182 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 183 | self.bn1 = self._norm(64) 184 | self.relu = nn.ReLU(inplace=True) 185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 186 | self.layer1 = self._make_layer(block, 64, layers[0], elastic=elastic, se=se) 187 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, elastic=elastic, se=se) 188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, elastic=elastic, se=se) 189 | if seg: 190 | self.layer4 = self._make_mg(block, 512, se=se) 191 | self.aspp = ASPP(512 * block.expansion, 256, num_classes, self._norm) 192 | for m in self.modules(): 193 | if isinstance(m, nn.Conv2d): 194 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 195 | m.weight.data.normal_(0, math.sqrt(2. / n)) 196 | elif isinstance(m, torch.nn.BatchNorm2d): 197 | m.weight.data.fill_(1) 198 | m.bias.data.zero_() 199 | else: 200 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, elastic=False, se=se) 201 | self.avgpool = nn.AdaptiveAvgPool2d(1) 202 | self.fc = nn.Linear(512 * block.expansion, num_classes) 203 | init.normal_(self.fc.weight, std=0.01) 204 | for n, p in self.named_parameters(): 205 | if n.split('.')[-1] == 'weight': 206 | if 'conv' in n: 207 | init.kaiming_normal_(p, mode='fan_in', nonlinearity='relu') 208 | if 'bn' in n: 209 | p.data.fill_(1) 210 | if 'bn3' in n: 211 | p.data.fill_(0) 212 | elif n.split('.')[-1] == 'bias': 213 | p.data.fill_(0) 214 | 215 | def _make_layer(self, block, planes, blocks, stride=1, elastic=False, se=False): 216 | downsample = None 217 | if stride != 1 or self.inplanes != planes * block.expansion: 218 | downsample = nn.Sequential( 219 | nn.Conv2d(self.inplanes, planes * block.expansion, 220 | kernel_size=1, stride=stride, bias=False), 221 | self._norm(planes * block.expansion), 222 | ) 223 | 224 | layers = list() 225 | layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample=downsample, norm=self._norm, elastic=elastic, se=se)) 226 | self.inplanes = planes * block.expansion 227 | for i in range(1, blocks): 228 | layers.append(block(self.inplanes, planes, self.cardinality, norm=self._norm, elastic=elastic, se=se)) 229 | return nn.Sequential(*layers) 230 | 231 | def _make_mg(self, block, planes, dilation=2, multi_grid=(1, 2, 4), se=False): 232 | downsample = nn.Sequential( 233 | nn.Conv2d(self.inplanes, planes * block.expansion, 234 | kernel_size=1, stride=1, dilation=1, bias=False), 235 | self._norm(planes * block.expansion), 236 | ) 237 | 238 | layers = list() 239 | layers.append(block(self.inplanes, planes, self.cardinality, downsample=downsample, dilation=dilation*multi_grid[0], norm=self._norm, se=se)) 240 | self.inplanes = planes * block.expansion 241 | layers.append(block(self.inplanes, planes, self.cardinality, dilation=dilation*multi_grid[1], norm=self._norm, se=se)) 242 | layers.append(block(self.inplanes, planes, self.cardinality, dilation=dilation*multi_grid[2], norm=self._norm, se=se)) 243 | return nn.Sequential(*layers) 244 | 245 | def forward(self, x): 246 | size = (x.shape[2], x.shape[3]) 247 | x = self.conv1(x) 248 | x = self.bn1(x) 249 | x = self.relu(x) 250 | x = self.maxpool(x) 251 | if self.seg: 252 | for module in self.layer1._modules.values(): 253 | x = checkpoint(module, x) 254 | for module in self.layer2._modules.values(): 255 | x = checkpoint(module, x) 256 | for module in self.layer3._modules.values(): 257 | x = checkpoint(module, x) 258 | for module in self.layer4._modules.values(): 259 | x = checkpoint(module, x) 260 | x = self.aspp(x) 261 | x = nn.Upsample(size, mode='bilinear', align_corners=True)(x) 262 | else: 263 | x = self.layer1(x) 264 | x = self.layer2(x) 265 | x = self.layer3(x) 266 | x = self.layer4(x) 267 | x = self.avgpool(x) 268 | x = x.view(x.size(0), -1) 269 | x = self.fc(x) 270 | return x 271 | 272 | 273 | def resnext50(seg=False, **kwargs): 274 | model = ResNext(BottleneckX, [3, 4, 6, 3], seg=seg, elastic=False, **kwargs) 275 | return model 276 | 277 | 278 | def se_resnext50(seg=False, **kwargs): 279 | model = ResNext(BottleneckX, [3, 4, 6, 3], seg=seg, elastic=False, se=True, **kwargs) 280 | return model 281 | 282 | 283 | def resnext50_elastic(seg=False, **kwargs): 284 | model = ResNext(BottleneckX, [6, 8, 5, 3], seg=seg, elastic=True, **kwargs) 285 | return model 286 | 287 | 288 | def se_resnext50_elastic(seg=False, **kwargs): 289 | model = ResNext(BottleneckX, [6, 8, 5, 3], seg=seg, elastic=True, se=True, **kwargs) 290 | return model 291 | 292 | 293 | def resnext101(seg=False, **kwargs): 294 | model = ResNext(BottleneckX, [3, 4, 23, 3], seg=seg, elastic=False, **kwargs) 295 | return model 296 | 297 | 298 | def resnext101_elastic(seg=False, **kwargs): 299 | model = ResNext(BottleneckX, [12, 14, 20, 3], seg=seg, elastic=True, **kwargs) 300 | return model 301 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class CpBatchNorm2d(torch.nn.BatchNorm2d): 8 | def __init__(self, *args, **kwargs): 9 | super(CpBatchNorm2d, self).__init__(*args, **kwargs) 10 | 11 | def forward(self, input): 12 | self._check_input_dim(input) 13 | if input.requires_grad: 14 | exponential_average_factor = 0.0 15 | if self.training and self.track_running_stats: 16 | self.num_batches_tracked += 1 17 | if self.momentum is None: # use cumulative moving average 18 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 19 | else: # use exponential moving average 20 | exponential_average_factor = self.momentum 21 | return F.batch_norm( 22 | input, self.running_mean, self.running_var, self.weight, self.bias, 23 | self.training or not self.track_running_stats, 24 | exponential_average_factor, self.eps) 25 | else: 26 | return F.batch_norm( 27 | input, self.running_mean, self.running_var, self.weight, self.bias, 28 | self.training or not self.track_running_stats, 0.0, self.eps) 29 | 30 | 31 | def detach_variable(inputs): 32 | if isinstance(inputs, tuple): 33 | out = [] 34 | for inp in inputs: 35 | x = inp.detach() 36 | x.requires_grad = inp.requires_grad 37 | out.append(x) 38 | return tuple(out) 39 | else: 40 | raise RuntimeError( 41 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 42 | 43 | 44 | def check_backward_validity(inputs): 45 | if not any(inp.requires_grad for inp in inputs): 46 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") 47 | 48 | 49 | class CheckpointFunction(torch.autograd.Function): 50 | @staticmethod 51 | def forward(ctx, run_function, length, *args): 52 | ctx.run_function = run_function 53 | ctx.input_tensors = list(args[:length]) 54 | ctx.input_params = list(args[length:]) 55 | with torch.no_grad(): 56 | output_tensors = ctx.run_function(*ctx.input_tensors) 57 | return output_tensors 58 | 59 | @staticmethod 60 | def backward(ctx, *output_grads): 61 | for i in range(len(ctx.input_tensors)): 62 | temp = ctx.input_tensors[i] 63 | ctx.input_tensors[i] = temp.detach() 64 | ctx.input_tensors[i].requires_grad = temp.requires_grad 65 | with torch.enable_grad(): 66 | output_tensors = ctx.run_function(*ctx.input_tensors) 67 | input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) 68 | return (None, None) + input_grads 69 | 70 | 71 | def fill_up_weights(up): 72 | w = up.weight.data 73 | f = math.ceil(w.size(2) / 2) 74 | c = (2 * f - 1 - f % 2) / (2. * f) 75 | for i in range(w.size(2)): 76 | for j in range(w.size(3)): 77 | w[0, 0, i, j] = \ 78 | (1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c)) 79 | for c in range(1, w.size(0)): 80 | w[c, 0, :, :] = w[0, 0, :, :] -------------------------------------------------------------------------------- /multilabel_classify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import pdb 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.optim 11 | import torch.utils.data as data 12 | import torch.utils.data.distributed 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | # import torchvision.models as models 16 | import models 17 | import os 18 | from PIL import Image 19 | from utils import add_flops_counting_methods, save_checkpoint, AverageMeter 20 | 21 | model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', 22 | 'dla60x', 'dla60x_elastic', 'dla102x', 'dla102x_elastic', 23 | 'se_resnext50', 'se_resnext50_elastic', 'densenet201', 'densenet201_elastic'] 24 | 25 | 26 | class CocoDetection(datasets.coco.CocoDetection): 27 | def __init__(self, root, annFile, transform=None, target_transform=None): 28 | from pycocotools.coco import COCO 29 | self.root = root 30 | self.coco = COCO(annFile) 31 | self.ids = list(self.coco.imgs.keys()) 32 | self.transform = transform 33 | self.target_transform = target_transform 34 | self.cat2cat = dict() 35 | for cat in self.coco.cats.keys(): 36 | self.cat2cat[cat] = len(self.cat2cat) 37 | # print(self.cat2cat) 38 | 39 | def __getitem__(self, index): 40 | coco = self.coco 41 | img_id = self.ids[index] 42 | ann_ids = coco.getAnnIds(imgIds=img_id) 43 | target = coco.loadAnns(ann_ids) 44 | 45 | output = torch.zeros((3, 80), dtype=torch.long) 46 | for obj in target: 47 | if obj['area'] < 32 * 32: 48 | output[0][self.cat2cat[obj['category_id']]] = 1 49 | elif obj['area'] < 96 * 96: 50 | output[1][self.cat2cat[obj['category_id']]] = 1 51 | else: 52 | output[2][self.cat2cat[obj['category_id']]] = 1 53 | target = output 54 | 55 | path = coco.loadImgs(img_id)[0]['file_name'] 56 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | if self.target_transform is not None: 61 | target = self.target_transform(target) 62 | return img, target 63 | 64 | 65 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 66 | parser.add_argument('data', metavar='DIR', help='path to dataset') 67 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, 68 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') 69 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 70 | help='number of data loading workers (default: 16)') 71 | parser.add_argument('--epochs', default=36, type=int, metavar='N', 72 | help='number of total epochs to run') 73 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 74 | help='manual epoch number (useful on restarts)') 75 | parser.add_argument('-b', '--batch-size', default=96, type=int, 76 | metavar='N', help='mini-batch size (default: 96)') 77 | parser.add_argument('-g', '--num-gpus', default=4, type=int, 78 | metavar='N', help='number of GPUs to match (default: 4)') 79 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 80 | metavar='LR', help='initial learning rate') 81 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 82 | help='momentum') 83 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 84 | metavar='W', help='weight decay (default: 5e-4)') 85 | parser.add_argument('--print-freq', '-p', default=117, type=int, 86 | metavar='N', help='print frequency (default: 117)') 87 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 88 | help='path to latest checkpoint (default: none)') 89 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 90 | help='evaluate model on validation set') 91 | parser.add_argument('--world-size', default=1, type=int, 92 | help='number of distributed processes') 93 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 94 | help='url used to set up distributed training') 95 | parser.add_argument('--dist-backend', default='gloo', type=str, 96 | help='distributed backend') 97 | 98 | 99 | def main(): 100 | global args 101 | args = parser.parse_args() 102 | print('config: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'num_gpus', args.num_gpus) 103 | iteration_size = args.num_gpus // torch.cuda.device_count() # do multiple iterations 104 | assert iteration_size >= 1 105 | args.weight_decay = args.weight_decay * iteration_size # will cancel out with lr 106 | args.lr = args.lr / iteration_size 107 | args.batch_size = args.batch_size // iteration_size 108 | print('real: wd', args.weight_decay, 'lr', args.lr, 'batch_size', args.batch_size, 'iteration_size', iteration_size) 109 | 110 | args.distributed = args.world_size > 1 111 | 112 | if args.distributed: 113 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 114 | world_size=args.world_size) 115 | 116 | # create model 117 | print("=> creating model '{}'".format(args.arch)) 118 | model = models.__dict__[args.arch](num_classes=80) 119 | 120 | # count number of parameters 121 | count = 0 122 | params = list() 123 | for n, p in model.named_parameters(): 124 | if '.ups.' not in n: 125 | params.append(p) 126 | count += np.prod(p.size()) 127 | print('Parameters:', count) 128 | 129 | # count flops 130 | model = add_flops_counting_methods(model) 131 | model.eval() 132 | image = torch.randn(1, 3, 224, 224) 133 | 134 | model.start_flops_count() 135 | model(image).sum() 136 | model.stop_flops_count() 137 | print("GFLOPs", model.compute_average_flops_cost() / 1000000000.0) 138 | 139 | # normal code 140 | model = torch.nn.DataParallel(model).cuda() 141 | 142 | criterion = nn.BCEWithLogitsLoss().cuda() 143 | optimizer = torch.optim.SGD([{'params': iter(params), 'lr': args.lr}, 144 | ], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 145 | 146 | # optionally resume from a checkpoint 147 | if args.resume: 148 | if os.path.isfile(args.resume): 149 | print("=> loading checkpoint '{}'".format(args.resume)) 150 | checkpoint = torch.load(args.resume) 151 | 152 | resume = ('module.fc.bias' in checkpoint['state_dict'] and 153 | checkpoint['state_dict']['module.fc.bias'].size() == model.module.fc.bias.size()) or \ 154 | ('module.classifier.bias' in checkpoint['state_dict'] and 155 | checkpoint['state_dict']['module.classifier.bias'].size() == model.module.classifier.bias.size()) 156 | if resume: 157 | # True resume: resume training on COCO 158 | model.load_state_dict(checkpoint['state_dict'], strict=False) 159 | optimizer.load_state_dict(checkpoint['optimizer']) if 'optimizer' in checkpoint else print('no optimizer found') 160 | args.start_epoch = checkpoint['epoch'] if 'epoch' in checkpoint else args.start_epoch 161 | else: 162 | # Fake resume: transfer from ImageNet 163 | for n, p in list(checkpoint['state_dict'].items()): 164 | if 'classifier' in n or 'fc' in n: 165 | print(n, 'deleted from state_dict') 166 | del checkpoint['state_dict'][n] 167 | model.load_state_dict(checkpoint['state_dict'], strict=False) 168 | 169 | print("=> loaded checkpoint '{}' (epoch {})" 170 | .format(args.resume, checkpoint['epoch'] if 'epoch' in checkpoint else 'unknown')) 171 | else: 172 | print("=> no checkpoint found at '{}'".format(args.resume)) 173 | 174 | cudnn.benchmark = True 175 | 176 | # Data loading code 177 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 178 | std=[0.229, 0.224, 0.225]) 179 | train_dataset = CocoDetection(os.path.join(args.data, 'train2014'), 180 | os.path.join(args.data, 'annotations/instances_train2014.json'), 181 | transforms.Compose([ 182 | transforms.RandomResizedCrop(224), 183 | transforms.RandomHorizontalFlip(), 184 | transforms.ToTensor(), 185 | normalize, 186 | ])) 187 | val_dataset = CocoDetection(os.path.join(args.data, 'val2014'), 188 | os.path.join(args.data, 'annotations/instances_val2014.json'), 189 | transforms.Compose([ 190 | transforms.Resize((224, 224)), 191 | transforms.ToTensor(), 192 | normalize, 193 | ])) 194 | 195 | train_sampler = torch.utils.data.sampler.RandomSampler(train_dataset) 196 | 197 | train_loader = torch.utils.data.DataLoader( 198 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 199 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 200 | val_loader = torch.utils.data.DataLoader( 201 | val_dataset, batch_size=args.batch_size, shuffle=False, 202 | num_workers=args.workers, pin_memory=True) 203 | 204 | if args.evaluate: 205 | validate_multi(val_loader, model, criterion) 206 | return 207 | 208 | for epoch in range(args.start_epoch, args.epochs): 209 | coco_adjust_learning_rate(optimizer, epoch) 210 | 211 | # train for one epoch 212 | train_multi(train_loader, model, criterion, optimizer, epoch, iteration_size) 213 | 214 | # evaluate on validation set 215 | validate_multi(val_loader, model, criterion) 216 | save_checkpoint({ 217 | 'epoch': epoch + 1, 218 | 'arch': args.arch, 219 | 'state_dict': model.state_dict(), 220 | 'optimizer': optimizer.state_dict(), 221 | }, False, filename='coco_' + args.arch + '_checkpoint.pth.tar') 222 | 223 | 224 | def train_multi(train_loader, model, criterion, optimizer, epoch, iteration_size): 225 | batch_time = AverageMeter() 226 | data_time = AverageMeter() 227 | losses = AverageMeter() 228 | prec = AverageMeter() 229 | rec = AverageMeter() 230 | 231 | # switch to train mode 232 | model.train() 233 | optimizer.zero_grad() 234 | end = time.time() 235 | tp, fp, fn, tn, count = 0, 0, 0, 0, 0 236 | for i, (input, target) in enumerate(train_loader): 237 | # measure data loading time 238 | data_time.update(time.time() - end) 239 | 240 | target = target.cuda(non_blocking=True) 241 | target = target.max(dim=1)[0] 242 | # compute output 243 | output = model(input) 244 | loss = criterion(output, target.float()) * 80.0 245 | 246 | # measure accuracy and record loss 247 | pred = output.data.gt(0.0).long() 248 | 249 | tp += (pred + target).eq(2).sum(dim=0) 250 | fp += (pred - target).eq(1).sum(dim=0) 251 | fn += (pred - target).eq(-1).sum(dim=0) 252 | tn += (pred + target).eq(0).sum(dim=0) 253 | count += input.size(0) 254 | 255 | this_tp = (pred + target).eq(2).sum() 256 | this_fp = (pred - target).eq(1).sum() 257 | this_fn = (pred - target).eq(-1).sum() 258 | this_tn = (pred + target).eq(0).sum() 259 | this_acc = (this_tp + this_tn).float() / (this_tp + this_tn + this_fp + this_fn).float() 260 | 261 | this_prec = this_tp.float() / (this_tp + this_fp).float() * 100.0 if this_tp + this_fp != 0 else 0.0 262 | this_rec = this_tp.float() / (this_tp + this_fn).float() * 100.0 if this_tp + this_fn != 0 else 0.0 263 | 264 | losses.update(float(loss), input.size(0)) 265 | prec.update(float(this_prec), input.size(0)) 266 | rec.update(float(this_rec), input.size(0)) 267 | # compute gradient and do SGD step 268 | loss.backward() 269 | 270 | if i % iteration_size == iteration_size - 1: 271 | optimizer.step() 272 | optimizer.zero_grad() 273 | # measure elapsed time 274 | batch_time.update(time.time() - end) 275 | end = time.time() 276 | 277 | p_c = [float(tp[i].float() / (tp[i] + fp[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] 278 | r_c = [float(tp[i].float() / (tp[i] + fn[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] 279 | f_c = [2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0 for i in range(len(tp))] 280 | 281 | mean_p_c = sum(p_c) / len(p_c) 282 | mean_r_c = sum(r_c) / len(r_c) 283 | mean_f_c = sum(f_c) / len(f_c) 284 | 285 | p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0 286 | r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0 287 | f_o = 2 * p_o * r_o / (p_o + r_o) 288 | 289 | if i % args.print_freq == 0: 290 | print('Epoch: [{0}][{1}/{2}]\t' 291 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 292 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 293 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 294 | 'Precision {prec.val:.2f} ({prec.avg:.2f})\t' 295 | 'Recall {rec.val:.2f} ({rec.avg:.2f})'.format( 296 | epoch, i, len(train_loader), batch_time=batch_time, 297 | data_time=data_time, loss=losses, prec=prec, rec=rec)) 298 | print('P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' 299 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) 300 | 301 | 302 | def validate_multi(val_loader, model, criterion): 303 | batch_time = AverageMeter() 304 | losses = AverageMeter() 305 | prec = AverageMeter() 306 | rec = AverageMeter() 307 | 308 | # switch to evaluate mode 309 | model.eval() 310 | 311 | end = time.time() 312 | tp, fp, fn, tn, count = 0, 0, 0, 0, 0 313 | tp_size, fn_size = 0, 0 314 | for i, (input, target) in enumerate(val_loader): 315 | target = target.cuda(non_blocking=True) 316 | original_target = target 317 | target = target.max(dim=1)[0] 318 | # compute output 319 | with torch.no_grad(): 320 | output = model(input) 321 | loss = criterion(output, target.float()) 322 | 323 | # measure accuracy and record loss 324 | pred = output.data.gt(0.0).long() 325 | 326 | tp += (pred + target).eq(2).sum(dim=0) 327 | fp += (pred - target).eq(1).sum(dim=0) 328 | fn += (pred - target).eq(-1).sum(dim=0) 329 | tn += (pred + target).eq(0).sum(dim=0) 330 | three_pred = pred.unsqueeze(1).expand(-1, 3, -1) # n, 3, 80 331 | tp_size += (three_pred + original_target).eq(2).sum(dim=0) 332 | fn_size += (three_pred - original_target).eq(-1).sum(dim=0) 333 | count += input.size(0) 334 | 335 | this_tp = (pred + target).eq(2).sum() 336 | this_fp = (pred - target).eq(1).sum() 337 | this_fn = (pred - target).eq(-1).sum() 338 | this_tn = (pred + target).eq(0).sum() 339 | this_acc = (this_tp + this_tn).float() / (this_tp + this_tn + this_fp + this_fn).float() 340 | 341 | this_prec = this_tp.float() / (this_tp + this_fp).float() * 100.0 if this_tp + this_fp != 0 else 0.0 342 | this_rec = this_tp.float() / (this_tp + this_fn).float() * 100.0 if this_tp + this_fn != 0 else 0.0 343 | 344 | losses.update(float(loss), input.size(0)) 345 | prec.update(float(this_prec), input.size(0)) 346 | rec.update(float(this_rec), input.size(0)) 347 | 348 | # measure elapsed time 349 | batch_time.update(time.time() - end) 350 | end = time.time() 351 | 352 | p_c = [float(tp[i].float() / (tp[i] + fp[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] 353 | r_c = [float(tp[i].float() / (tp[i] + fn[i]).float()) * 100.0 if tp[i] > 0 else 0.0 for i in range(len(tp))] 354 | f_c = [2 * p_c[i] * r_c[i] / (p_c[i] + r_c[i]) if tp[i] > 0 else 0.0 for i in range(len(tp))] 355 | 356 | mean_p_c = sum(p_c) / len(p_c) 357 | mean_r_c = sum(r_c) / len(r_c) 358 | mean_f_c = sum(f_c) / len(f_c) 359 | 360 | p_o = tp.sum().float() / (tp + fp).sum().float() * 100.0 361 | r_o = tp.sum().float() / (tp + fn).sum().float() * 100.0 362 | f_o = 2 * p_o * r_o / (p_o + r_o) 363 | 364 | if i % args.print_freq == 0: 365 | print('Test: [{0}/{1}]\t' 366 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 367 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 368 | 'Precision {prec.val:.2f} ({prec.avg:.2f})\t' 369 | 'Recall {rec.val:.2f} ({rec.avg:.2f})'.format( 370 | i, len(val_loader), batch_time=batch_time, loss=losses, 371 | prec=prec, rec=rec)) 372 | print('P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' 373 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) 374 | 375 | print('--------------------------------------------------------------------') 376 | print(' * P_C {:.2f} R_C {:.2f} F_C {:.2f} P_O {:.2f} R_O {:.2f} F_O {:.2f}' 377 | .format(mean_p_c, mean_r_c, mean_f_c, p_o, r_o, f_o)) 378 | return 379 | 380 | 381 | def coco_adjust_learning_rate(optimizer, epoch): 382 | if isinstance(optimizer, torch.optim.Adam): 383 | return 384 | lr = args.lr 385 | # if epoch >= 12: 386 | # lr *= 0.1 387 | if epoch >= 24: 388 | lr *= 0.1 389 | if epoch >= 30: 390 | lr *= 0.1 391 | for param_group in optimizer.param_groups: 392 | param_group['lr'] = lr 393 | 394 | 395 | if __name__ == '__main__': 396 | main() 397 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import pdb 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | from torch.autograd import Variable 11 | from utils import AverageMeter, inter_and_union, VOCSegmentation 12 | import models 13 | 14 | model_names = ['resnext50', 'resnext50_elastic', 'resnext101', 'resnext101_elastic', 'dla60x', 'dla60x_elastic'] 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--train', action='store_true', default=False, 18 | help='training mode') 19 | parser.add_argument('--exp', type=str, required=True, 20 | help='name of experiment') 21 | parser.add_argument('--gpu', type=int, default=0, 22 | help='test time gpu device id') 23 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnext50_elastic', choices=model_names, 24 | help='model architecture: ' + ' | '.join(model_names) + ' (default: resnext50_elastic)') 25 | parser.add_argument('--dataset', type=str, default='pascal', 26 | help='pascal') 27 | parser.add_argument('--epochs', type=int, default=50, 28 | help='num of training epochs') 29 | parser.add_argument('--batch_size', type=int, default=16, 30 | help='batch size') 31 | parser.add_argument('--base_lr', type=float, default=0.007, 32 | help='base learning rate') 33 | parser.add_argument('--last_mult', type=float, default=1.0, 34 | help='learning rate multiplier for last layers') 35 | parser.add_argument('--freeze_bn', action='store_true', default=False, 36 | help='freeze batch normalization parameters') 37 | parser.add_argument('--crop_size', type=int, default=513, 38 | help='image crop size') 39 | parser.add_argument('--resume', type=str, default=None, 40 | help='path to checkpoint to resume from') 41 | parser.add_argument('--workers', type=int, default=4, 42 | help='number of data loading workers') 43 | args = parser.parse_args() 44 | 45 | 46 | def main(): 47 | assert torch.cuda.is_available() 48 | model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format( 49 | args.arch, args.dataset, args.exp) 50 | if args.dataset == 'pascal': 51 | dataset = VOCSegmentation('data/VOCdevkit', 52 | train=args.train, crop_size=args.crop_size) 53 | else: 54 | raise ValueError('Unknown dataset: {}'.format(args.dataset)) 55 | 56 | if 'resnext' in args.arch: 57 | model = models.__dict__[args.arch](seg=True, num_classes=len(dataset.CLASSES)) 58 | elif 'dla' in args.arch: 59 | model = models.__dict__[args.arch + '_seg'](classes=len(dataset.CLASSES)) 60 | else: 61 | raise ValueError('Unknown arch: {}'.format(args.arch)) 62 | 63 | if args.train: 64 | criterion = nn.CrossEntropyLoss(ignore_index=255) 65 | model = nn.DataParallel(model).cuda() 66 | model.train() 67 | if args.freeze_bn: 68 | for m in model.modules(): 69 | if isinstance(m, nn.BatchNorm2d): 70 | m.eval() 71 | m.weight.requires_grad = False 72 | m.bias.requires_grad = False 73 | if 'resnext' in args.arch: 74 | arch_params = ( 75 | list(model.module.conv1.parameters()) + 76 | list(model.module.bn1.parameters()) + 77 | list(model.module.layer1.parameters()) + 78 | list(model.module.layer2.parameters()) + 79 | list(model.module.layer3.parameters()) + 80 | list(model.module.layer4.parameters())) 81 | last_params = list(model.module.aspp.parameters()) 82 | else: 83 | arch_params = list(model.module.base.parameters()) 84 | last_params = list() 85 | for n, p in model.named_parameters(): 86 | if 'base' not in n and 'up.weight' not in n: 87 | last_params.append(p) 88 | 89 | optimizer = optim.SGD([ 90 | {'params': filter(lambda p: p.requires_grad, arch_params)}, 91 | {'params': filter(lambda p: p.requires_grad, last_params)}], 92 | lr=args.base_lr, momentum=0.9, weight_decay=0.0005 if 'resnext' in args.arch else 0.0001) 93 | dataset_loader = torch.utils.data.DataLoader( 94 | dataset, batch_size=args.batch_size, shuffle=args.train, 95 | pin_memory=True, num_workers=args.workers) 96 | max_iter = args.epochs * len(dataset_loader) 97 | losses = AverageMeter() 98 | start_epoch = 0 99 | 100 | if args.resume: 101 | if os.path.isfile(args.resume): 102 | print('=> loading checkpoint {0}'.format(args.resume)) 103 | checkpoint = torch.load(args.resume) 104 | 105 | resume = False 106 | for n, p in list(checkpoint['state_dict'].items()): 107 | if 'aspp' in n or 'dla_up' in n: 108 | resume = True 109 | break 110 | if resume: 111 | # True resume: resume training on pascal 112 | model.load_state_dict(checkpoint['state_dict'], strict=True) 113 | optimizer.load_state_dict(checkpoint['optimizer']) 114 | start_epoch = checkpoint['epoch'] 115 | else: 116 | # Fake resume: transfer from ImageNet 117 | if 'resnext' in args.arch: 118 | model.load_state_dict(checkpoint['state_dict'], strict=False) 119 | else: 120 | pretrained_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()} 121 | model.module.base.load_state_dict(pretrained_dict, strict=False) 122 | print('=> loaded checkpoint {0} (epoch {1})'.format( 123 | args.resume, start_epoch)) 124 | else: 125 | print('=> no checkpoint found at {0}'.format(args.resume)) 126 | 127 | for epoch in range(start_epoch, args.epochs): 128 | for i, (inputs, target, _, _, _, _) in enumerate(dataset_loader): 129 | cur_iter = epoch * len(dataset_loader) + i 130 | lr = args.base_lr * (1 - float(cur_iter) / max_iter) ** 0.9 131 | optimizer.param_groups[0]['lr'] = lr 132 | optimizer.param_groups[1]['lr'] = lr * args.last_mult 133 | 134 | inputs = Variable(inputs.cuda()) 135 | target = Variable(target.cuda()) 136 | 137 | outputs = model(inputs) 138 | loss = criterion(outputs, target) 139 | if np.isnan(loss.item()) or np.isinf(loss.item()): 140 | pdb.set_trace() 141 | losses.update(loss.item(), args.batch_size) 142 | loss.backward() 143 | optimizer.step() 144 | optimizer.zero_grad() 145 | 146 | print('epoch: {0}\t' 147 | 'iter: {1}/{2}\t' 148 | 'lr: {3:.6f}\t' 149 | 'loss: {loss.val:.4f} ({loss.ema:.4f})'.format( 150 | epoch + 1, i + 1, len(dataset_loader), lr, loss=losses)) 151 | 152 | if epoch % 10 == 9: 153 | torch.save({ 154 | 'epoch': epoch + 1, 155 | 'state_dict': model.state_dict(), 156 | 'optimizer': optimizer.state_dict(), 157 | }, model_fname % (epoch + 1)) 158 | 159 | else: 160 | torch.cuda.set_device(args.gpu) 161 | model = model.cuda() 162 | model.eval() 163 | checkpoint = torch.load(model_fname % args.epochs) 164 | state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items() if 'tracked' not in k} 165 | model.load_state_dict(state_dict) 166 | cmap = loadmat('data/pascal_seg_colormap.mat')['colormap'] 167 | cmap = (cmap * 255).astype(np.uint8).flatten().tolist() 168 | 169 | inter_meter = AverageMeter() 170 | union_meter = AverageMeter() 171 | for i in range(len(dataset)): 172 | inputs, target, a, b, h, w = dataset[i] 173 | inputs = inputs.unsqueeze(0) 174 | inputs = Variable(inputs.cuda()) 175 | outputs = model(inputs) 176 | _, pred = torch.max(outputs, 1) 177 | pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) 178 | mask = target.numpy().astype(np.uint8) 179 | imname = dataset.masks[i].split('/')[-1] 180 | 181 | inter, union = inter_and_union(pred, mask, len(dataset.CLASSES)) 182 | inter_meter.update(inter) 183 | union_meter.update(union) 184 | 185 | mask_pred = Image.fromarray(pred[a:a + h, b:b + w]) 186 | mask_pred.putpalette(cmap) 187 | mask_pred.save(os.path.join('data/val', imname)) 188 | print('eval: {0}/{1}'.format(i + 1, len(dataset))) 189 | 190 | iou = inter_meter.sum / (union_meter.sum + 1e-10) 191 | for i, val in enumerate(iou): 192 | print('IoU {0}: {1:.2f}'.format(dataset.CLASSES[i], val * 100)) 193 | print('Mean IoU: {0:.2f}'.format(iou.mean() * 100)) 194 | 195 | 196 | if __name__ == "__main__": 197 | main() 198 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import math 3 | import random 4 | import torchvision.transforms as transforms 5 | import warnings 6 | from torch.nn import functional as F 7 | import shutil 8 | import torch.utils.data as data 9 | import os 10 | from PIL import Image 11 | import torch 12 | import numpy as np 13 | 14 | 15 | def accuracy(output, target, topk=(1,)): 16 | """Computes the precision@k for the specified values of k""" 17 | maxk = max(topk) 18 | batch_size = target.size(0) 19 | 20 | _, pred = output.topk(maxk, 1, True, True) 21 | pred = pred.t() 22 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 23 | 24 | res = [] 25 | for k in topk: 26 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 27 | res.append(correct_k.mul_(100.0 / batch_size)) 28 | return res 29 | 30 | 31 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 32 | torch.save(state, filename) 33 | if is_best: 34 | shutil.copyfile(filename, 'model_best.pth.tar') 35 | 36 | 37 | class AverageMeter(object): 38 | def __init__(self): 39 | self.val = None 40 | self.sum = None 41 | self.cnt = None 42 | self.avg = None 43 | self.ema = None 44 | self.initialized = False 45 | 46 | def update(self, val, n=1): 47 | if not self.initialized: 48 | self.initialize(val, n) 49 | else: 50 | self.add(val, n) 51 | 52 | def initialize(self, val, n): 53 | self.val = val 54 | self.sum = val * n 55 | self.cnt = n 56 | self.avg = val 57 | self.ema = val 58 | self.initialized = True 59 | 60 | def add(self, val, n): 61 | self.val = val 62 | self.sum += val * n 63 | self.cnt += n 64 | self.avg = self.sum / self.cnt 65 | self.ema = self.ema * 0.99 + self.val * 0.01 66 | 67 | 68 | def inter_and_union(pred, mask, num_class): 69 | pred = np.asarray(pred, dtype=np.uint8).copy() 70 | mask = np.asarray(mask, dtype=np.uint8).copy() 71 | 72 | # 255 -> 0 73 | pred += 1 74 | mask += 1 75 | pred = pred * (mask > 0) 76 | 77 | inter = pred * (pred == mask) 78 | (area_inter, _) = np.histogram(inter, bins=num_class, range=(1, num_class)) 79 | (area_pred, _) = np.histogram(pred, bins=num_class, range=(1, num_class)) 80 | (area_mask, _) = np.histogram(mask, bins=num_class, range=(1, num_class)) 81 | area_union = area_pred + area_mask - area_inter 82 | 83 | return (area_inter, area_union) 84 | 85 | 86 | def preprocess(image, mask, flip=False, scale=None, crop=None): 87 | if flip: 88 | if random.random() < 0.5: 89 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 90 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 91 | if scale: 92 | w, h = image.size 93 | rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2)) 94 | random_scale = math.pow(2, rand_log_scale) 95 | new_size = (int(round(w * random_scale)), int(round(h * random_scale))) 96 | image = image.resize(new_size, Image.ANTIALIAS) 97 | mask = mask.resize(new_size, Image.NEAREST) 98 | 99 | data_transforms = transforms.Compose([ 100 | transforms.ToTensor(), 101 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 102 | ]) 103 | image = data_transforms(image) 104 | mask = torch.LongTensor(np.array(mask).astype(np.int64)) 105 | 106 | if crop: 107 | h, w = image.shape[1], image.shape[2] 108 | ori_h, ori_w = image.shape[1], image.shape[2] 109 | 110 | pad_tb = max(0, int((1 + crop[0] - h) / 2)) 111 | pad_lr = max(0, int((1 + crop[1] - w) / 2)) 112 | image = torch.nn.ZeroPad2d((pad_lr, pad_lr, pad_tb, pad_tb))(image) 113 | mask = torch.nn.ConstantPad2d((pad_lr, pad_lr, pad_tb, pad_tb), 255)(mask) 114 | 115 | h, w = image.shape[1], image.shape[2] 116 | i = random.randint(0, h - crop[0]) 117 | j = random.randint(0, w - crop[1]) 118 | image = image[:, i:i + crop[0], j:j + crop[1]] 119 | mask = mask[i:i + crop[0], j:j + crop[1]] 120 | 121 | return image, mask, pad_tb - j, pad_lr - i, ori_h, ori_w 122 | 123 | 124 | # pascal dataloader 125 | class VOCSegmentation(data.Dataset): 126 | CLASSES = [ 127 | 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 128 | 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 129 | 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 130 | 'tv/monitor' 131 | ] 132 | 133 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, crop_size=None): 134 | self.root = root 135 | _voc_root = os.path.join(self.root, 'VOC2012') 136 | _list_dir = os.path.join(_voc_root, 'list') 137 | self.transform = transform 138 | self.target_transform = target_transform 139 | self.train = train 140 | self.crop_size = crop_size 141 | 142 | if download: 143 | self.download() 144 | 145 | if self.train: 146 | _list_f = os.path.join(_list_dir, 'train_aug.txt') 147 | else: 148 | _list_f = os.path.join(_list_dir, 'val.txt') 149 | self.images = [] 150 | self.masks = [] 151 | with open(_list_f, 'r') as lines: 152 | for line in lines: 153 | _image = _voc_root + line.split()[0] 154 | _mask = _voc_root + line.split()[1] 155 | assert os.path.isfile(_image) 156 | assert os.path.isfile(_mask) 157 | self.images.append(_image) 158 | self.masks.append(_mask) 159 | 160 | def __getitem__(self, index): 161 | _img = Image.open(self.images[index]).convert('RGB') 162 | _target = Image.open(self.masks[index]) 163 | 164 | _img, _target, a, b, h, w = preprocess(_img, _target, 165 | flip=True if self.train else False, 166 | scale=(0.5, 2.0) if self.train else None, 167 | crop=(self.crop_size, self.crop_size)) 168 | 169 | if self.transform is not None: 170 | _img = self.transform(_img) 171 | 172 | if self.target_transform is not None: 173 | _target = self.target_transform(_target) 174 | 175 | return _img, _target, a, b, h, w # used for visualizing 176 | 177 | def __len__(self): 178 | return len(self.images) 179 | 180 | def download(self): 181 | raise NotImplementedError('Automatic download not yet implemented.') 182 | 183 | 184 | # flops counter 185 | def add_flops_counting_methods(net_main_module): 186 | """Adds flops counting functions to an existing model. After that 187 | the flops count should be activated and the model should be run on an input 188 | image. 189 | 190 | Example: 191 | 192 | fcn = add_flops_counting_methods(fcn) 193 | fcn = fcn.cuda().train() 194 | fcn.start_flops_count() 195 | 196 | 197 | _ = fcn(batch) 198 | 199 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch 200 | 201 | Important: dividing by 2 only works for resnet models -- see below for the details 202 | of flops computation. 203 | 204 | Attention: we are counting multiply-add as two flops in this work, because in 205 | most resnet models convolutions are bias-free (BN layers act as bias there) 206 | and it makes sense to count muliply and add as separate flops therefore. 207 | This is why in the above example we divide by 2 in order to be consistent with 208 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual 209 | Networks" by Figurnov et al multiply-add was counted as two flops. 210 | 211 | This module computes the average flops which is necessary for dynamic networks which 212 | have different number of executed layers. For static networks it is enough to run the network 213 | once and get statistics (above example). 214 | 215 | Implementation: 216 | The module works by adding batch_count to the main module which tracks the sum 217 | of all batch sizes that were run through the network. 218 | 219 | Also each convolutional layer of the network tracks the overall number of flops 220 | performed. 221 | 222 | The parameters are updated with the help of registered hook-functions which 223 | are being called each time the respective layer is executed. 224 | 225 | Parameters 226 | ---------- 227 | net_main_module : torch.nn.Module 228 | Main module containing network 229 | 230 | Returns 231 | ------- 232 | net_main_module : torch.nn.Module 233 | Updated main module with new methods/attributes that are used 234 | to compute flops. 235 | """ 236 | 237 | # adding additional methods to the existing module object, 238 | # this is done this way so that each function has access to self object 239 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 240 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 241 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 242 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 243 | 244 | net_main_module.reset_flops_count() 245 | 246 | # Adding variables necessary for masked flops computation 247 | net_main_module.apply(add_flops_mask_variable_or_reset) 248 | 249 | return net_main_module 250 | 251 | 252 | def compute_average_flops_cost(self): 253 | """ 254 | A method that will be available after add_flops_counting_methods() is called 255 | on a desired net object. 256 | 257 | Returns current mean flops consumption per image. 258 | 259 | """ 260 | 261 | batches_count = self.__batch_counter__ 262 | flops_sum = 0 263 | for module in self.modules(): 264 | if hasattr(module, '__flops__'): # is_supported_instance(module) 265 | flops_sum += module.__flops__ 266 | 267 | return flops_sum / batches_count 268 | 269 | 270 | def start_flops_count(self): 271 | """ 272 | A method that will be available after add_flops_counting_methods() is called 273 | on a desired net object. 274 | 275 | Activates the computation of mean flops consumption per image. 276 | Call it before you run the network. 277 | 278 | """ 279 | add_batch_counter_hook_function(self) 280 | self.apply(add_flops_counter_hook_function) 281 | 282 | 283 | def stop_flops_count(self): 284 | """ 285 | A method that will be available after add_flops_counting_methods() is called 286 | on a desired net object. 287 | 288 | Stops computing the mean flops consumption per image. 289 | Call whenever you want to pause the computation. 290 | 291 | """ 292 | remove_batch_counter_hook_function(self) 293 | self.apply(remove_flops_counter_hook_function) 294 | 295 | 296 | def reset_flops_count(self): 297 | """ 298 | A method that will be available after add_flops_counting_methods() is called 299 | on a desired net object. 300 | 301 | Resets statistics computed so far. 302 | 303 | """ 304 | add_batch_counter_variables_or_reset(self) 305 | self.apply(add_flops_counter_variable_or_reset) 306 | 307 | 308 | def add_flops_mask(module, mask): 309 | def add_flops_mask_func(module): 310 | if isinstance(module, torch.nn.Conv2d): 311 | module.__mask__ = mask 312 | module.apply(add_flops_mask_func) 313 | 314 | 315 | def remove_flops_mask(module): 316 | module.apply(add_flops_mask_variable_or_reset) 317 | 318 | 319 | # ---- Internal functions 320 | def is_supported_instance(module): 321 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.ReLU) \ 322 | or isinstance(module, torch.nn.PReLU) or isinstance(module, torch.nn.ELU) \ 323 | or isinstance(module, torch.nn.LeakyReLU) or isinstance(module, torch.nn.ReLU6) \ 324 | or isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.MaxPool2d) \ 325 | or isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.BatchNorm2d): 326 | return True 327 | 328 | return False 329 | 330 | 331 | def empty_flops_counter_hook(module, input, output): 332 | module.__flops__ += 0 333 | 334 | 335 | def relu_flops_counter_hook(module, input, output): 336 | input = input[0] 337 | batch_size = input.shape[0] 338 | active_elements_count = batch_size 339 | for val in input.shape[1:]: 340 | active_elements_count *= val 341 | 342 | module.__flops__ += active_elements_count 343 | 344 | 345 | def linear_flops_counter_hook(module, input, output): 346 | input = input[0] 347 | batch_size = input.shape[0] 348 | module.__flops__ += batch_size * input.shape[1] * output.shape[1] 349 | 350 | 351 | def pool_flops_counter_hook(module, input, output): 352 | input = input[0] 353 | module.__flops__ += np.prod(input.shape) 354 | 355 | def bn_flops_counter_hook(module, input, output): 356 | module.affine 357 | input = input[0] 358 | 359 | batch_flops = np.prod(input.shape) 360 | if module.affine: 361 | batch_flops *= 2 362 | module.__flops__ += batch_flops 363 | 364 | def conv_flops_counter_hook(conv_module, input, output): 365 | # Can have multiple inputs, getting the first one 366 | input = input[0] 367 | 368 | batch_size = input.shape[0] 369 | output_height, output_width = output.shape[2:] 370 | 371 | kernel_height, kernel_width = conv_module.kernel_size 372 | in_channels = conv_module.in_channels 373 | out_channels = conv_module.out_channels 374 | groups = conv_module.groups 375 | 376 | filters_per_channel = out_channels // groups 377 | conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel 378 | 379 | active_elements_count = batch_size * output_height * output_width 380 | 381 | if conv_module.__mask__ is not None: 382 | # (b, 1, h, w) 383 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 384 | active_elements_count = flops_mask.sum() 385 | 386 | overall_conv_flops = conv_per_position_flops * active_elements_count 387 | 388 | bias_flops = 0 389 | 390 | if conv_module.bias is not None: 391 | 392 | bias_flops = out_channels * active_elements_count 393 | 394 | overall_flops = overall_conv_flops + bias_flops 395 | 396 | conv_module.__flops__ += overall_flops 397 | 398 | 399 | def batch_counter_hook(module, input, output): 400 | # Can have multiple inputs, getting the first one 401 | input = input[0] 402 | batch_size = input.shape[0] 403 | module.__batch_counter__ += batch_size 404 | 405 | 406 | def add_batch_counter_variables_or_reset(module): 407 | 408 | module.__batch_counter__ = 0 409 | 410 | 411 | def add_batch_counter_hook_function(module): 412 | if hasattr(module, '__batch_counter_handle__'): 413 | return 414 | 415 | handle = module.register_forward_hook(batch_counter_hook) 416 | module.__batch_counter_handle__ = handle 417 | 418 | 419 | def remove_batch_counter_hook_function(module): 420 | if hasattr(module, '__batch_counter_handle__'): 421 | module.__batch_counter_handle__.remove() 422 | del module.__batch_counter_handle__ 423 | 424 | 425 | def add_flops_counter_variable_or_reset(module): 426 | if is_supported_instance(module): 427 | module.__flops__ = 0 428 | 429 | 430 | def add_flops_counter_hook_function(module): 431 | if is_supported_instance(module): 432 | if hasattr(module, '__flops_handle__'): 433 | return 434 | 435 | if isinstance(module, torch.nn.Conv2d): 436 | handle = module.register_forward_hook(conv_flops_counter_hook) 437 | elif isinstance(module, torch.nn.ReLU) or isinstance(module, torch.nn.PReLU) \ 438 | or isinstance(module, torch.nn.ELU) or isinstance(module, torch.nn.LeakyReLU) \ 439 | or isinstance(module, torch.nn.ReLU6): 440 | handle = module.register_forward_hook(relu_flops_counter_hook) 441 | elif isinstance(module, torch.nn.Linear): 442 | handle = module.register_forward_hook(linear_flops_counter_hook) 443 | elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d): 444 | handle = module.register_forward_hook(pool_flops_counter_hook) 445 | elif isinstance(module, torch.nn.BatchNorm2d): 446 | handle = module.register_forward_hook(bn_flops_counter_hook) 447 | else: 448 | handle = module.register_forward_hook(empty_flops_counter_hook) 449 | module.__flops_handle__ = handle 450 | 451 | 452 | def remove_flops_counter_hook_function(module): 453 | if is_supported_instance(module): 454 | if hasattr(module, '__flops_handle__'): 455 | module.__flops_handle__.remove() 456 | del module.__flops_handle__ 457 | # --- Masked flops counting 458 | 459 | 460 | # Also being run in the initialization 461 | def add_flops_mask_variable_or_reset(module): 462 | if is_supported_instance(module): 463 | module.__mask__ = None 464 | --------------------------------------------------------------------------------