├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── data └── list │ └── cityscapes │ ├── test.lst │ ├── train.lst │ ├── trainval.lst │ └── val.lst ├── experiments └── cityscapes │ ├── w18.yaml │ └── w48.yaml ├── lib ├── config │ ├── __init__.py │ ├── default.py │ └── models.py ├── core │ ├── criterion.py │ └── function.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ └── cityscapes.py ├── models │ ├── __init__.py │ ├── conv_mask.py │ ├── model_anytime.py │ └── sync_bn │ │ ├── __init__.py │ │ └── inplace_abn │ │ ├── __init__.py │ │ ├── bn.py │ │ ├── functions.py │ │ └── src │ │ ├── common.h │ │ ├── inplace_abn.cpp │ │ ├── inplace_abn.h │ │ ├── inplace_abn_cpu.cpp │ │ └── inplace_abn_cuda.cu └── utils │ ├── __init__.py │ ├── metric.py │ ├── modelsummary.py │ └── utils.py ├── requirements.txt └── tools ├── _init_paths.py ├── test_ee.py └── train_ee.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | pretrained_models/ 3 | scripts/ 4 | output_new/ -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | We provide installation instructions for Cityscapes segmentation experiments here. 3 | ## Dependency Setup 4 | Create a new conda virtual environment 5 | ``` 6 | conda create -n anytime python=3.8 -y 7 | conda activate anytime 8 | ``` 9 | Install `PyTorch=1.1.0` 10 | ``` 11 | pip install torch==1.1.0 12 | ``` 13 | Clone this repo and install required packages 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Data preparation 19 | Download the [Cityscapes](https://www.cityscapes-dataset.com/) dataset and place a symbolic link under the `data` folder. 20 | 21 | ``` 22 | mkdir data 23 | ln -s $DATA_ROOT data 24 | ``` 25 | 26 | Structure the data as follows 27 | ```` 28 | $ROOT/data 29 | └── cityscapes 30 | ├── gtFine 31 | │ ├── test 32 | │ ├── train 33 | │ └── val 34 | └── leftImg8bit 35 | ├── test 36 | ├── train 37 | └── val 38 | 39 | ```` 40 | 41 | ## Pretrained model preparation 42 | Create a folder named `pretrained_models` under the root directory. 43 | ``` 44 | mkdir pretrained_models 45 | ``` 46 | Download the [HRNet-W18-C-Small-v2](https://1drv.ms/u/s!Aus8VCZ_C_33gRmfdPR79WBS61Qn?e=HVZUi8) and [HRNet-W48-C](https://1drv.ms/u/s!Aus8VCZ_C_33dKvqI6pBZlifgJk) from [HRNet-Image-Classification](https://github.com/HRNet/HRNet-Image-Classification.git) 47 | and structure the directory as follows 48 | ``` 49 | pretrained_models 50 | ├── hrnet_w18_small_model_v2.pth 51 | └── hrnetv2_w48_imagenet_pretrained.pth 52 | ``` 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zhuang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Anytime Dense Prediction with Confidence Adaptivity](https://arxiv.org/abs/2104.00749) 2 | 3 | Official PyTorch implementation for the following paper: 4 | 5 | [Anytime Dense Prediction with Confidence Adaptivity](https://openreview.net/forum?id=kNKFOXleuC). ICLR 2022.\ 6 | [Zhuang Liu](https://liuzhuang13.github.io), [Zhiqiu Xu](https://www.linkedin.com/in/oscar-xu-1250821a1/), [Hung-ju Wang](https://www.linkedin.com/in/hungju-wang-5a5124172/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) and [Evan Shelhamer](http://imaginarynumber.net/)\ 7 | UC Berkeley, Adobe Research 8 | 9 | Our implementation is based upon [HRNet-Semantic-Segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/pytorch-v1.1). 10 | 11 | --- 12 |

13 | 15 |

16 | 17 | Our full method, named **Anytime Dense Prediction with Confidence (ADP-C)**, achieves the same level of final accuracy with HRNet-w48, and meanwhile significantly reduces total computation. 18 | 19 | ### Main Results 20 | 21 | 22 | | Setting (HRNet-W48) | model | exit1 | exit2 | exit3 | exit4 | mean mIoU | exit1 | exit2 | exit3 | exit4 | mean GFLOPs | 23 | | ------------------------- | :---: | :---: | :---: | :---: | :------: | :---------: | :---: | :---: | :---: | :-------: | :---------: | 24 | | HRNet-W48 | - | - | - | 80.7 | - | - | - | - | 696.2 | - | 25 | | EE | [model](https://drive.google.com/file/d/1GOXuP0e-qDp1mqiilxdhL8FyP8E-6pZP/view?usp=sharing) | 34.3 | 59.0 | 76.9 | 80.4 | 62.7 | 521.6 | 717.9 | 914.2 | 1110.5 | 816.0 | 26 | | EE + RH | [model](https://drive.google.com/file/d/11QNuEpq-oBErMKO3eMEddAU8ug9OYyts/view?usp=sharing) | 44.6 | 60.2 | 76.6 | 79.9 | 65.3 | 41.9 | 105.6 | 368.0 | 701.3 | 304.2 | 27 | | ADP-C: EE + RH + CA | [model](https://drive.google.com/file/d/1zcKkKWuknrLHpOEVRvUm82xlowjafQ_u/view?usp=sharing) | 44.3 | 60.1 | 76.8 | **81.3** | **65.7** | 41.9 | 93.9 | 259.3 | **387.1** | **195.6** | 28 | 29 | 30 | 31 | ## Installation 32 | Please check [INSTALL.md](INSTALL.md) for installation instructions. 33 | 34 | ## Evaluation on pretrained models 35 | 36 | Download our pretrained model from the table above and specify its location by `TEST.MODEL_FILE` 37 | 38 | **Early Exits (EE)** 39 | ```bash 40 | python tools/test_ee.py --cfg experiments/cityscapes/w48.yaml \ 41 | TEST.MODEL_FILE .pth 42 | ``` 43 | This should give 44 | ``` 45 | 34.33 59.01 76.90 80.43 62.67 46 | ``` 47 | 48 | **Redesigned Heads (RH)** 49 | ```bash 50 | python tools/test_ee.py --cfg experiments/cityscapes/w48.yaml \ 51 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 \ 52 | TEST.MODEL_FILE .pth 53 | ``` 54 | 55 | This should give 56 | ``` 57 | 44.61 60.19 76.64 79.89 65.33 58 | ``` 59 | 60 | **ADP-C (EE + RH + CA)** 61 | ```bash 62 | python tools/test_ee.py \ 63 | --cfg experiments/cityscapes/w48.yaml MODEL.NAME model_anytime \ 64 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 \ 65 | MASK.USE True MASK.CONF_THRE 0.998 \ 66 | TEST.MODEL_FILE .pth 67 | ``` 68 | 69 | This should give 70 | ``` 71 | 44.34 60.13 76.82 81.31 65.65 72 | ``` 73 | 74 | **ADP-C (EE + RH + CA)** (w18) [Pretrained w18 with ADP-C](https://drive.google.com/file/d/1bU7spRV236OV7D5dgZzHy_AI0FN3oGAp/view?usp=sharing) 75 | ```bash 76 | python tools/test_ee.py \ 77 | --cfg experiments/cityscapes/w18.yaml MODEL.NAME model_anytime \ 78 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 64 \ 79 | MASK.USE True MASK.CONF_THRE 0.998 \ 80 | TEST.MODEL_FILE .pth 81 | ``` 82 | 83 | This should give 84 | ``` 85 | 40.83 48.19 68.26 77.02 58.57 86 | ``` 87 | 88 | 89 | ## Train 90 | 91 | There are two configurations for the backbone HRnet model, `w48.yaml` and `w18.yaml` under `experimens/cityscapes`. Note that the following commands are for using `HRNet-w48` as backbone. Please change `EXIT.INTER_CHANNEL` to `64` when using `w18` as backbone. 92 | 93 | **Early Exits (EE)** 94 | 95 | ````bash 96 | python -m torch.distributed.launch tools/train_ee.py \ 97 | --cfg experiments/cityscapes/w48.yaml 98 | ```` 99 | 100 | 101 | **Redesigned Heads (RH)** 102 | 103 | ````bash 104 | python -m torch.distributed.launch tools/train_ee.py \ 105 | --cfg experiments/cityscapes/w48.yaml \ 106 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 107 | ```` 108 | 109 | 110 | **Confidence Adatative (CA)** 111 | 112 | ````bash 113 | python -m torch.distributed.launch tools/train_ee.py \ 114 | --cfg experiments/cityscapes/w48.yaml \ 115 | MASK.USE True MASK.CONF_THRE 0.998 116 | ```` 117 | 118 | 119 | **ADP-C (EE + RH + CA)** 120 | 121 | ````bash 122 | python -m torch.distributed.launch tools/train_ee.py \ 123 | --cfg experiments/cityscapes/w48.yaml \ 124 | EXIT.TYPE 'flex' EXIT.INTER_CHANNEL 128 \ 125 | MASK.USE True MASK.CONF_THRE 0.998 126 | ```` 127 | 128 | Evaulation results will be generated at the end of training. 129 | 130 | - `result.txt`: contains mIOU for each exit and the average mIOU of the four exits. 131 | 132 | - `test_stats.json`: contains FLOPs and number of parameters. 133 | 134 | - `final_state.pth`: the trained model file. 135 | 136 | - `config.yaml`: the configuration file. 137 | 138 | ## Test 139 | 140 | **Evaluation** 141 | 142 | ``` 143 | python tools/test_ee.py --cfg /config.yaml 144 | ``` 145 | 146 | ## Acknowledgement 147 | This repository is built upon [HRNet-Semantic-Segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/pytorch-v1.1). 148 | 149 | ## License 150 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 151 | 152 | ## Citation 153 | If you find this repository helpful, please consider citing: 154 | ``` 155 | @Article{liu2022anytime, 156 | author = {Zhuang Liu and Zhiqiu Xu and Hung-Ju Wang and Trevor Darrell and Evan Shelhamer}, 157 | title = {Anytime Dense Prediction with Confidence Adaptivity}, 158 | journal = {International Conference on Learning Representations (ICLR)}, 159 | year = {2022}, 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /experiments/cityscapes/w18.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | GPUS: (0,1,2,3) 6 | OUTPUT_DIR: 'output' 7 | LOG_DIR: 'log' 8 | WORKERS: 1 9 | PRINT_FREQ: 10 10 | 11 | MASK: 12 | USE: false 13 | INTERPOLATION: rbf 14 | P: 0.5 15 | CONF_THRE: 0.0 16 | ENTROPY_THRE: 0.0 17 | CRIT: conf_thre 18 | AGGR: copy 19 | 20 | EXIT: 21 | TYPE: original 22 | FINAL_CONV_KERNEL: 1 23 | 24 | DATASET: 25 | DATASET: cityscapes 26 | ROOT: 'data/' 27 | TEST_SET: 'list/cityscapes/val.lst' 28 | TRAIN_SET: 'list/cityscapes/train.lst' 29 | NUM_CLASSES: 19 30 | MODEL: 31 | NAME: 'model_anytime' 32 | PRETRAINED: 'pretrained_models/hrnet_w18_small_model_v2.pth' 33 | LOAD_STAGE: 0 34 | EXTRA: 35 | EE_WEIGHTS: (1,1,1,1) 36 | AGGREGATION: none 37 | EARLY_DETACH: false 38 | EXIT_NORM: BN 39 | STAGE1: 40 | NUM_MODULES: 1 41 | NUM_RANCHES: 1 42 | BLOCK: BOTTLENECK 43 | NUM_BLOCKS: 44 | - 2 45 | NUM_CHANNELS: 46 | - 64 47 | FUSE_METHOD: SUM 48 | STAGE2: 49 | NUM_MODULES: 1 50 | NUM_BRANCHES: 2 51 | BLOCK: BASIC 52 | NUM_BLOCKS: 53 | - 2 54 | - 2 55 | NUM_CHANNELS: 56 | - 18 57 | - 36 58 | FUSE_METHOD: SUM 59 | STAGE3: 60 | NUM_MODULES: 3 61 | NUM_BRANCHES: 3 62 | BLOCK: BASIC 63 | NUM_BLOCKS: 64 | - 2 65 | - 2 66 | - 2 67 | NUM_CHANNELS: 68 | - 18 69 | - 36 70 | - 72 71 | FUSE_METHOD: SUM 72 | STAGE4: 73 | NUM_MODULES: 2 74 | NUM_BRANCHES: 4 75 | BLOCK: BASIC 76 | NUM_BLOCKS: 77 | - 2 78 | - 2 79 | - 2 80 | - 2 81 | NUM_CHANNELS: 82 | - 18 83 | - 36 84 | - 72 85 | - 144 86 | FUSE_METHOD: SUM 87 | LOSS: 88 | USE_OHEM: false 89 | OHEMTHRES: 0.9 90 | OHEMKEEP: 131072 91 | TRAIN: 92 | EE_ONLY: false 93 | ALLE_ONLY: false 94 | IMAGE_SIZE: 95 | - 1024 96 | - 512 97 | BASE_SIZE: 2048 98 | BATCH_SIZE_PER_GPU: 3 99 | SHUFFLE: true 100 | BEGIN_EPOCH: 0 101 | END_EPOCH: 484 102 | RESUME: false 103 | OPTIMIZER: sgd 104 | LR: 0.01 105 | WD: 0.0005 106 | MOMENTUM: 0.9 107 | NESTEROV: false 108 | FLIP: true 109 | MULTI_SCALE: true 110 | DOWNSAMPLERATE: 1 111 | IGNORE_LABEL: 255 112 | SCALE_FACTOR: 16 113 | TEST: 114 | SUB_DIR: '' 115 | IMAGE_SIZE: 116 | - 2048 117 | - 1024 118 | BASE_SIZE: 2048 119 | BATCH_SIZE_PER_GPU: 4 120 | CENTER_CROP_TEST: false 121 | 122 | -------------------------------------------------------------------------------- /experiments/cityscapes/w48.yaml: -------------------------------------------------------------------------------- 1 | CUDNN: 2 | BENCHMARK: true 3 | DETERMINISTIC: false 4 | ENABLED: true 5 | GPUS: (0,1,2,3) 6 | OUTPUT_DIR: 'output' 7 | LOG_DIR: 'log' 8 | WORKERS: 1 9 | PRINT_FREQ: 10 10 | 11 | MASK: 12 | USE: false 13 | INTERPOLATION: rbf 14 | P: 0.5 15 | CONF_THRE: 0.0 16 | ENTROPY_THRE: 0.0 17 | CRIT: conf_thre 18 | AGGR: copy 19 | 20 | EXIT: 21 | TYPE: original 22 | FINAL_CONV_KERNEL: 1 23 | 24 | DATASET: 25 | DATASET: cityscapes 26 | ROOT: 'data/' 27 | TEST_SET: 'list/cityscapes/val.lst' 28 | TRAIN_SET: 'list/cityscapes/train.lst' 29 | NUM_CLASSES: 19 30 | MODEL: 31 | NAME: 'model_anytime' 32 | PRETRAINED: 'pretrained_models/hrnetv2_w48_imagenet_pretrained.pth' 33 | LOAD_STAGE: 0 34 | EXTRA: 35 | EE_WEIGHTS: (1,1,1,1) 36 | AGGREGATION: none 37 | EARLY_DETACH: false 38 | EXIT_NORM: BN 39 | STAGE1: 40 | NUM_MODULES: 1 41 | NUM_RANCHES: 1 42 | BLOCK: BOTTLENECK 43 | NUM_BLOCKS: 44 | - 4 45 | NUM_CHANNELS: 46 | - 64 47 | FUSE_METHOD: SUM 48 | STAGE2: 49 | NUM_MODULES: 1 50 | NUM_BRANCHES: 2 51 | BLOCK: BASIC 52 | NUM_BLOCKS: 53 | - 4 54 | - 4 55 | NUM_CHANNELS: 56 | - 48 57 | - 96 58 | FUSE_METHOD: SUM 59 | STAGE3: 60 | NUM_MODULES: 4 61 | NUM_BRANCHES: 3 62 | BLOCK: BASIC 63 | NUM_BLOCKS: 64 | - 4 65 | - 4 66 | - 4 67 | NUM_CHANNELS: 68 | - 48 69 | - 96 70 | - 192 71 | FUSE_METHOD: SUM 72 | STAGE4: 73 | NUM_MODULES: 3 74 | NUM_BRANCHES: 4 75 | BLOCK: BASIC 76 | NUM_BLOCKS: 77 | - 4 78 | - 4 79 | - 4 80 | - 4 81 | NUM_CHANNELS: 82 | - 48 83 | - 96 84 | - 192 85 | - 384 86 | FUSE_METHOD: SUM 87 | LOSS: 88 | USE_OHEM: false 89 | OHEMTHRES: 0.9 90 | OHEMKEEP: 131072 91 | TRAIN: 92 | EE_ONLY: false 93 | ALLE_ONLY: false 94 | IMAGE_SIZE: 95 | - 1024 96 | - 512 97 | BASE_SIZE: 2048 98 | BATCH_SIZE_PER_GPU: 3 99 | SHUFFLE: true 100 | BEGIN_EPOCH: 0 101 | END_EPOCH: 484 102 | RESUME: false 103 | OPTIMIZER: sgd 104 | LR: 0.01 105 | WD: 0.0005 106 | MOMENTUM: 0.9 107 | NESTEROV: false 108 | FLIP: true 109 | MULTI_SCALE: true 110 | DOWNSAMPLERATE: 1 111 | IGNORE_LABEL: 255 112 | SCALE_FACTOR: 16 113 | TEST: 114 | SUB_DIR: '' 115 | IMAGE_SIZE: 116 | - 2048 117 | - 1024 118 | BASE_SIZE: 2048 119 | BATCH_SIZE_PER_GPU: 4 120 | CENTER_CROP_TEST: false 121 | -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .default import _C as config 6 | from .default import update_config 7 | from .models import MODEL_EXTRAS 8 | 9 | 10 | -------------------------------------------------------------------------------- /lib/config/default.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | from yacs.config import CfgNode as CN 8 | 9 | 10 | _C = CN(new_allowed=True) 11 | 12 | _C.OUTPUT_DIR = '' 13 | _C.LOG_DIR = '' 14 | _C.GPUS = (0,) 15 | _C.WORKERS = 4 16 | _C.PRINT_FREQ = 20 17 | _C.AUTO_RESUME = False 18 | _C.PIN_MEMORY = True 19 | _C.RANK = 0 20 | 21 | _C.CUDNN = CN() 22 | _C.CUDNN.BENCHMARK = True 23 | _C.CUDNN.DETERMINISTIC = False 24 | _C.CUDNN.ENABLED = True 25 | 26 | _C.MODEL = CN(new_allowed=True) 27 | _C.MODEL.NAME = 'seg_hrnet' 28 | _C.MODEL.PRETRAINED = '' 29 | _C.MODEL.LOAD_STAGE = 0 30 | _C.MODEL.EXTRA = CN(new_allowed=True) 31 | 32 | _C.LOSS = CN() 33 | _C.LOSS.USE_OHEM = False 34 | _C.LOSS.OHEMTHRES = 0.9 35 | _C.LOSS.OHEMKEEP = 100000 36 | _C.LOSS.CLASS_BALANCE = True 37 | 38 | _C.DATASET = CN() 39 | _C.DATASET.ROOT = '' 40 | _C.DATASET.DATASET = 'cityscapes' 41 | _C.DATASET.NUM_CLASSES = 19 42 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst' 43 | _C.DATASET.EXTRA_TRAIN_SET = '' 44 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst' 45 | 46 | _C.TRAIN = CN(new_allowed=True) 47 | 48 | _C.TRAIN.IMAGE_SIZE = [1024, 512] 49 | _C.TRAIN.BASE_SIZE = 2048 50 | _C.TRAIN.DOWNSAMPLERATE = 1 51 | _C.TRAIN.FLIP = True 52 | _C.TRAIN.MULTI_SCALE = True 53 | _C.TRAIN.SCALE_FACTOR = 16 54 | 55 | _C.TRAIN.LR_FACTOR = 0.1 56 | _C.TRAIN.LR_STEP = [90, 110] 57 | _C.TRAIN.LR = 0.01 58 | _C.TRAIN.EXTRA_LR = 0.001 59 | 60 | _C.TRAIN.OPTIMIZER = 'sgd' 61 | _C.TRAIN.MOMENTUM = 0.9 62 | _C.TRAIN.WD = 0.0001 63 | _C.TRAIN.NESTEROV = False 64 | _C.TRAIN.IGNORE_LABEL = -1 65 | 66 | _C.TRAIN.BEGIN_EPOCH = 0 67 | _C.TRAIN.END_EPOCH = 484 68 | _C.TRAIN.EXTRA_EPOCH = 0 69 | 70 | _C.TRAIN.RESUME = False 71 | 72 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 73 | _C.TRAIN.SHUFFLE = True 74 | _C.TRAIN.NUM_SAMPLES = 0 75 | 76 | _C.TEST = CN(new_allowed=True) 77 | 78 | _C.TEST.IMAGE_SIZE = [2048, 1024] 79 | _C.TEST.BASE_SIZE = 2048 80 | 81 | _C.TEST.BATCH_SIZE_PER_GPU = 32 82 | _C.TEST.NUM_SAMPLES = 0 83 | 84 | _C.TEST.MODEL_FILE = '' 85 | _C.TEST.FLIP_TEST = False 86 | _C.TEST.MULTI_SCALE = False 87 | _C.TEST.CENTER_CROP_TEST = False 88 | _C.TEST.SCALE_LIST = [1] 89 | 90 | _C.DEBUG = CN() 91 | _C.DEBUG.DEBUG = False 92 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 93 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 94 | _C.DEBUG.SAVE_HEATMAPS_GT = False 95 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 96 | 97 | _C.EXIT = CN(new_allowed=True) 98 | 99 | _C.EXIT.TYPE = 'original' 100 | _C.EXIT.FINAL_CONV_KERNEL = 1 101 | _C.EXIT.COMP_RATE = 1.0 102 | _C.EXIT.SMOOTH = False 103 | _C.EXIT.SMOOTH_KS = 3 104 | _C.EXIT.LAST_SAME = False 105 | _C.EXIT.FIX_INTER_CHANNEL = False 106 | _C.EXIT.INTER_CHANNEL = 64 107 | 108 | _C.MASK = CN(new_allowed=True) 109 | _C.MASK.ENTROPY_THRE = 0.0 110 | 111 | _C.PYRAMID_TEST = CN(new_allowed=True) 112 | _C.PYRAMID_TEST.USE = False 113 | _C.PYRAMID_TEST.SIZE = 512 114 | 115 | 116 | 117 | def update_config(cfg, args): 118 | cfg.defrost() 119 | 120 | cfg.merge_from_file(args.cfg) 121 | cfg.merge_from_list(args.opts) 122 | 123 | cfg.freeze() 124 | 125 | 126 | if __name__ == '__main__': 127 | import sys 128 | with open(sys.argv[1], 'w') as f: 129 | print(_C, file=f) 130 | 131 | -------------------------------------------------------------------------------- /lib/config/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | HIGH_RESOLUTION_NET = CN() 8 | HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 9 | HIGH_RESOLUTION_NET.STEM_INPLANES = 64 10 | HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 11 | HIGH_RESOLUTION_NET.WITH_HEAD = True 12 | 13 | HIGH_RESOLUTION_NET.STAGE1 = CN() 14 | HIGH_RESOLUTION_NET.STAGE1.NUM_MODULES = 1 15 | HIGH_RESOLUTION_NET.STAGE1.NUM_BRANCHES = 1 16 | HIGH_RESOLUTION_NET.STAGE1.NUM_BLOCKS = [4] 17 | HIGH_RESOLUTION_NET.STAGE1.NUM_CHANNELS = [32] 18 | HIGH_RESOLUTION_NET.STAGE1.BLOCK = 'BASIC' 19 | HIGH_RESOLUTION_NET.STAGE1.FUSE_METHOD = 'SUM' 20 | 21 | HIGH_RESOLUTION_NET.STAGE2 = CN() 22 | HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 23 | HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 24 | HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 25 | HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 26 | HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 27 | HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 28 | 29 | HIGH_RESOLUTION_NET.STAGE3 = CN() 30 | HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 31 | HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 32 | HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 33 | HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 34 | HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 35 | HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 36 | 37 | HIGH_RESOLUTION_NET.STAGE4 = CN() 38 | HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 39 | HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 40 | HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 41 | HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 42 | HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 43 | HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 44 | 45 | MODEL_EXTRAS = { 46 | 'seg_hrnet': HIGH_RESOLUTION_NET, 47 | } 48 | -------------------------------------------------------------------------------- /lib/core/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class CrossEntropy(nn.Module): 6 | def __init__(self, ignore_label=-1, weight=None): 7 | super(CrossEntropy, self).__init__() 8 | self.ignore_label = ignore_label 9 | self.criterion = nn.CrossEntropyLoss(weight=weight, 10 | ignore_index=ignore_label) 11 | 12 | def forward(self, score, target): 13 | ph, pw = score.size(2), score.size(3) 14 | h, w = target.size(1), target.size(2) 15 | if ph != h or pw != w: 16 | score = F.upsample( 17 | input=score, size=(h, w), mode='bilinear') 18 | 19 | loss = self.criterion(score, target) 20 | 21 | return loss 22 | 23 | -------------------------------------------------------------------------------- /lib/core/function.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import numpy.ma as ma 7 | from tqdm import tqdm 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.distributed as dist 12 | from torch.nn import functional as F 13 | 14 | from utils.utils import AverageMeter 15 | from utils.utils import get_confusion_matrix_gpu 16 | from utils.utils import adjust_learning_rate 17 | from utils.utils import get_world_size, get_rank 18 | from utils.modelsummary import get_model_summary 19 | 20 | import pdb 21 | from PIL import Image 22 | import cv2 23 | import time 24 | 25 | def reduce_tensor(inp): 26 | world_size = get_world_size() 27 | if world_size < 2: 28 | return inp 29 | with torch.no_grad(): 30 | reduced_inp = inp 31 | dist.reduce(reduced_inp, dst=0) 32 | return reduced_inp 33 | 34 | 35 | def train_ee(config, epoch, num_epoch, epoch_iters, base_lr, num_iters, 36 | trainloader, optimizer, model, writer_dict, device): 37 | 38 | model.train() 39 | torch.manual_seed(get_rank() + epoch * 123) 40 | 41 | if config.TRAIN.EE_ONLY or config.TRAIN.ALLE_ONLY: 42 | model.eval() 43 | model.module.model.exit1.train() 44 | model.module.model.exit2.train() 45 | model.module.model.exit3.train() 46 | if config.TRAIN.ALLE_ONLY: 47 | model.module.model.last_layer.train() 48 | 49 | 50 | data_time = AverageMeter() 51 | batch_time = AverageMeter() 52 | ave_loss = AverageMeter() 53 | 54 | tic_data = time.time() 55 | tic = time.time() 56 | tic_total = time.time() 57 | cur_iters = epoch*epoch_iters 58 | writer = writer_dict['writer'] 59 | global_steps = writer_dict['train_global_steps'] 60 | rank = get_rank() 61 | world_size = get_world_size() 62 | 63 | for i_iter, batch in enumerate(trainloader): 64 | data_time.update(time.time() - tic_data) 65 | 66 | 67 | images, labels, _, _ = batch 68 | images = images.to(device) 69 | labels = labels.long().to(device) 70 | 71 | losses, _ = model(images, labels) 72 | 73 | loss = 0 74 | reduced_losses = [] 75 | for i, l in enumerate(losses): 76 | loss += config.MODEL.EXTRA.EE_WEIGHTS[i] * losses[i] 77 | reduced_losses.append(reduce_tensor(losses[i])) 78 | reduced_loss = reduce_tensor(loss) 79 | 80 | model.zero_grad() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | 85 | ave_loss.update(reduced_loss.item()) 86 | 87 | lr = adjust_learning_rate(optimizer, 88 | base_lr, 89 | num_iters, 90 | i_iter+cur_iters) 91 | 92 | batch_time.update(time.time() - tic) 93 | tic = time.time() 94 | 95 | 96 | if i_iter % config.PRINT_FREQ == 0 and rank == 0: 97 | 98 | print_loss = reduced_loss / world_size 99 | msg = 'Epoch: [{: >3d}/{}] Iter:[{: >3d}/{}], Time: {:.2f}, Data Time: {:.2f} ' \ 100 | 'lr: {:.6f}, Loss: {:.6f}' .format( 101 | epoch, num_epoch, i_iter, epoch_iters, 102 | batch_time.average(), data_time.average(), lr, print_loss) 103 | logging.info(msg) 104 | 105 | global_steps = writer_dict['train_global_steps'] 106 | writer.add_scalar('train_loss', print_loss, global_steps) 107 | 108 | writer.add_scalars('exit_train_loss', { 109 | 'exit1': reduced_losses[0].item() / world_size, 110 | 'exit2': reduced_losses[1].item() / world_size, 111 | 'exit3': reduced_losses[2].item() / world_size, 112 | 'exit4': reduced_losses[3].item() / world_size, 113 | }, 114 | global_steps) 115 | 116 | writer_dict['train_global_steps'] += 1 117 | 118 | tic_data = time.time() 119 | 120 | train_time = time.time() - tic_total 121 | 122 | if rank == 0: 123 | logging.info(f'Train time:{train_time}s') 124 | 125 | def validate_ee(config, testloader, model, writer_dict, device): 126 | 127 | torch.manual_seed(get_rank()) 128 | 129 | tic_data = time.time() 130 | tic = time.time() 131 | tic_total = time.time() 132 | rank = get_rank() 133 | world_size = get_world_size() 134 | model.eval() 135 | 136 | data_time = AverageMeter() 137 | batch_time = AverageMeter() 138 | ave_loss = AverageMeter() 139 | 140 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS) 141 | 142 | ave_losses = [AverageMeter() for i in range(num_exits)] 143 | 144 | confusion_matrices = [np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) for i in range(num_exits)] 145 | 146 | 147 | with torch.no_grad(): 148 | for i_iter, batch in enumerate(testloader): 149 | data_time.update(time.time() - tic_data) 150 | 151 | image, label, _, _ = batch 152 | size = label.size() 153 | image = image.to(device) 154 | label = label.long().to(device) 155 | 156 | losses, preds = model(image, label) 157 | 158 | for i, pred in enumerate(preds): 159 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]: 160 | pred = F.upsample(pred, (size[-2], size[-1]), 161 | mode='bilinear') 162 | 163 | confusion_matrices[i] += get_confusion_matrix_gpu( 164 | label, 165 | pred, 166 | size, 167 | config.DATASET.NUM_CLASSES, 168 | config.TRAIN.IGNORE_LABEL) 169 | 170 | loss = 0 171 | reduced_losses = [] 172 | for i, l in enumerate(losses): 173 | loss += config.MODEL.EXTRA.EE_WEIGHTS[i] * losses[i] 174 | reduced_losses.append(reduce_tensor(losses[i])) 175 | ave_losses[i].update(reduced_losses[i].item()) 176 | 177 | reduced_loss = reduce_tensor(loss) 178 | ave_loss.update(reduced_loss.item()) 179 | 180 | batch_time.update(time.time() - tic) 181 | tic = time.time() 182 | 183 | tic_data = time.time() 184 | 185 | if i_iter % config.PRINT_FREQ == 0 and rank == 0: 186 | print_loss = ave_loss.average() / world_size 187 | msg = 'Iter:[{: >3d}/{}], Time: {:.2f}, Data Time: {:.2f} ' \ 188 | 'Loss: {:.6f}' .format( 189 | i_iter, len(testloader), batch_time.average(), data_time.average(), print_loss) 190 | logging.info(msg) 191 | 192 | 193 | results = [] 194 | for i, confusion_matrix in enumerate(confusion_matrices): 195 | 196 | confusion_matrix = torch.from_numpy(confusion_matrix).to(device) 197 | reduced_confusion_matrix = reduce_tensor(confusion_matrix) 198 | confusion_matrix = reduced_confusion_matrix.cpu().numpy() 199 | 200 | pos = confusion_matrix.sum(1) 201 | res = confusion_matrix.sum(0) 202 | tp = np.diag(confusion_matrix) 203 | pixel_acc = tp.sum()/pos.sum() 204 | mean_acc = (tp/np.maximum(1.0, pos)).mean() 205 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 206 | mean_IoU = IoU_array.mean() 207 | 208 | results.append((mean_IoU, IoU_array, pixel_acc, mean_acc)) 209 | 210 | val_time = time.time() - tic_total 211 | 212 | if rank == 0: 213 | logging.info(f'Validation time:{val_time}s') 214 | mean_IoUs = [result[0] for result in results] 215 | mean_IoUs.append(np.mean(mean_IoUs)) 216 | print_result = '\t'.join(['{:.2f}'.format(m*100) for m in mean_IoUs]) 217 | logging.info(f'mean_IoUs: {print_result}') 218 | 219 | writer = writer_dict['writer'] 220 | global_steps = writer_dict['valid_global_steps'] 221 | writer.add_scalar('valid_loss', print_loss, global_steps) 222 | 223 | writer.add_scalars('exit_valid_loss', { 224 | 'exit1': ave_losses[0].average() / world_size, 225 | 'exit2': ave_losses[1].average() / world_size, 226 | 'exit3': ave_losses[2].average() / world_size, 227 | 'exit4': ave_losses[3].average() / world_size, 228 | }, 229 | global_steps) 230 | 231 | writer.add_scalars('valid_mIoUs', 232 | {f'valid_mIoU{i+1}': results[i][0] for i in range(num_exits)}, 233 | global_steps 234 | ) 235 | writer_dict['valid_global_steps'] += 1 236 | 237 | return results 238 | 239 | 240 | VIS_T = False 241 | VIS = False 242 | VIS_CONF = False 243 | TIMING = True 244 | 245 | def testval_ee(config, test_dataset, testloader, model, 246 | sv_dir='', sv_pred=False): 247 | model.eval() 248 | torch.manual_seed(get_rank()) 249 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS) 250 | 251 | confusion_matrices = [np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) for i in range(num_exits)] 252 | 253 | total_time = 0 254 | 255 | with torch.no_grad(): 256 | for index, batch in enumerate(tqdm(testloader)): 257 | image, label, _, name = batch 258 | if config.PYRAMID_TEST.USE: 259 | image = F.interpolate(image, (config.PYRAMID_TEST.SIZE//2, config.PYRAMID_TEST.SIZE), mode='bilinear') 260 | 261 | size = label.size() 262 | 263 | if TIMING: 264 | start = time.time() 265 | torch.cuda.synchronize() 266 | preds = model(image) 267 | 268 | if TIMING: 269 | torch.cuda.synchronize() 270 | total_time += time.time() - start 271 | 272 | for i, pred in enumerate(preds): 273 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]: 274 | original_logits = pred 275 | pred = F.upsample(pred, (size[-2], size[-1]), 276 | mode='bilinear') 277 | 278 | confusion_matrices[i] += get_confusion_matrix_gpu( 279 | label, 280 | pred, 281 | size, 282 | config.DATASET.NUM_CLASSES, 283 | config.TRAIN.IGNORE_LABEL) 284 | 285 | if sv_pred and index % 20 == 0 and VIS: 286 | print("Saving ... ", name) 287 | sv_path = os.path.join(sv_dir, f'test_val_results/{i+1}') 288 | os.makedirs(sv_path, exist_ok=True) 289 | test_dataset.save_pred(pred, sv_path, name) 290 | 291 | if VIS_T or VIS_CONF: 292 | def save_float_img(t, sv_path, name, normalize=False): 293 | os.makedirs(sv_path, exist_ok=True) 294 | if normalize: 295 | t = t/t.max() 296 | torch.save(t, os.path.join(sv_path, name[0]+'.pth')) 297 | t = t[0][0] 298 | t = t.cpu().numpy().copy() 299 | np.save(os.path.join(sv_path, name[0]+'.npy'), t) 300 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t*255) 301 | 302 | def save_long_img(t, sv_path, name): 303 | os.makedirs(sv_path, exist_ok=True) 304 | t = t[0][0] 305 | t = t.cpu().numpy().copy() 306 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t) 307 | 308 | def save_tensor(t, sv_path, name): 309 | os.makedirs(sv_path, exist_ok=True) 310 | torch.save(t, os.path.join(sv_path, name[0]+'.pth')) 311 | 312 | 313 | if VIS_CONF: 314 | 315 | out = F.softmax(original_logits, dim=1) 316 | 317 | sv_path = os.path.join(sv_dir, f'test_val_original_conf/{i+1}') 318 | original_conf_map, _ = out.max(dim=1) 319 | save_float_img(original_conf_map.unsqueeze(0), sv_path, name, normalize=False) 320 | 321 | sv_path = os.path.join(sv_dir, f'test_val_original_pred/{i+1}') 322 | max_index = torch.max(out, dim=1)[1] 323 | save_long_img(max_index.unsqueeze(0), sv_path, name) 324 | 325 | sv_path = os.path.join(sv_dir, f'test_val_original_logits/{i+1}') 326 | save_tensor(original_logits, sv_path, name) 327 | 328 | sv_path = os.path.join(sv_dir, f'test_val_original_results/{i+1}') 329 | os.makedirs(sv_path, exist_ok=True) 330 | test_dataset.save_pred(original_logits, sv_path, name) 331 | 332 | if hasattr(model.module, 'mask_dict'): 333 | sv_path = os.path.join(sv_dir, f'test_val_masks/') 334 | os.makedirs(sv_path, exist_ok=True) 335 | torch.save(model.module.mask_dict, os.path.join(sv_path, name[0]+'.pth')) 336 | 337 | if i == 0: 338 | sv_path = os.path.join(sv_dir, f'test_val_gt/') 339 | save_long_img(label.unsqueeze(0), sv_path, name) 340 | if index % 100 == 0: 341 | logging.info(f'processing: {index} images with exit {i}') 342 | pos = confusion_matrices[i].sum(1) 343 | res = confusion_matrices[i].sum(0) 344 | tp = np.diag(confusion_matrices[i]) 345 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 346 | mean_IoU = IoU_array.mean() 347 | logging.info('mIoU: %.4f' % (mean_IoU)) 348 | 349 | results = [] 350 | for i, confusion_matrix in enumerate(confusion_matrices): 351 | pos = confusion_matrix.sum(1) 352 | res = confusion_matrix.sum(0) 353 | tp = np.diag(confusion_matrix) 354 | pixel_acc = tp.sum()/pos.sum() 355 | mean_acc = (tp/np.maximum(1.0, pos)).mean() 356 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 357 | mean_IoU = IoU_array.mean() 358 | 359 | results.append((mean_IoU, IoU_array, pixel_acc, mean_acc)) 360 | 361 | if TIMING: 362 | print("Total_time", total_time) 363 | 364 | return results 365 | 366 | 367 | def testval_ee_class(config, test_dataset, testloader, model, 368 | sv_dir='', sv_pred=False): 369 | model.eval() 370 | torch.manual_seed(get_rank()) 371 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS) 372 | 373 | confusion_matrices = [np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) for i in range(num_exits)] 374 | 375 | total_time = 0 376 | 377 | with torch.no_grad(): 378 | for index, batch in enumerate(tqdm(testloader)): 379 | image, label, _, name = batch 380 | 381 | size = label.size() 382 | preds = model(image) 383 | 384 | for i, pred in enumerate(preds): 385 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]: 386 | original_logits = pred 387 | pred = F.upsample(pred, (size[-2], size[-1]), 388 | mode='bilinear') 389 | 390 | confusion_matrices[i] += get_confusion_matrix_gpu( 391 | label, 392 | pred, 393 | size, 394 | config.DATASET.NUM_CLASSES, 395 | config.TRAIN.IGNORE_LABEL) 396 | 397 | if sv_pred and index % 20 == 0 and VIS: 398 | print("Saving ... ", name) 399 | sv_path = os.path.join(sv_dir, f'test_val_results/{i+1}') 400 | os.makedirs(sv_path, exist_ok=True) 401 | test_dataset.save_pred(pred, sv_path, name) 402 | 403 | if VIS_T or VIS_CONF: 404 | def save_float_img(t, sv_path, name, normalize=False): 405 | os.makedirs(sv_path, exist_ok=True) 406 | if normalize: 407 | t = t/t.max() 408 | torch.save(t, os.path.join(sv_path, name[0]+'.pth')) 409 | t = t[0][0] 410 | t = t.cpu().numpy().copy() 411 | np.save(os.path.join(sv_path, name[0]+'.npy'), t) 412 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t*255) 413 | 414 | def save_long_img(t, sv_path, name): 415 | os.makedirs(sv_path, exist_ok=True) 416 | t = t[0][0] 417 | t = t.cpu().numpy().copy() 418 | cv2.imwrite(os.path.join(sv_path, name[0]+'.png'), t) 419 | 420 | def save_tensor(t, sv_path, name): 421 | os.makedirs(sv_path, exist_ok=True) 422 | torch.save(t, os.path.join(sv_path, name[0]+'.pth')) 423 | if VIS_CONF: 424 | out = F.softmax(original_logits, dim=1) 425 | 426 | sv_path = os.path.join(sv_dir, f'test_val_original_conf/{i+1}') 427 | original_conf_map, _ = out.max(dim=1) 428 | save_float_img(original_conf_map.unsqueeze(0), sv_path, name, normalize=False) 429 | 430 | sv_path = os.path.join(sv_dir, f'test_val_original_pred/{i+1}') 431 | max_index = torch.max(out, dim=1)[1] 432 | save_long_img(max_index.unsqueeze(0), sv_path, name) 433 | 434 | sv_path = os.path.join(sv_dir, f'test_val_original_logits/{i+1}') 435 | save_tensor(original_logits, sv_path, name) 436 | 437 | sv_path = os.path.join(sv_dir, f'test_val_original_results/{i+1}') 438 | os.makedirs(sv_path, exist_ok=True) 439 | test_dataset.save_pred(original_logits, sv_path, name) 440 | 441 | if hasattr(model.module, 'mask_dict'): 442 | sv_path = os.path.join(sv_dir, f'test_val_masks/') 443 | os.makedirs(sv_path, exist_ok=True) 444 | torch.save(model.module.mask_dict, os.path.join(sv_path, name[0]+'.pth')) 445 | 446 | if i == 0: 447 | sv_path = os.path.join(sv_dir, f'test_val_gt/') 448 | save_long_img(label.unsqueeze(0), sv_path, name) 449 | 450 | if index % 100 == 0: 451 | logging.info(f'processing: {index} images with exit {i}') 452 | pos = confusion_matrices[i].sum(1) 453 | res = confusion_matrices[i].sum(0) 454 | tp = np.diag(confusion_matrices[i]) 455 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 456 | mean_IoU = IoU_array.mean() 457 | logging.info('mIoU: %.4f' % (mean_IoU)) 458 | 459 | results = [] 460 | for i, confusion_matrix in enumerate(confusion_matrices): 461 | pos = confusion_matrix.sum(1) 462 | res = confusion_matrix.sum(0) 463 | tp = np.diag(confusion_matrix) 464 | pixel_acc = tp.sum()/pos.sum() 465 | mean_acc = (tp/np.maximum(1.0, pos)).mean() 466 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 467 | mean_IoU = IoU_array.mean() 468 | 469 | results.append((mean_IoU, IoU_array, pixel_acc, mean_acc)) 470 | 471 | if TIMING: 472 | print("Total_time", total_time) 473 | 474 | return results 475 | def testval_ee_profiling(config, test_dataset, testloader, model, 476 | sv_dir='', sv_pred=False): 477 | model.eval() 478 | torch.manual_seed(get_rank()) 479 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS) 480 | total_time = 0 481 | 482 | gflops = [] 483 | with torch.no_grad(): 484 | for index, batch in enumerate(tqdm(testloader)): 485 | image, label, _, name = batch 486 | if config.PYRAMID_TEST.USE: 487 | image = F.interpolate(image, (config.PYRAMID_TEST.SIZE, config.PYRAMID_TEST.SIZE//2), mode='bilinear') 488 | stats = {} 489 | saved_stats = {} 490 | 491 | for i in range(4): 492 | setattr(model.module, f"stop{i+1}", "anY_RanDOM_ThiNg") 493 | summary, stats[i+1] = get_model_summary(model, image, verbose=False) 494 | delattr(model.module, f"stop{i+1}") 495 | 496 | saved_stats['params'] = [stats[i+1]['params'] for i in range(4)] 497 | saved_stats['flops'] = [stats[i+1]['flops'] for i in range(4)] 498 | saved_stats['counts'] = [stats[i+1]['counts'] for i in range(4)] 499 | saved_stats['Gflops'] = [f/(1024**3) for f in saved_stats['flops']] 500 | saved_stats['Mparams'] = [f/(10**6) for f in saved_stats['params']] 501 | gflops.append(saved_stats['Gflops']) 502 | 503 | final_stats = saved_stats 504 | final_stats['Gflops'] = [] 505 | for i in range(4): 506 | final_stats['Gflops'].append(np.mean([x[i] for x in gflops])) 507 | final_stats['Gflops_mean'] = np.mean(final_stats['Gflops']) 508 | return final_stats 509 | 510 | def testval_ee_profiling_actual(config, test_dataset, testloader, model, 511 | sv_dir='', sv_pred=False): 512 | model.eval() 513 | torch.manual_seed(get_rank()) 514 | num_exits = len(config.MODEL.EXTRA.EE_WEIGHTS) 515 | total_time = 0 516 | 517 | stats = {} 518 | stats['time'] = {} 519 | times = [] 520 | 521 | with torch.no_grad(): 522 | for index, batch in enumerate(tqdm(testloader)): 523 | image, label, _, name = batch 524 | t = [] 525 | for i in range(4): 526 | if isinstance(model, nn.DataParallel): 527 | setattr(model.module, f"stop{i+1}", "anY_RanDOM_ThiNg") 528 | else: 529 | setattr(model, f"stop{i+1}", "anY_RanDOM_ThiNg") 530 | 531 | torch.cuda.synchronize() 532 | start = time.time() 533 | out = model(image) 534 | torch.cuda.synchronize() 535 | t.append(time.time() - start) 536 | 537 | if isinstance(model, nn.DataParallel): 538 | delattr(model.module, f"stop{i+1}") 539 | else: 540 | delattr(model, f"stop{i+1}") 541 | 542 | if index > 5: 543 | times.append(t) 544 | if index > 20: 545 | break 546 | 547 | print(t) 548 | for i in range(4): 549 | stats['time'][i] = np.mean([t[i] for t in times]) 550 | print(stats) 551 | return stats 552 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cityscapes import Cityscapes as cityscapes -------------------------------------------------------------------------------- /lib/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import random 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | from torch.utils import data 10 | 11 | class BaseDataset(data.Dataset): 12 | def __init__(self, 13 | ignore_label=-1, 14 | base_size=2048, 15 | crop_size=(512, 1024), 16 | downsample_rate=1, 17 | scale_factor=16, 18 | mean=[0.485, 0.456, 0.406], 19 | std=[0.229, 0.224, 0.225]): 20 | 21 | self.base_size = base_size 22 | self.crop_size = crop_size 23 | self.ignore_label = ignore_label 24 | 25 | self.mean = mean 26 | self.std = std 27 | self.scale_factor = scale_factor 28 | self.downsample_rate = 1./downsample_rate 29 | 30 | self.files = [] 31 | 32 | def __len__(self): 33 | return len(self.files) 34 | 35 | def input_transform(self, image): 36 | image = image.astype(np.float32)[:, :, ::-1] 37 | image = image / 255.0 38 | image -= self.mean 39 | image /= self.std 40 | return image 41 | 42 | def label_transform(self, label): 43 | return np.array(label).astype('int32') 44 | 45 | def pad_image(self, image, h, w, size, padvalue): 46 | pad_image = image.copy() 47 | pad_h = max(size[0] - h, 0) 48 | pad_w = max(size[1] - w, 0) 49 | if pad_h > 0 or pad_w > 0: 50 | pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, 51 | pad_w, cv2.BORDER_CONSTANT, 52 | value=padvalue) 53 | 54 | return pad_image 55 | 56 | def rand_crop(self, image, label): 57 | h, w = image.shape[:-1] 58 | image = self.pad_image(image, h, w, self.crop_size, 59 | (0.0, 0.0, 0.0)) 60 | label = self.pad_image(label, h, w, self.crop_size, 61 | (self.ignore_label,)) 62 | 63 | new_h, new_w = label.shape 64 | x = random.randint(0, new_w - self.crop_size[1]) 65 | y = random.randint(0, new_h - self.crop_size[0]) 66 | image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]] 67 | label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]] 68 | 69 | return image, label 70 | 71 | def center_crop(self, image, label): 72 | h, w = image.shape[:2] 73 | x = int(round((w - self.crop_size[1]) / 2.)) 74 | y = int(round((h - self.crop_size[0]) / 2.)) 75 | image = image[y:y+self.crop_size[0], x:x+self.crop_size[1]] 76 | label = label[y:y+self.crop_size[0], x:x+self.crop_size[1]] 77 | 78 | return image, label 79 | 80 | def image_resize(self, image, long_size, label=None): 81 | h, w = image.shape[:2] 82 | if h > w: 83 | new_h = long_size 84 | new_w = np.int(w * long_size / h + 0.5) 85 | else: 86 | new_w = long_size 87 | new_h = np.int(h * long_size / w + 0.5) 88 | 89 | image = cv2.resize(image, (new_w, new_h), 90 | interpolation = cv2.INTER_LINEAR) 91 | if label is not None: 92 | label = cv2.resize(label, (new_w, new_h), 93 | interpolation = cv2.INTER_NEAREST) 94 | else: 95 | return image 96 | 97 | return image, label 98 | 99 | def multi_scale_aug(self, image, label=None, 100 | rand_scale=1, rand_crop=True): 101 | long_size = np.int(self.base_size * rand_scale + 0.5) 102 | if label is not None: 103 | image, label = self.image_resize(image, long_size, label) 104 | if rand_crop: 105 | image, label = self.rand_crop(image, label) 106 | return image, label 107 | else: 108 | image = self.image_resize(image, long_size) 109 | return image 110 | 111 | def gen_sample(self, image, label, 112 | multi_scale=True, is_flip=True, center_crop_test=False): 113 | if multi_scale: 114 | rand_scale = 0.5 + random.randint(0, self.scale_factor) / 10.0 115 | image, label = self.multi_scale_aug(image, label, 116 | rand_scale=rand_scale) 117 | 118 | if center_crop_test: 119 | image, label = self.image_resize(image, 120 | self.base_size, 121 | label) 122 | image, label = self.center_crop(image, label) 123 | 124 | image = self.input_transform(image) 125 | label = self.label_transform(label) 126 | 127 | image = image.transpose((2, 0, 1)) 128 | 129 | if is_flip: 130 | flip = np.random.choice(2) * 2 - 1 131 | image = image[:, :, ::flip] 132 | label = label[:, ::flip] 133 | 134 | if self.downsample_rate != 1: 135 | label = cv2.resize(label, 136 | None, 137 | fx=self.downsample_rate, 138 | fy=self.downsample_rate, 139 | interpolation=cv2.INTER_NEAREST) 140 | 141 | return image, label 142 | 143 | def inference(self, model, image, flip=False): 144 | size = image.size() 145 | pred = model(image) 146 | pred = F.upsample(input=pred, 147 | size=(size[-2], size[-1]), 148 | mode='bilinear') 149 | if flip: 150 | flip_img = image.numpy()[:,:,:,::-1] 151 | flip_output = model(torch.from_numpy(flip_img.copy())) 152 | flip_output = F.upsample(input=flip_output, 153 | size=(size[-2], size[-1]), 154 | mode='bilinear') 155 | flip_pred = flip_output.cpu().numpy().copy() 156 | flip_pred = torch.from_numpy(flip_pred[:,:,:,::-1].copy()).cuda() 157 | pred += flip_pred 158 | pred = pred * 0.5 159 | return pred.exp() 160 | 161 | def multi_scale_inference(self, model, image, scales=[1], flip=False): 162 | batch, _, ori_height, ori_width = image.size() 163 | assert batch == 1, "only supporting batchsize 1." 164 | device = torch.device("cuda:%d" % model.device_ids[0]) 165 | image = image.numpy()[0].transpose((1,2,0)).copy() 166 | stride_h = np.int(self.crop_size[0] * 2.0 / 3.0) 167 | stride_w = np.int(self.crop_size[1] * 2.0 / 3.0) 168 | final_pred = torch.zeros([1, self.num_classes, 169 | ori_height,ori_width]).to(device) 170 | padvalue = -1.0 * np.array(self.mean) / np.array(self.std) 171 | for scale in scales: 172 | new_img = self.multi_scale_aug(image=image, 173 | rand_scale=scale, 174 | rand_crop=False) 175 | height, width = new_img.shape[:-1] 176 | 177 | if max(height, width) <= np.min(self.crop_size): 178 | new_img = self.pad_image(new_img, height, width, 179 | self.crop_size, padvalue) 180 | new_img = new_img.transpose((2, 0, 1)) 181 | new_img = np.expand_dims(new_img, axis=0) 182 | new_img = torch.from_numpy(new_img) 183 | preds = self.inference(model, new_img, flip) 184 | preds = preds[:, :, 0:height, 0:width] 185 | else: 186 | if height < self.crop_size[0] or width < self.crop_size[1]: 187 | new_img = self.pad_image(new_img, height, width, 188 | self.crop_size, padvalue) 189 | new_h, new_w = new_img.shape[:-1] 190 | rows = np.int(np.ceil(1.0 * (new_h - 191 | self.crop_size[0]) / stride_h)) + 1 192 | cols = np.int(np.ceil(1.0 * (new_w - 193 | self.crop_size[1]) / stride_w)) + 1 194 | preds = torch.zeros([1, self.num_classes, 195 | new_h,new_w]).to(device) 196 | count = torch.zeros([1,1, new_h, new_w]).to(device) 197 | 198 | for r in range(rows): 199 | for c in range(cols): 200 | h0 = r * stride_h 201 | w0 = c * stride_w 202 | h1 = min(h0 + self.crop_size[0], new_h) 203 | w1 = min(w0 + self.crop_size[1], new_w) 204 | crop_img = new_img[h0:h1, w0:w1, :] 205 | if h1 == new_h or w1 == new_w: 206 | crop_img = self.pad_image(crop_img, 207 | h1-h0, 208 | w1-w0, 209 | self.crop_size, 210 | padvalue) 211 | crop_img = crop_img.transpose((2, 0, 1)) 212 | crop_img = np.expand_dims(crop_img, axis=0) 213 | crop_img = torch.from_numpy(crop_img) 214 | pred = self.inference(model, crop_img, flip) 215 | 216 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0] 217 | count[:,:,h0:h1,w0:w1] += 1 218 | preds = preds / count 219 | preds = preds[:,:,:height,:width] 220 | preds = F.upsample(preds, (ori_height, ori_width), 221 | mode='bilinear') 222 | final_pred += preds 223 | return final_pred 224 | -------------------------------------------------------------------------------- /lib/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | from .base_dataset import BaseDataset 11 | 12 | class Cityscapes(BaseDataset): 13 | def __init__(self, 14 | root, 15 | list_path, 16 | num_samples=None, 17 | num_classes=19, 18 | multi_scale=True, 19 | flip=True, 20 | ignore_label=-1, 21 | base_size=2048, 22 | crop_size=(512, 1024), 23 | center_crop_test=False, 24 | downsample_rate=1, 25 | scale_factor=16, 26 | mean=[0.485, 0.456, 0.406], 27 | std=[0.229, 0.224, 0.225]): 28 | 29 | super(Cityscapes, self).__init__(ignore_label, base_size, 30 | crop_size, downsample_rate, scale_factor, mean, std,) 31 | 32 | self.root = root 33 | self.list_path = list_path 34 | self.num_classes = num_classes 35 | self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 36 | 1.0166, 0.9969, 0.9754, 1.0489, 37 | 0.8786, 1.0023, 0.9539, 0.9843, 38 | 1.1116, 0.9037, 1.0865, 1.0955, 39 | 1.0865, 1.1529, 1.0507]).cuda() 40 | 41 | self.multi_scale = multi_scale 42 | self.flip = flip 43 | self.center_crop_test = center_crop_test 44 | 45 | self.img_list = [line.strip().split() for line in open(root+list_path)] 46 | 47 | self.files = self.read_files() 48 | if num_samples: 49 | self.files = self.files[:num_samples] 50 | 51 | self.label_mapping = {-1: ignore_label, 0: ignore_label, 52 | 1: ignore_label, 2: ignore_label, 53 | 3: ignore_label, 4: ignore_label, 54 | 5: ignore_label, 6: ignore_label, 55 | 7: 0, 8: 1, 9: ignore_label, 56 | 10: ignore_label, 11: 2, 12: 3, 57 | 13: 4, 14: ignore_label, 15: ignore_label, 58 | 16: ignore_label, 17: 5, 18: ignore_label, 59 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 60 | 25: 12, 26: 13, 27: 14, 28: 15, 61 | 29: ignore_label, 30: ignore_label, 62 | 31: 16, 32: 17, 33: 18} 63 | 64 | def read_files(self): 65 | files = [] 66 | if 'test' in self.list_path: 67 | for item in self.img_list: 68 | image_path = item 69 | name = os.path.splitext(os.path.basename(image_path[0]))[0] 70 | files.append({ 71 | "img": image_path[0], 72 | "name": name, 73 | }) 74 | else: 75 | for item in self.img_list: 76 | image_path, label_path = item 77 | name = os.path.splitext(os.path.basename(label_path))[0] 78 | files.append({ 79 | "img": image_path, 80 | "label": label_path, 81 | "name": name, 82 | "weight": 1 83 | }) 84 | return files 85 | 86 | def convert_label(self, label, inverse=False): 87 | temp = label.copy() 88 | if inverse: 89 | for v, k in self.label_mapping.items(): 90 | label[temp == k] = v 91 | else: 92 | for k, v in self.label_mapping.items(): 93 | label[temp == k] = v 94 | return label 95 | 96 | def __getitem__(self, index): 97 | item = self.files[index] 98 | name = item["name"] 99 | image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]), 100 | cv2.IMREAD_COLOR) 101 | size = image.shape 102 | 103 | if 'test' in self.list_path: 104 | image = self.input_transform(image) 105 | image = image.transpose((2, 0, 1)) 106 | 107 | return image.copy(), np.array(size), name 108 | 109 | label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]), 110 | cv2.IMREAD_GRAYSCALE) 111 | label = self.convert_label(label) 112 | 113 | image, label = self.gen_sample(image, label, 114 | self.multi_scale, self.flip, 115 | self.center_crop_test) 116 | 117 | return image.copy(), label.copy(), np.array(size), name 118 | 119 | def multi_scale_inference(self, model, image, scales=[1], flip=False): 120 | batch, _, ori_height, ori_width = image.size() 121 | assert batch == 1, "only supporting batchsize 1." 122 | image = image.numpy()[0].transpose((1,2,0)).copy() 123 | stride_h = np.int(self.crop_size[0] * 1.0) 124 | stride_w = np.int(self.crop_size[1] * 1.0) 125 | final_pred = torch.zeros([1, self.num_classes, 126 | ori_height,ori_width]).cuda() 127 | for scale in scales: 128 | new_img = self.multi_scale_aug(image=image, 129 | rand_scale=scale, 130 | rand_crop=False) 131 | height, width = new_img.shape[:-1] 132 | 133 | if scale <= 1.0: 134 | new_img = new_img.transpose((2, 0, 1)) 135 | new_img = np.expand_dims(new_img, axis=0) 136 | new_img = torch.from_numpy(new_img) 137 | preds = self.inference(model, new_img, flip) 138 | preds = preds[:, :, 0:height, 0:width] 139 | else: 140 | new_h, new_w = new_img.shape[:-1] 141 | rows = np.int(np.ceil(1.0 * (new_h - 142 | self.crop_size[0]) / stride_h)) + 1 143 | cols = np.int(np.ceil(1.0 * (new_w - 144 | self.crop_size[1]) / stride_w)) + 1 145 | preds = torch.zeros([1, self.num_classes, 146 | new_h,new_w]).cuda() 147 | count = torch.zeros([1,1, new_h, new_w]).cuda() 148 | 149 | for r in range(rows): 150 | for c in range(cols): 151 | h0 = r * stride_h 152 | w0 = c * stride_w 153 | h1 = min(h0 + self.crop_size[0], new_h) 154 | w1 = min(w0 + self.crop_size[1], new_w) 155 | h0 = max(int(h1 - self.crop_size[0]), 0) 156 | w0 = max(int(w1 - self.crop_size[1]), 0) 157 | crop_img = new_img[h0:h1, w0:w1, :] 158 | crop_img = crop_img.transpose((2, 0, 1)) 159 | crop_img = np.expand_dims(crop_img, axis=0) 160 | crop_img = torch.from_numpy(crop_img) 161 | pred = self.inference(model, crop_img, flip) 162 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0] 163 | count[:,:,h0:h1,w0:w1] += 1 164 | preds = preds / count 165 | preds = preds[:,:,:height,:width] 166 | preds = F.upsample(preds, (ori_height, ori_width), 167 | mode='bilinear') 168 | final_pred += preds 169 | return final_pred 170 | 171 | def get_palette(self, n): 172 | palette = [0] * (n * 3) 173 | for j in range(0, n): 174 | lab = j 175 | palette[j * 3 + 0] = 0 176 | palette[j * 3 + 1] = 0 177 | palette[j * 3 + 2] = 0 178 | i = 0 179 | while lab: 180 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 181 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 182 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 183 | i += 1 184 | lab >>= 3 185 | return palette 186 | 187 | def get_palette_cityscapes(self, n): 188 | palette = [0] * (n * 3) 189 | from cityscapesscripts.helpers.labels import labels 190 | trainId2color = {label.trainId: label.color for label in labels if (label.trainId != 255 and label.trainId != -1)} 191 | for trainId, color in trainId2color.items(): 192 | palette[trainId*3] = color[0] 193 | palette[trainId*3 + 1] = color[1] 194 | palette[trainId*3 + 2] = color[2] 195 | 196 | return palette 197 | 198 | 199 | def save_pred(self, preds, sv_path, name): 200 | 201 | palette = self.get_palette_cityscapes(256) 202 | 203 | preds = preds.cpu().numpy().copy() 204 | preds = np.asarray(np.argmax(preds, axis=1), dtype=np.uint8) 205 | for i in range(preds.shape[0]): 206 | pred = preds[i] 207 | save_img = Image.fromarray(pred) 208 | save_img.putpalette(palette) 209 | save_img.save(os.path.join(sv_path, name[i]+'.png')) 210 | 211 | 212 | def save_ts(self, t, sv_path, name): 213 | palette = self.get_palette(256) 214 | t = t.cpu().numpy().copy() 215 | for i in range(preds.shape[0]): 216 | pred = self.convert_label(preds[i], inverse=True) 217 | save_img = Image.fromarray(pred) 218 | save_img.putpalette(palette) 219 | save_img.save(os.path.join(sv_path, name[i]+'.png')) 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import models.model_anytime -------------------------------------------------------------------------------- /lib/models/conv_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pdb, time 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | class conv_mask_uniform(nn.Conv2d): 9 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, p=0.5, interpolate='none'): 10 | super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) 11 | self.mask = None 12 | self.mask_built = False 13 | self.p = p 14 | 15 | self.interpolate = interpolate 16 | self.r = 7 17 | self.padding_interpolate = 3 18 | 19 | self.Lambda = nn.Parameter(torch.tensor(3.0)) 20 | square_dis = np.zeros((self.r, self.r)) 21 | center_point = (square_dis.shape[0]//2, square_dis.shape[1]//2) 22 | 23 | for i in range(square_dis.shape[0]): 24 | for j in range(square_dis.shape[1]): 25 | square_dis[i][j] = (i - center_point[0])**2 + (j - center_point[1])**2 26 | 27 | square_dis[center_point[0]][center_point[1]] = 100000.0 28 | 29 | self.square_dis = nn.Parameter(torch.Tensor(square_dis), requires_grad=False) 30 | 31 | def build_mask(self, x): 32 | mask_p = x.new(x.shape[2:]).fill_(self.p) 33 | mask = torch.bernoulli(mask_p) 34 | self.mask = mask[None, None, :, :].float() 35 | self.mask_built = True 36 | 37 | if self.in_channels == 3: 38 | print('Mask sum:', torch.sum(self.mask)) 39 | 40 | def build_mask_random(self, x): 41 | mask_p = x.new(size=(x.shape[0], *x.shape[2:])).fill_(self.p) 42 | mask = torch.bernoulli(mask_p) 43 | self.mask = mask[:, None, :, :].float() 44 | self.mask_built = True 45 | 46 | def set_mask(self, mask): 47 | self.mask = mask[:, None, :, :] 48 | self.mask_built = True 49 | 50 | def forward(self, x): 51 | y = super().forward(x) 52 | self.out_h, self.out_w = y.size(-2), y.size(-1) 53 | if not self.mask_built: 54 | self.build_mask_random(y) 55 | 56 | kernel = (-(self.Lambda**2) * self.square_dis.detach()).exp() 57 | kernel = kernel / (kernel.sum() + 10**(-5)) 58 | kernel = kernel.expand((self.out_channels, 1, kernel.size(0), kernel.size(1))) 59 | interpolated = F.conv2d(y * self.mask, kernel, stride=1, padding=self.padding_interpolate, groups=self.out_channels) 60 | 61 | out = y * self.mask + interpolated * (1 - self.mask) 62 | self.mask_built = False 63 | 64 | return out 65 | 66 | if __name__ == '__main__': 67 | a = Smooth(n_channels=10, kernel_size=3, padding=1) 68 | 69 | -------------------------------------------------------------------------------- /lib/models/model_anytime.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import functools 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch._utils 14 | import torch.nn.functional as F 15 | 16 | from utils.utils import AverageMeter 17 | import pdb, time 18 | from .conv_mask import conv_mask_uniform 19 | from functools import partial 20 | 21 | from utils.utils import get_rank 22 | 23 | BatchNorm2d = nn.BatchNorm2d 24 | BN_MOMENTUM = 0.01 25 | logger = logging.getLogger(__name__) 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | return used_conv(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = used_conv(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 71 | self.conv2 = used_conv(planes, planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 74 | self.conv3 = used_conv(planes, planes * self.expansion, kernel_size=1, 75 | bias=False) 76 | self.bn3 = BatchNorm2d(planes * self.expansion, 77 | momentum=BN_MOMENTUM) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class HighResolutionModule(nn.Module): 106 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 107 | num_channels, fuse_method, multi_scale_output=True): 108 | super(HighResolutionModule, self).__init__() 109 | self._check_branches( 110 | num_branches, blocks, num_blocks, num_inchannels, num_channels) 111 | 112 | self.num_inchannels = num_inchannels 113 | self.fuse_method = fuse_method 114 | self.num_branches = num_branches 115 | 116 | self.multi_scale_output = multi_scale_output 117 | 118 | self.branches = self._make_branches( 119 | num_branches, blocks, num_blocks, num_channels) 120 | self.fuse_layers = self._make_fuse_layers() 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | def _check_branches(self, num_branches, blocks, num_blocks, 124 | num_inchannels, num_channels): 125 | if num_branches != len(num_blocks): 126 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 127 | num_branches, len(num_blocks)) 128 | logger.error(error_msg) 129 | raise ValueError(error_msg) 130 | 131 | if num_branches != len(num_channels): 132 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 133 | num_branches, len(num_channels)) 134 | logger.error(error_msg) 135 | raise ValueError(error_msg) 136 | 137 | if num_branches != len(num_inchannels): 138 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 139 | num_branches, len(num_inchannels)) 140 | logger.error(error_msg) 141 | raise ValueError(error_msg) 142 | 143 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 144 | stride=1): 145 | downsample = None 146 | if stride != 1 or \ 147 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 148 | downsample = nn.Sequential( 149 | used_conv(self.num_inchannels[branch_index], 150 | num_channels[branch_index] * block.expansion, 151 | kernel_size=1, stride=stride, bias=False), 152 | BatchNorm2d(num_channels[branch_index] * block.expansion, 153 | momentum=BN_MOMENTUM), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.num_inchannels[branch_index], 158 | num_channels[branch_index], stride, downsample)) 159 | self.num_inchannels[branch_index] = \ 160 | num_channels[branch_index] * block.expansion 161 | for i in range(1, num_blocks[branch_index]): 162 | layers.append(block(self.num_inchannels[branch_index], 163 | num_channels[branch_index])) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 168 | branches = [] 169 | 170 | for i in range(num_branches): 171 | branches.append( 172 | self._make_one_branch(i, block, num_blocks, num_channels)) 173 | 174 | return nn.ModuleList(branches) 175 | 176 | def _make_fuse_layers(self): 177 | if self.num_branches == 1: 178 | return None 179 | 180 | num_branches = self.num_branches 181 | num_inchannels = self.num_inchannels 182 | fuse_layers = [] 183 | for i in range(num_branches if self.multi_scale_output else 1): 184 | fuse_layer = [] 185 | for j in range(num_branches): 186 | if j > i: 187 | fuse_layer.append(nn.Sequential( 188 | used_conv(num_inchannels[j], 189 | num_inchannels[i], 190 | 1, 191 | 1, 192 | 0, 193 | bias=False), 194 | BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) 195 | elif j == i: 196 | fuse_layer.append(None) 197 | else: 198 | conv3x3s = [] 199 | for k in range(i-j): 200 | if k == i - j - 1: 201 | num_outchannels_conv3x3 = num_inchannels[i] 202 | conv3x3s.append(nn.Sequential( 203 | used_conv(num_inchannels[j], 204 | num_outchannels_conv3x3, 205 | 3, 2, 1, bias=False), 206 | BatchNorm2d(num_outchannels_conv3x3, 207 | momentum=BN_MOMENTUM))) 208 | else: 209 | num_outchannels_conv3x3 = num_inchannels[j] 210 | conv3x3s.append(nn.Sequential( 211 | used_conv(num_inchannels[j], 212 | num_outchannels_conv3x3, 213 | 3, 2, 1, bias=False), 214 | BatchNorm2d(num_outchannels_conv3x3, 215 | momentum=BN_MOMENTUM), 216 | nn.ReLU(inplace=True))) 217 | fuse_layer.append(nn.Sequential(*conv3x3s)) 218 | fuse_layers.append(nn.ModuleList(fuse_layer)) 219 | 220 | return nn.ModuleList(fuse_layers) 221 | 222 | def get_num_inchannels(self): 223 | return self.num_inchannels 224 | 225 | def forward(self, x): 226 | if self.num_branches == 1: 227 | return [self.branches[0](x[0])] 228 | 229 | for i in range(self.num_branches): 230 | x[i] = self.branches[i](x[i]) 231 | 232 | x_fuse = [] 233 | for i in range(len(self.fuse_layers)): 234 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 235 | for j in range(1, self.num_branches): 236 | if i == j: 237 | y = y + x[j] 238 | elif j > i: 239 | width_output = x[i].shape[-1] 240 | height_output = x[i].shape[-2] 241 | y = y + F.interpolate( 242 | self.fuse_layers[i][j](x[j]), 243 | size=[height_output, width_output], 244 | mode='bilinear') 245 | else: 246 | y = y + self.fuse_layers[i][j](x[j]) 247 | x_fuse.append(self.relu(y)) 248 | 249 | return x_fuse 250 | 251 | 252 | blocks_dict = { 253 | 'BASIC': BasicBlock, 254 | 'BOTTLENECK': Bottleneck 255 | } 256 | 257 | 258 | 259 | class HighResolutionNet(nn.Module): 260 | def _make_transition_layer( 261 | self, num_channels_pre_layer, num_channels_cur_layer): 262 | num_branches_cur = len(num_channels_cur_layer) 263 | num_branches_pre = len(num_channels_pre_layer) 264 | 265 | transition_layers = [] 266 | for i in range(num_branches_cur): 267 | if i < num_branches_pre: 268 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 269 | transition_layers.append(nn.Sequential( 270 | used_conv(num_channels_pre_layer[i], 271 | num_channels_cur_layer[i], 272 | 3, 273 | 1, 274 | 1, 275 | bias=False), 276 | BatchNorm2d( 277 | num_channels_cur_layer[i], momentum=BN_MOMENTUM), 278 | nn.ReLU(inplace=True))) 279 | else: 280 | transition_layers.append(None) 281 | else: 282 | conv3x3s = [] 283 | for j in range(i+1-num_branches_pre): 284 | inchannels = num_channels_pre_layer[-1] 285 | outchannels = num_channels_cur_layer[i] \ 286 | if j == i-num_branches_pre else inchannels 287 | conv3x3s.append(nn.Sequential( 288 | used_conv( 289 | inchannels, outchannels, 3, 2, 1, bias=False), 290 | BatchNorm2d(outchannels, momentum=BN_MOMENTUM), 291 | nn.ReLU(inplace=True))) 292 | transition_layers.append(nn.Sequential(*conv3x3s)) 293 | 294 | return nn.ModuleList(transition_layers) 295 | 296 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 297 | downsample = None 298 | if stride != 1 or inplanes != planes * block.expansion: 299 | downsample = nn.Sequential( 300 | used_conv(inplanes, planes * block.expansion, 301 | kernel_size=1, stride=stride, bias=False), 302 | BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 303 | ) 304 | 305 | layers = [] 306 | layers.append(block(inplanes, planes, stride, downsample)) 307 | inplanes = planes * block.expansion 308 | for i in range(1, blocks): 309 | layers.append(block(inplanes, planes)) 310 | 311 | return nn.Sequential(*layers) 312 | 313 | def _make_stage(self, layer_config, num_inchannels, 314 | multi_scale_output=True): 315 | num_modules = layer_config['NUM_MODULES'] 316 | num_branches = layer_config['NUM_BRANCHES'] 317 | num_blocks = layer_config['NUM_BLOCKS'] 318 | num_channels = layer_config['NUM_CHANNELS'] 319 | block = blocks_dict[layer_config['BLOCK']] 320 | fuse_method = layer_config['FUSE_METHOD'] 321 | 322 | modules = [] 323 | for i in range(num_modules): 324 | if not multi_scale_output and i == num_modules - 1: 325 | reset_multi_scale_output = False 326 | else: 327 | reset_multi_scale_output = True 328 | modules.append( 329 | HighResolutionModule(num_branches, 330 | block, 331 | num_blocks, 332 | num_inchannels, 333 | num_channels, 334 | fuse_method, 335 | reset_multi_scale_output) 336 | ) 337 | num_inchannels = modules[-1].get_num_inchannels() 338 | 339 | return nn.Sequential(*modules), num_inchannels 340 | 341 | def __init__(self, config, **kwargs): 342 | 343 | 344 | super(HighResolutionNet, self).__init__() 345 | extra = config.MODEL.EXTRA 346 | self.extra = extra 347 | self.mask_cfg = config.MASK 348 | 349 | global mask_conv, mask_conv_no_interpolate 350 | mask_conv = partial(conv_mask_uniform, p=self.mask_cfg.P, interpolate=self.mask_cfg.INTERPOLATION) 351 | mask_conv_no_interpolate = partial(conv_mask_uniform, p=self.mask_cfg.P, interpolate='none') 352 | global used_conv 353 | used_conv = nn.Conv2d 354 | 355 | self.num_exits = len(extra.EE_WEIGHTS) 356 | self.num_classes = config.DATASET.NUM_CLASSES 357 | if 'profiling_cpu' in kwargs or 'profiling_gpu' in kwargs: 358 | self.profiling_meters = [AverageMeter() for i in range(self.num_exits)] 359 | self.profiling_gpu = 'profiling_gpu' in kwargs 360 | self.profiling_cpu = 'profiling_cpu' in kwargs 361 | self.forward_count = 0 362 | else: 363 | self.profiling_gpu, self.profiling_cpu = False, False 364 | 365 | self.conv1 = used_conv(3, 64, kernel_size=3, stride=2, padding=1, 366 | bias=False) 367 | self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) 368 | self.conv2 = used_conv(64, 64, kernel_size=3, stride=2, padding=1, 369 | bias=False) 370 | self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) 371 | self.relu = nn.ReLU(inplace=True) 372 | 373 | self.stage1_cfg = extra['STAGE1'] 374 | num_channels = self.stage1_cfg['NUM_CHANNELS'][0] 375 | block = blocks_dict[self.stage1_cfg['BLOCK']] 376 | num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] 377 | self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) 378 | stage1_out_channel = block.expansion*num_channels 379 | self.exit1 = self.get_exit_layer(stage1_out_channel, config, exit_number=1) 380 | 381 | if self.mask_cfg.USE: 382 | used_conv = mask_conv 383 | else: 384 | used_conv = nn.Conv2d 385 | 386 | self.stage2_cfg = extra['STAGE2'] 387 | num_channels = self.stage2_cfg['NUM_CHANNELS'] 388 | block = blocks_dict[self.stage2_cfg['BLOCK']] 389 | num_channels = [ 390 | num_channels[i] * block.expansion for i in range(len(num_channels))] 391 | self.transition1 = self._make_transition_layer( 392 | [stage1_out_channel], num_channels) 393 | self.stage2, pre_stage_channels = self._make_stage( 394 | self.stage2_cfg, num_channels) 395 | self.exit2 = self.get_exit_layer(np.int(np.sum(pre_stage_channels)), config, exit_number=2) 396 | 397 | if self.mask_cfg.USE: 398 | used_conv = mask_conv 399 | else: 400 | used_conv = nn.Conv2d 401 | 402 | 403 | self.stage3_cfg = extra['STAGE3'] 404 | num_channels = self.stage3_cfg['NUM_CHANNELS'] 405 | block = blocks_dict[self.stage3_cfg['BLOCK']] 406 | num_channels = [ 407 | num_channels[i] * block.expansion for i in range(len(num_channels))] 408 | self.transition2 = self._make_transition_layer( 409 | pre_stage_channels, num_channels) 410 | self.stage3, pre_stage_channels = self._make_stage( 411 | self.stage3_cfg, num_channels) 412 | self.exit3 = self.get_exit_layer(np.int(np.sum(pre_stage_channels)), config, exit_number=3) 413 | 414 | self.stage4_cfg = extra['STAGE4'] 415 | num_channels = self.stage4_cfg['NUM_CHANNELS'] 416 | block = blocks_dict[self.stage4_cfg['BLOCK']] 417 | num_channels = [ 418 | num_channels[i] * block.expansion for i in range(len(num_channels))] 419 | self.transition3 = self._make_transition_layer( 420 | pre_stage_channels, num_channels) 421 | self.stage4, pre_stage_channels = self._make_stage( 422 | self.stage4_cfg, num_channels, multi_scale_output=True) 423 | 424 | last_inp_channels = np.int(np.sum(pre_stage_channels)) 425 | self.last_layer = self.get_exit_layer(last_inp_channels, config, last=True) 426 | 427 | print(sum(p.numel() for p in self.parameters() if p.requires_grad)) 428 | print(sum(p.numel() for p in self.parameters())) 429 | 430 | 431 | 432 | def profile(self, out, index): 433 | if not (self.profiling_cpu or self.profiling_gpu): 434 | return 435 | self.forward_count += 1 436 | print(self.forward_count) 437 | start_count = 25 * 4 438 | if self.forward_count < start_count: 439 | return 440 | 441 | if self.profiling_cpu: 442 | self.profiling_meters[index].update(time.time() - self.start) 443 | elif self.profiling_gpu: 444 | tmp_out = out.cpu() 445 | torch.cuda.synchronize() 446 | self.profiling_meters[index].update(time.time() - self.start) 447 | else: 448 | return 449 | if index == self.num_exits - 1 and (self.forward_count > start_count + 10): 450 | times = [self.profiling_meters[i].average() for i in range(self.num_exits)] 451 | times.append(np.mean(times)) 452 | print('\t'.join(['{:.3f}'.format(x) for x in times])) 453 | 454 | def get_points_from_confs(self, confs, ratio): 455 | bs, h, w = confs.size(0), confs.size(2), confs.size(3) 456 | idx = torch.arange(h * w, device=confs.device) 457 | h_pos = idx // w 458 | w_pos = idx % w 459 | point_coords_int = torch.cat((h_pos.unsqueeze(1), w_pos.unsqueeze(1)), dim=1) 460 | point_coords_int = point_coords_int.unsqueeze(0).repeat(bs, 1, 1) 461 | num_sampled = point_coords_int.size(1) 462 | 463 | num_certain_points = int(ratio * h * w) 464 | point_certainties = confs.view(bs, 1, -1) 465 | values, idx = torch.topk(point_certainties[:, 0, :], k=num_certain_points, dim=1) 466 | shift = num_sampled * torch.arange(bs, dtype=torch.long, device=confs.device) 467 | idx += shift[:, None] 468 | point_coords_selected_int = point_coords_int.view(-1, 2)[idx.view(-1), :].view( 469 | bs, num_certain_points, 2 470 | ) 471 | point_coords_selected_frac = torch.cat(( (point_coords_selected_int[:, :, 0:1] + 0.5)/float(h), (point_coords_selected_int[:, :, 1:2] + 0.5)/float(w)), dim=2) 472 | return point_coords_selected_int, point_coords_selected_frac 473 | 474 | def get_resized_mask_from_logits(self, logits, h, w,criterion): 475 | if criterion == 'conf_thre': 476 | resized_logits = F.interpolate(logits, size=(h, w)) 477 | resized_probs = F.softmax(resized_logits, dim=1) 478 | resized_confs, _ = resized_probs.max(dim=1, keepdim=True) 479 | mask = (resized_confs <= self.mask_cfg.CONF_THRE).float().view(logits.size(0), h, w) 480 | elif criterion == 'entropy_thre': 481 | resized_logits = F.interpolate(logits, size=(h, w)) 482 | resized_probs = F.softmax(resized_logits, dim=1) 483 | resized_confs = torch.sum( - resized_probs * torch.log(resized_probs), dim=1, keepdim=True) # 484 | mask = (resized_confs >= self.mask_cfg.ENTROPY_THRE).float().view(logits.size(0), h, w) 485 | return mask 486 | 487 | def generate_grid_priors(self): 488 | if hasattr(self, 'mask_grid_prior_dict') and len(self.mask_grid_prior_dict) > 0: 489 | return 490 | self.mask_grid_prior_dict = {} 491 | 492 | for m in self.modules(): 493 | if isinstance(m, conv_mask_uniform): 494 | try: 495 | h,w = m.out_h, m.out_w 496 | except: 497 | logger.info("First forwarding, collecting output size, quit generating grid priors") 498 | break 499 | 500 | if (h,w) in self.mask_grid_prior_dict: 501 | continue 502 | logger.info(f"generating grid priors for size {(h,w)}") 503 | res = torch.zeros((h, w), device=m.weight.device) 504 | stride = self.mask_cfg.GRID_STRIDE 505 | start = (stride - 1) // 2 506 | 507 | for i in range(start, res.size(0), stride): 508 | for j in range(start, res.size(1), stride): 509 | res[i][j] = 1. 510 | 511 | self.mask_grid_prior_dict[(h, w)] = res 512 | 513 | def set_masks(self, logits): 514 | self.mask_dict = {} 515 | for m in self.modules(): 516 | if isinstance(m, conv_mask_uniform): 517 | try: 518 | h,w = m.out_h, m.out_w 519 | except: 520 | logger.info("First forwarding, collecting output size, quit setting masks") 521 | break 522 | 523 | if (h,w) in self.mask_dict: 524 | m.set_mask(self.mask_dict[(h,w)]) 525 | else: 526 | self.mask_dict[(h,w)] = self.get_resized_mask_from_logits(logits, h, w, criterion=self.mask_cfg.CRIT) 527 | 528 | if self.mask_cfg.GRID_PRIOR: 529 | self.mask_dict[(h,w)] = torch.max(self.mask_dict[(h,w)], self.mask_grid_prior_dict[(h, w)]) 530 | m.set_mask(self.mask_dict[(h,w)]) 531 | 532 | 533 | def set_part_masks(self, logits, ref_name, masked_modules): 534 | start = time.time() 535 | self.part_mask_dicts[ref_name] = {} 536 | for module in masked_modules: 537 | for m in module.modules(): 538 | if isinstance(m, conv_mask_uniform): 539 | try: 540 | h,w = m.out_h, m.out_w 541 | except: 542 | logger.info("First forwarding, collecting output size, quit setting masks") 543 | break 544 | if (h,w) in self.part_mask_dicts[ref_name]: 545 | m.set_mask(self.part_mask_dicts[ref_name][(h,w)]) 546 | else: 547 | self.part_mask_dicts[ref_name][(h,w)] = self.get_resized_mask_from_logits(logits, h, w, criterion=self.mask_cfg.CRIT) 548 | m.set_mask(self.part_mask_dicts[ref_name][(h,w)]) 549 | 550 | def forward(self, x): 551 | self.part_mask_dicts = {} 552 | 553 | if self.profiling_gpu: 554 | torch.cuda.synchronize() 555 | if self.profiling_cpu or self.profiling_gpu: 556 | self.start = time.time() 557 | 558 | x = self.conv1(x) 559 | x = self.bn1(x) 560 | x = self.relu(x) 561 | x = self.conv2(x) 562 | x = self.bn2(x) 563 | x = self.relu(x) 564 | x = self.layer1(x) 565 | out1_feat = self.get_exit_input([x], detach=self.extra.EARLY_DETACH) 566 | out1 = self.exit1(out1_feat) # logits of exit 1 567 | out_size = (out1.size(-2), out1.size(-1)) 568 | 569 | # Set mask for all conv_mask modules between exit 1 and exit 2 570 | if self.mask_cfg.USE: 571 | self.set_part_masks(out1, 'out1', [self.transition1, self.stage2, self.exit2]) 572 | if hasattr(self, "stop1"): 573 | return out1 574 | 575 | x_list = [] 576 | for i in range(self.stage2_cfg['NUM_BRANCHES']): 577 | if self.transition1[i] is not None: 578 | x_list.append(self.transition1[i](x)) 579 | else: 580 | x_list.append(x) 581 | y_list = self.stage2(x_list) 582 | out2_feat = self.get_exit_input(y_list, detach=self.extra.EARLY_DETACH) 583 | out2 = self.exit2(out2_feat) 584 | 585 | if self.mask_cfg.USE: 586 | # Compute logits, aggregate results from the previous exit 587 | if self.mask_cfg.AGGR == 'copy' and len(self.part_mask_dicts['out1']) > 0: 588 | result_mask = self.part_mask_dicts['out1'][out_size][:, None, :, :] 589 | out2 = out1 * (1-result_mask) + out2 * result_mask 590 | # Set mask for all conv_mask modules between exit 2 and exit 3 591 | self.set_part_masks(out2, 'out2', [self.transition2, self.stage3, self.exit3]) 592 | if hasattr(self, "stop2"): 593 | return out2 594 | 595 | x_list = [] 596 | for i in range(self.stage3_cfg['NUM_BRANCHES']): 597 | if self.transition2[i] is not None: 598 | x_list.append(self.transition2[i](y_list[-1])) 599 | else: 600 | x_list.append(y_list[i]) 601 | y_list = self.stage3(x_list) 602 | out3_feat = self.get_exit_input(y_list, detach=self.extra.EARLY_DETACH) 603 | out3 = self.exit3(out3_feat) 604 | 605 | if self.mask_cfg.USE: 606 | # Compute logits, aggregate results from the previous exit 607 | if self.mask_cfg.AGGR == 'copy' and len(self.part_mask_dicts['out2']) > 0: 608 | result_mask = self.part_mask_dicts['out2'][out_size][:, None, :, :] 609 | out3 = out2 * (1-result_mask) + out3 * result_mask 610 | # Set mask for all conv_mask module between exit 3 and exit 4 611 | self.set_part_masks(out3, 'out3', [self.transition3, self.stage4, self.last_layer]) 612 | if hasattr(self, "stop3"): 613 | return out3 614 | 615 | x_list = [] 616 | for i in range(self.stage4_cfg['NUM_BRANCHES']): 617 | if self.transition3[i] is not None: 618 | x_list.append(self.transition3[i](y_list[-1])) 619 | else: 620 | x_list.append(y_list[i]) 621 | 622 | y_list = self.stage4(x_list) 623 | out4_feat = self.get_exit_input(y_list, detach=False) 624 | out4 = self.last_layer(out4_feat) 625 | 626 | if self.mask_cfg.USE: 627 | if self.mask_cfg.AGGR == 'copy' and len(self.part_mask_dicts['out3']) > 0: 628 | result_mask = self.part_mask_dicts['out3'][out_size][:, None, :, :] 629 | out4 = out3 * (1-result_mask) + out4 * result_mask 630 | 631 | self.profile(out4, 3) 632 | if hasattr(self, "stop4"): 633 | return out4 634 | 635 | outs = [out1, out2, out3, out4] 636 | 637 | return outs 638 | 639 | 640 | def get_exit_layer(self, num_channels, config, last=False, exit_number=0): 641 | print(f'EXIT num_channels:{num_channels}') 642 | extra = config.MODEL.EXTRA 643 | layer_type = config.EXIT.TYPE if (not last) else 'original' 644 | 645 | inter_channel = int(num_channels) 646 | 647 | if layer_type == 'flex': 648 | assert exit_number in [1,2,3] 649 | type_map = {1: 'downup_pool_1x1_inter_triple', 2: 'downup_pool_1x1_inter_double', 3: 'downup_pool_1x1_inter'} 650 | layer_type = type_map[exit_number] 651 | inter_channel = config.EXIT.INTER_CHANNEL 652 | 653 | if self.mask_cfg.USE: 654 | exit_conv = used_conv 655 | else: 656 | exit_conv = nn.Conv2d 657 | 658 | norm_layer = BatchNorm2d(num_channels, momentum=BN_MOMENTUM) 659 | 660 | if layer_type == 'original': 661 | exit_layer = [ 662 | exit_conv( 663 | in_channels=num_channels, 664 | out_channels=num_channels, 665 | kernel_size=1, 666 | stride=1, 667 | padding=0, 668 | bias=True), 669 | 670 | norm_layer, 671 | nn.ReLU(inplace=True), 672 | exit_conv( 673 | in_channels=num_channels, 674 | out_channels=config.DATASET.NUM_CLASSES, 675 | kernel_size=config.EXIT.FINAL_CONV_KERNEL, 676 | stride=1, 677 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0, 678 | bias=True, 679 | ) 680 | ] 681 | 682 | elif layer_type == 'downup_pool_1x1_inter': 683 | exit_layer = [ 684 | nn.AvgPool2d(2, 2), 685 | exit_conv( 686 | in_channels=num_channels, 687 | out_channels=inter_channel, 688 | kernel_size=1, 689 | stride=1, 690 | padding=0, 691 | bias=True), 692 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 693 | nn.ReLU(inplace=True), 694 | nn.Upsample(scale_factor=2, mode='bilinear'), 695 | exit_conv( 696 | in_channels=inter_channel, 697 | out_channels=config.DATASET.NUM_CLASSES, 698 | kernel_size=config.EXIT.FINAL_CONV_KERNEL, 699 | stride=1, 700 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0, 701 | bias=True, 702 | ) 703 | ] 704 | 705 | 706 | 707 | elif layer_type == 'downup_pool_1x1_inter_double': 708 | exit_layer = [ 709 | nn.AvgPool2d(2, 2), 710 | exit_conv( 711 | in_channels=num_channels, 712 | out_channels=inter_channel, 713 | kernel_size=1, 714 | stride=1, 715 | padding=0, 716 | bias=True), 717 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 718 | nn.ReLU(inplace=True), 719 | 720 | nn.AvgPool2d(2, 2), 721 | exit_conv( 722 | in_channels=inter_channel, 723 | out_channels=inter_channel, 724 | kernel_size=1, 725 | stride=1, 726 | padding=0, 727 | bias=True 728 | ), 729 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 730 | nn.ReLU(inplace=True), 731 | 732 | nn.Upsample(scale_factor=2, mode='bilinear'), 733 | exit_conv( 734 | in_channels=inter_channel, 735 | out_channels=inter_channel, 736 | kernel_size=1, 737 | stride=1, 738 | padding=0, 739 | bias=True 740 | ), 741 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 742 | nn.ReLU(inplace=True), 743 | 744 | nn.Upsample(scale_factor=2, mode='bilinear'), 745 | exit_conv( 746 | in_channels=inter_channel, 747 | out_channels=config.DATASET.NUM_CLASSES, 748 | kernel_size=config.EXIT.FINAL_CONV_KERNEL, 749 | stride=1, 750 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0, 751 | bias=True, 752 | ) 753 | ] 754 | 755 | 756 | elif layer_type == 'downup_pool_1x1_inter_triple': 757 | exit_layer = [ 758 | nn.AvgPool2d(2, 2), 759 | exit_conv( 760 | in_channels=num_channels, 761 | out_channels=inter_channel, 762 | kernel_size=1, 763 | stride=1, 764 | padding=0, 765 | bias=True), 766 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 767 | nn.ReLU(inplace=True), 768 | 769 | nn.AvgPool2d(2, 2), 770 | exit_conv( 771 | in_channels=inter_channel, 772 | out_channels=inter_channel, 773 | kernel_size=1, 774 | stride=1, 775 | padding=0, 776 | bias=True 777 | ), 778 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 779 | nn.ReLU(inplace=True), 780 | 781 | nn.AvgPool2d(2, 2), 782 | exit_conv( 783 | in_channels=inter_channel, 784 | out_channels=inter_channel, 785 | kernel_size=1, 786 | stride=1, 787 | padding=0, 788 | bias=True 789 | ), 790 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 791 | nn.ReLU(inplace=True), 792 | 793 | nn.Upsample(scale_factor=2, mode='bilinear'), 794 | exit_conv( 795 | in_channels=inter_channel, 796 | out_channels=inter_channel, 797 | kernel_size=1, 798 | stride=1, 799 | padding=0, 800 | bias=True 801 | ), 802 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 803 | nn.ReLU(inplace=True), 804 | 805 | nn.Upsample(scale_factor=2, mode='bilinear'), 806 | exit_conv( 807 | in_channels=inter_channel, 808 | out_channels=inter_channel, 809 | kernel_size=1, 810 | stride=1, 811 | padding=0, 812 | bias=True 813 | ), 814 | BatchNorm2d(inter_channel, momentum=BN_MOMENTUM), 815 | nn.ReLU(inplace=True), 816 | 817 | nn.Upsample(scale_factor=2, mode='bilinear'), 818 | exit_conv( 819 | in_channels=inter_channel, 820 | out_channels=config.DATASET.NUM_CLASSES, 821 | kernel_size=config.EXIT.FINAL_CONV_KERNEL, 822 | stride=1, 823 | padding=1 if config.EXIT.FINAL_CONV_KERNEL == 3 else 0, 824 | bias=True, 825 | ) 826 | ] 827 | 828 | exit_layer = nn.Sequential(*exit_layer) 829 | 830 | return exit_layer 831 | 832 | def get_exit_input(self, x, detach=True): 833 | interpolated_list = [x[0]] 834 | x0_h, x0_w = x[0].size(2), x[0].size(3) 835 | 836 | for i in range(1, len(x)): 837 | interpolated_list.append(F.upsample(x[i], size=(x0_h, x0_w), mode='bilinear')) 838 | 839 | ret = torch.cat(interpolated_list, 1) 840 | 841 | return ret.detach() if detach else ret 842 | 843 | 844 | 845 | def init_weights(self, pretrained='', load_stage=1): 846 | logger.info('=> init weights from normal distribution') 847 | for m in self.modules(): 848 | if isinstance(m, nn.Conv2d): 849 | nn.init.normal_(m.weight, std=0.001) 850 | elif isinstance(m, nn.BatchNorm2d): 851 | nn.init.constant_(m.weight, 1) 852 | nn.init.constant_(m.bias, 0) 853 | 854 | if os.path.isfile(pretrained) and load_stage == 0: 855 | pretrained_dict = torch.load(pretrained) 856 | logger.info('=> loading pretrained model {}'.format(pretrained)) 857 | model_dict = self.state_dict() 858 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 859 | elif os.path.isfile(pretrained) and load_stage == 1: 860 | pretrained_dict = torch.load(pretrained) 861 | logger.info('=> loading pretrained model {}'.format(pretrained)) 862 | model_dict = self.state_dict() 863 | pretrained_dict = {k[len('model.'):]: v for k, v in pretrained_dict.items() if k[len('model.'):] in model_dict.keys()} 864 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('exit')} 865 | 866 | elif os.path.isfile(pretrained) and load_stage == 2: 867 | pretrained_dict = torch.load(pretrained) 868 | logger.info('=> loading pretrained model {}'.format(pretrained)) 869 | model_dict = self.state_dict() 870 | pretrained_dict = {k[len('model.'):]: v for k, v in pretrained_dict.items() if k[len('model.'):] in model_dict.keys()} 871 | 872 | logger.info('loading stage: {}, loading {} dict keys'.format(load_stage, len(pretrained_dict))) 873 | model_dict.update(pretrained_dict) 874 | self.load_state_dict(model_dict) 875 | 876 | 877 | class L2Norm(nn.Module): 878 | def __init__(self): 879 | super(L2Norm, self).__init__() 880 | def forward(self, x): 881 | return F.normalize(x, p=2, dim=0) 882 | 883 | 884 | class TemperatureScaling(nn.Module): 885 | def __init__(self, channel_wise, location_wise): 886 | super(TemperatureScaling, self).__init__() 887 | 888 | self.channel_wise = channel_wise 889 | self.location_wise = location_wise 890 | self.shift = 0.5413 891 | 892 | def forward(self, x): 893 | pass 894 | 895 | class TemperatureScalingFixed(TemperatureScaling): 896 | def __init__(self, channel_wise=False, location_wise=False, num_channels=0): 897 | 898 | super(TemperatureScalingFixed, self).__init__(channel_wise=channel_wise, location_wise=location_wise) 899 | self.num_channels = num_channels 900 | 901 | assert (not self.location_wise) 902 | 903 | if channel_wise: 904 | self.t_vector = nn.Parameter(torch.zeros(num_channels), requires_grad=True) 905 | else: 906 | self.t = nn.Parameter(torch.zeros(1), requires_grad=True) 907 | 908 | def forward(self, x): 909 | 910 | if self.channel_wise: 911 | positive_t_vector = F.softplus(self.t_vector + self.shift) 912 | out = x * positive_t_vector[None, :, None, None] 913 | else: 914 | positive_t = F.softplus(self.t + self.shift) 915 | out = x * positive_t 916 | return out 917 | 918 | class TemperatureScalingPredicted(TemperatureScaling): 919 | def __init__(self, channel_wise=False, location_wise=False, in_channels=0, layer_type='conv1'): 920 | super(TemperatureScalingPredicted, self).__init__(channel_wise=channel_wise, location_wise=location_wise) 921 | assert self.location_wise 922 | 923 | self.in_channels = in_channels 924 | self.layer_type = layer_type 925 | 926 | if self.layer_type == 'conv1': 927 | self.layer = used_conv(in_channels, 1, kernel_size=1, padding=0) 928 | elif self.layer_type == 'conv3': 929 | self.layer = used_conv(in_channels, 1, kernel_size=3, padding=1) 930 | elif self.layer_type == 'default_exit': 931 | self.layer = nn.Sequential( 932 | used_conv( 933 | in_channels=in_channels, 934 | out_channels=in_channels, 935 | kernel_size=1, 936 | stride=1, 937 | padding=0), 938 | BatchNorm2d(in_channels, momentum=BN_MOMENTUM), 939 | nn.ReLU(inplace=True), 940 | used_conv( 941 | in_channels=in_channels, 942 | out_channels=1, 943 | kernel_size=1, 944 | stride=1, 945 | padding=0) 946 | ) 947 | else: 948 | raise NotImplementedError('TemperatureScalingPredicted layer type {} not implemented!'.format(self.layer_type)) 949 | 950 | def forward(self, x): 951 | logits = x[0] 952 | features = x[1] 953 | self.t_map = self.layer(features) * 1.0 954 | self.positive_t_map = F.softplus(self.t_map + self.shift) 955 | return logits * self.positive_t_map 956 | 957 | 958 | def get_seg_model(cfg, **kwargs): 959 | model = HighResolutionNet(cfg, **kwargs) 960 | model.init_weights(cfg.MODEL.PRETRAINED, cfg.MODEL.LOAD_STAGE) 961 | 962 | return model 963 | 964 | if __name__ == '__main__': 965 | from config import config 966 | from config import update_config 967 | import argparse 968 | import torch.backends.cudnn as cudnn 969 | 970 | def parse_args(): 971 | parser = argparse.ArgumentParser(description='Train segmentation network') 972 | 973 | parser.add_argument('--cfg', 974 | help='experiment configure file name', 975 | type=str, default='experiments/cityscapes/seg_hrnet_ee_0715_mask.yaml') 976 | parser.add_argument('opts', 977 | help="Modify config options using the command-line", 978 | default=None, 979 | nargs=argparse.REMAINDER) 980 | args = parser.parse_args() 981 | update_config(config, args) 982 | return args 983 | 984 | args = parse_args() 985 | cudnn.benchmark = config.CUDNN.BENCHMARK 986 | cudnn.deterministic = config.CUDNN.DETERMINISTIC 987 | cudnn.enabled = config.CUDNN.ENABLED 988 | 989 | model = eval('get_seg_model')(config) 990 | model = nn.DataParallel(model, device_ids=[0]).cuda() 991 | 992 | for i in range(20): 993 | print(i) 994 | dump_input = torch.rand( 995 | (1, 3, config.TRAIN.IMAGE_SIZE[1]//4, config.TRAIN.IMAGE_SIZE[0]//4) 996 | ) 997 | out = model(dump_input) 998 | 999 | def count_parameters(model): 1000 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 1001 | print(count_parameters(model)) 1002 | 1003 | -------------------------------------------------------------------------------- /lib/models/sync_bn/__init__.py: -------------------------------------------------------------------------------- 1 | from .inplace_abn import bn -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/bn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | 6 | try: 7 | from queue import Queue 8 | except ImportError: 9 | from Queue import Queue 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(os.path.join(BASE_DIR, '../src')) 14 | from functions import * 15 | 16 | 17 | class ABN(nn.Module): 18 | """Activated Batch Normalization 19 | 20 | This gathers a `BatchNorm2d` and an activation function in a single module 21 | """ 22 | 23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 24 | """Creates an Activated Batch Normalization module 25 | 26 | Parameters 27 | ---------- 28 | num_features : int 29 | Number of feature channels in the input and output. 30 | eps : float 31 | Small constant to prevent numerical issues. 32 | momentum : float 33 | Momentum factor applied to compute running statistics as. 34 | affine : bool 35 | If `True` apply learned scale and shift transformation after normalization. 36 | activation : str 37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 38 | slope : float 39 | Negative slope for the `leaky_relu` activation. 40 | """ 41 | super(ABN, self).__init__() 42 | self.num_features = num_features 43 | self.affine = affine 44 | self.eps = eps 45 | self.momentum = momentum 46 | self.activation = activation 47 | self.slope = slope 48 | if self.affine: 49 | self.weight = nn.Parameter(torch.ones(num_features)) 50 | self.bias = nn.Parameter(torch.zeros(num_features)) 51 | else: 52 | self.register_parameter('weight', None) 53 | self.register_parameter('bias', None) 54 | self.register_buffer('running_mean', torch.zeros(num_features)) 55 | self.register_buffer('running_var', torch.ones(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.constant_(self.running_mean, 0) 60 | nn.init.constant_(self.running_var, 1) 61 | if self.affine: 62 | nn.init.constant_(self.weight, 1) 63 | nn.init.constant_(self.bias, 0) 64 | 65 | def forward(self, x): 66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 67 | self.training, self.momentum, self.eps) 68 | 69 | if self.activation == ACT_RELU: 70 | return functional.relu(x, inplace=True) 71 | elif self.activation == ACT_LEAKY_RELU: 72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) 73 | elif self.activation == ACT_ELU: 74 | return functional.elu(x, inplace=True) 75 | else: 76 | return x 77 | 78 | def __repr__(self): 79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 80 | ' affine={affine}, activation={activation}' 81 | if self.activation == "leaky_relu": 82 | rep += ', slope={slope})' 83 | else: 84 | rep += ')' 85 | return rep.format(name=self.__class__.__name__, **self.__dict__) 86 | 87 | 88 | class InPlaceABN(ABN): 89 | """InPlace Activated Batch Normalization""" 90 | 91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 92 | """Creates an InPlace Activated Batch Normalization module 93 | 94 | Parameters 95 | ---------- 96 | num_features : int 97 | Number of feature channels in the input and output. 98 | eps : float 99 | Small constant to prevent numerical issues. 100 | momentum : float 101 | Momentum factor applied to compute running statistics as. 102 | affine : bool 103 | If `True` apply learned scale and shift transformation after normalization. 104 | activation : str 105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 106 | slope : float 107 | Negative slope for the `leaky_relu` activation. 108 | """ 109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) 110 | 111 | def forward(self, x): 112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 113 | self.training, self.momentum, self.eps, self.activation, self.slope) 114 | 115 | 116 | class InPlaceABNSync(ABN): 117 | """InPlace Activated Batch Normalization with cross-GPU synchronization 118 | 119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`. 120 | """ 121 | 122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 123 | slope=0.01): 124 | """Creates a synchronized, InPlace Activated Batch Normalization module 125 | 126 | Parameters 127 | ---------- 128 | num_features : int 129 | Number of feature channels in the input and output. 130 | devices : list of int or None 131 | IDs of the GPUs that will run the replicas of this module. 132 | eps : float 133 | Small constant to prevent numerical issues. 134 | momentum : float 135 | Momentum factor applied to compute running statistics as. 136 | affine : bool 137 | If `True` apply learned scale and shift transformation after normalization. 138 | activation : str 139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 140 | slope : float 141 | Negative slope for the `leaky_relu` activation. 142 | """ 143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope) 144 | self.devices = devices if devices else list(range(torch.cuda.device_count())) 145 | 146 | # Initialize queues 147 | self.worker_ids = self.devices[1:] 148 | self.master_queue = Queue(len(self.worker_ids)) 149 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 150 | 151 | def forward(self, x): 152 | if x.get_device() == self.devices[0]: 153 | # Master mode 154 | extra = { 155 | "is_master": True, 156 | "master_queue": self.master_queue, 157 | "worker_queues": self.worker_queues, 158 | "worker_ids": self.worker_ids 159 | } 160 | else: 161 | # Worker mode 162 | extra = { 163 | "is_master": False, 164 | "master_queue": self.master_queue, 165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 166 | } 167 | 168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope) 170 | 171 | def __repr__(self): 172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 173 | ' affine={affine}, devices={devices}, activation={activation}' 174 | if self.activation == "leaky_relu": 175 | rep += ', slope={slope})' 176 | else: 177 | rep += ')' 178 | return rep.format(name=self.__class__.__name__, **self.__dict__) 179 | -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/functions.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import torch.autograd as autograd 4 | import torch.cuda.comm as comm 5 | from torch.autograd.function import once_differentiable 6 | from torch.utils.cpp_extension import load 7 | 8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src") 9 | _backend = load(name="inplace_abn", 10 | extra_cflags=["-O3"], 11 | sources=[path.join(_src_path, f) for f in [ 12 | "inplace_abn.cpp", 13 | "inplace_abn_cpu.cpp", 14 | "inplace_abn_cuda.cu" 15 | ]], 16 | extra_cuda_cflags=["--expt-extended-lambda"]) 17 | 18 | # Activation names 19 | ACT_RELU = "relu" 20 | ACT_LEAKY_RELU = "leaky_relu" 21 | ACT_ELU = "elu" 22 | ACT_NONE = "none" 23 | 24 | 25 | def _check(fn, *args, **kwargs): 26 | success = fn(*args, **kwargs) 27 | if not success: 28 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 29 | 30 | 31 | def _broadcast_shape(x): 32 | out_size = [] 33 | for i, s in enumerate(x.size()): 34 | if i != 1: 35 | out_size.append(1) 36 | else: 37 | out_size.append(s) 38 | return out_size 39 | 40 | 41 | def _reduce(x): 42 | if len(x.size()) == 2: 43 | return x.sum(dim=0) 44 | else: 45 | n, c = x.size()[0:2] 46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 47 | 48 | 49 | def _count_samples(x): 50 | count = 1 51 | for i, s in enumerate(x.size()): 52 | if i != 1: 53 | count *= s 54 | return count 55 | 56 | 57 | def _act_forward(ctx, x): 58 | if ctx.activation == ACT_LEAKY_RELU: 59 | _backend.leaky_relu_forward(x, ctx.slope) 60 | elif ctx.activation == ACT_ELU: 61 | _backend.elu_forward(x) 62 | elif ctx.activation == ACT_NONE: 63 | pass 64 | 65 | 66 | def _act_backward(ctx, x, dx): 67 | if ctx.activation == ACT_LEAKY_RELU: 68 | _backend.leaky_relu_backward(x, dx, ctx.slope) 69 | elif ctx.activation == ACT_ELU: 70 | _backend.elu_backward(x, dx) 71 | elif ctx.activation == ACT_NONE: 72 | pass 73 | 74 | 75 | class InPlaceABN(autograd.Function): 76 | @staticmethod 77 | def forward(ctx, x, weight, bias, running_mean, running_var, 78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 79 | ctx.training = training 80 | ctx.momentum = momentum 81 | ctx.eps = eps 82 | ctx.activation = activation 83 | ctx.slope = slope 84 | ctx.affine = weight is not None and bias is not None 85 | 86 | count = _count_samples(x) 87 | x = x.contiguous() 88 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 89 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 90 | 91 | if ctx.training: 92 | mean, var = _backend.mean_var(x) 93 | 94 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 95 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 96 | 97 | ctx.mark_dirty(x, running_mean, running_var) 98 | else: 99 | mean, var = running_mean.contiguous(), running_var.contiguous() 100 | ctx.mark_dirty(x) 101 | 102 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 103 | _act_forward(ctx, x) 104 | 105 | ctx.var = var 106 | ctx.save_for_backward(x, var, weight, bias) 107 | return x 108 | 109 | @staticmethod 110 | @once_differentiable 111 | def backward(ctx, dz): 112 | z, var, weight, bias = ctx.saved_tensors 113 | dz = dz.contiguous() 114 | 115 | _act_backward(ctx, z, dz) 116 | 117 | if ctx.training: 118 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 119 | else: 120 | edz = dz.new_zeros(dz.size(1)) 121 | eydz = dz.new_zeros(dz.size(1)) 122 | 123 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 124 | dweight = dweight if ctx.affine else None 125 | dbias = dbias if ctx.affine else None 126 | 127 | return dx, dweight, dbias, None, None, None, None, None, None, None 128 | 129 | 130 | class InPlaceABNSync(autograd.Function): 131 | @classmethod 132 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 133 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 134 | cls._parse_extra(ctx, extra) 135 | ctx.training = training 136 | ctx.momentum = momentum 137 | ctx.eps = eps 138 | ctx.activation = activation 139 | ctx.slope = slope 140 | ctx.affine = weight is not None and bias is not None 141 | 142 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1) 143 | x = x.contiguous() 144 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 145 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 146 | 147 | if ctx.training: 148 | mean, var = _backend.mean_var(x) 149 | 150 | if ctx.is_master: 151 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)] 152 | for _ in range(ctx.master_queue.maxsize): 153 | mean_w, var_w = ctx.master_queue.get() 154 | ctx.master_queue.task_done() 155 | means.append(mean_w.unsqueeze(0)) 156 | vars.append(var_w.unsqueeze(0)) 157 | 158 | means = comm.gather(means) 159 | vars = comm.gather(vars) 160 | 161 | mean = means.mean(0) 162 | var = (vars + (mean - means) ** 2).mean(0) 163 | 164 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) 165 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 166 | queue.put(ts) 167 | else: 168 | ctx.master_queue.put((mean, var)) 169 | mean, var = ctx.worker_queue.get() 170 | ctx.worker_queue.task_done() 171 | 172 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 173 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 174 | 175 | ctx.mark_dirty(x, running_mean, running_var) 176 | else: 177 | mean, var = running_mean.contiguous(), running_var.contiguous() 178 | ctx.mark_dirty(x) 179 | 180 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 181 | _act_forward(ctx, x) 182 | 183 | ctx.var = var 184 | ctx.save_for_backward(x, var, weight, bias) 185 | return x 186 | 187 | @staticmethod 188 | @once_differentiable 189 | def backward(ctx, dz): 190 | z, var, weight, bias = ctx.saved_tensors 191 | dz = dz.contiguous() 192 | 193 | _act_backward(ctx, z, dz) 194 | 195 | if ctx.training: 196 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 197 | 198 | if ctx.is_master: 199 | edzs, eydzs = [edz], [eydz] 200 | for _ in range(len(ctx.worker_queues)): 201 | edz_w, eydz_w = ctx.master_queue.get() 202 | ctx.master_queue.task_done() 203 | edzs.append(edz_w) 204 | eydzs.append(eydz_w) 205 | 206 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) 207 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) 208 | 209 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) 210 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 211 | queue.put(ts) 212 | else: 213 | ctx.master_queue.put((edz, eydz)) 214 | edz, eydz = ctx.worker_queue.get() 215 | ctx.worker_queue.task_done() 216 | else: 217 | edz = dz.new_zeros(dz.size(1)) 218 | eydz = dz.new_zeros(dz.size(1)) 219 | 220 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 221 | dweight = dweight if ctx.affine else None 222 | dbias = dbias if ctx.affine else None 223 | 224 | return dx, dweight, dbias, None, None, None, None, None, None, None, None 225 | 226 | @staticmethod 227 | def _parse_extra(ctx, extra): 228 | ctx.is_master = extra["is_master"] 229 | if ctx.is_master: 230 | ctx.master_queue = extra["master_queue"] 231 | ctx.worker_queues = extra["worker_queues"] 232 | ctx.worker_ids = extra["worker_ids"] 233 | else: 234 | ctx.master_queue = extra["master_queue"] 235 | ctx.worker_queue = extra["worker_queue"] 236 | 237 | 238 | inplace_abn = InPlaceABN.apply 239 | inplace_abn_sync = InPlaceABNSync.apply 240 | 241 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] 242 | -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/src/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | const int WARP_SIZE = 32; 6 | const int MAX_BLOCK_SIZE = 512; 7 | 8 | template 9 | struct Pair { 10 | T v1, v2; 11 | __device__ Pair() {} 12 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 13 | __device__ Pair(T v) : v1(v), v2(v) {} 14 | __device__ Pair(int v) : v1(v), v2(v) {} 15 | __device__ Pair &operator+=(const Pair &a) { 16 | v1 += a.v1; 17 | v2 += a.v2; 18 | return *this; 19 | } 20 | }; 21 | 22 | template 23 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 24 | unsigned int mask = 0xffffffff) { 25 | #if CUDART_VERSION >= 9000 26 | return __shfl_xor_sync(mask, value, laneMask, width); 27 | #else 28 | return __shfl_xor(value, laneMask, width); 29 | #endif 30 | } 31 | 32 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 33 | 34 | static int getNumThreads(int nElem) { 35 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 36 | for (int i = 0; i != 5; ++i) { 37 | if (nElem <= threadSizes[i]) { 38 | return threadSizes[i]; 39 | } 40 | } 41 | return MAX_BLOCK_SIZE; 42 | } 43 | 44 | template 45 | static __device__ __forceinline__ T warpSum(T val) { 46 | #if __CUDA_ARCH__ >= 300 47 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 48 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 49 | } 50 | #else 51 | __shared__ T values[MAX_BLOCK_SIZE]; 52 | values[threadIdx.x] = val; 53 | __threadfence_block(); 54 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 55 | for (int i = 1; i < WARP_SIZE; i++) { 56 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 57 | } 58 | #endif 59 | return val; 60 | } 61 | 62 | template 63 | static __device__ __forceinline__ Pair warpSum(Pair value) { 64 | value.v1 = warpSum(value.v1); 65 | value.v2 = warpSum(value.v2); 66 | return value; 67 | } 68 | 69 | template 70 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 71 | T sum = (T)0; 72 | for (int batch = 0; batch < N; ++batch) { 73 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 74 | sum += op(batch, plane, x); 75 | } 76 | } 77 | 78 | sum = warpSum(sum); 79 | 80 | __shared__ T shared[32]; 81 | __syncthreads(); 82 | if (threadIdx.x % WARP_SIZE == 0) { 83 | shared[threadIdx.x / WARP_SIZE] = sum; 84 | } 85 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 86 | shared[threadIdx.x] = (T)0; 87 | } 88 | __syncthreads(); 89 | if (threadIdx.x / WARP_SIZE == 0) { 90 | sum = warpSum(shared[threadIdx.x]); 91 | if (threadIdx.x == 0) { 92 | shared[0] = sum; 93 | } 94 | } 95 | __syncthreads(); 96 | 97 | return shared[0]; 98 | } -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | return mean_var_cuda(x); 10 | } else { 11 | return mean_var_cpu(x); 12 | } 13 | } 14 | 15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps) { 17 | if (x.is_cuda()) { 18 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 19 | } else { 20 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 21 | } 22 | } 23 | 24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 25 | bool affine, float eps) { 26 | if (z.is_cuda()) { 27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 28 | } else { 29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 30 | } 31 | } 32 | 33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 35 | if (z.is_cuda()) { 36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 37 | } else { 38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 39 | } 40 | } 41 | 42 | void leaky_relu_forward(at::Tensor z, float slope) { 43 | at::leaky_relu_(z, slope); 44 | } 45 | 46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 47 | if (z.is_cuda()) { 48 | return leaky_relu_backward_cuda(z, dz, slope); 49 | } else { 50 | return leaky_relu_backward_cpu(z, dz, slope); 51 | } 52 | } 53 | 54 | void elu_forward(at::Tensor z) { 55 | at::elu_(z); 56 | } 57 | 58 | void elu_backward(at::Tensor z, at::Tensor dz) { 59 | if (z.is_cuda()) { 60 | return elu_backward_cuda(z, dz); 61 | } else { 62 | return elu_backward_cpu(z, dz); 63 | } 64 | } 65 | 66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 67 | m.def("mean_var", &mean_var, "Mean and variance computation"); 68 | m.def("forward", &forward, "In-place forward computation"); 69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 70 | m.def("backward", &backward, "Second part of backward computation"); 71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 73 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 75 | } -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | 10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 11 | bool affine, float eps); 12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 13 | bool affine, float eps); 14 | 15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 18 | bool affine, float eps); 19 | 20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 24 | 25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 27 | 28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | at::Tensor reduce_sum(at::Tensor x) { 8 | if (x.ndimension() == 2) { 9 | return x.sum(0); 10 | } else { 11 | auto x_view = x.view({x.size(0), x.size(1), -1}); 12 | return x_view.sum(-1).sum(0); 13 | } 14 | } 15 | 16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 17 | if (x.ndimension() == 2) { 18 | return v; 19 | } else { 20 | std::vector broadcast_size = {1, -1}; 21 | for (int64_t i = 2; i < x.ndimension(); ++i) 22 | broadcast_size.push_back(1); 23 | 24 | return v.view(broadcast_size); 25 | } 26 | } 27 | 28 | int64_t count(at::Tensor x) { 29 | int64_t count = x.size(0); 30 | for (int64_t i = 2; i < x.ndimension(); ++i) 31 | count *= x.size(i); 32 | 33 | return count; 34 | } 35 | 36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 37 | if (affine) { 38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 39 | } else { 40 | return z; 41 | } 42 | } 43 | 44 | std::vector mean_var_cpu(at::Tensor x) { 45 | auto num = count(x); 46 | auto mean = reduce_sum(x) / num; 47 | auto diff = x - broadcast_to(mean, x); 48 | auto var = reduce_sum(diff.pow(2)) / num; 49 | 50 | return {mean, var}; 51 | } 52 | 53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 54 | bool affine, float eps) { 55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 56 | auto mul = at::rsqrt(var + eps) * gamma; 57 | 58 | x.sub_(broadcast_to(mean, x)); 59 | x.mul_(broadcast_to(mul, x)); 60 | if (affine) x.add_(broadcast_to(bias, x)); 61 | 62 | return x; 63 | } 64 | 65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 66 | bool affine, float eps) { 67 | auto edz = reduce_sum(dz); 68 | auto y = invert_affine(z, weight, bias, affine, eps); 69 | auto eydz = reduce_sum(y * dz); 70 | 71 | return {edz, eydz}; 72 | } 73 | 74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 76 | auto y = invert_affine(z, weight, bias, affine, eps); 77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 78 | 79 | auto num = count(z); 80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 81 | 82 | auto dweight = at::empty(z.type(), {0}); 83 | auto dbias = at::empty(z.type(), {0}); 84 | if (affine) { 85 | dweight = eydz * at::sign(weight); 86 | dbias = edz; 87 | } 88 | 89 | return {dx, dweight, dbias}; 90 | } 91 | 92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 94 | int64_t count = z.numel(); 95 | auto *_z = z.data(); 96 | auto *_dz = dz.data(); 97 | 98 | for (int64_t i = 0; i < count; ++i) { 99 | if (_z[i] < 0) { 100 | _z[i] *= 1 / slope; 101 | _dz[i] *= slope; 102 | } 103 | } 104 | })); 105 | } 106 | 107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 109 | int64_t count = z.numel(); 110 | auto *_z = z.data(); 111 | auto *_dz = dz.data(); 112 | 113 | for (int64_t i = 0; i < count; ++i) { 114 | if (_z[i] < 0) { 115 | _z[i] = log1p(_z[i]); 116 | _dz[i] *= (_z[i] + 1.f); 117 | } 118 | } 119 | })); 120 | } -------------------------------------------------------------------------------- /lib/models/sync_bn/inplace_abn/src/inplace_abn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "common.h" 9 | #include "inplace_abn.h" 10 | 11 | // Checks 12 | #ifndef AT_CHECK 13 | #define AT_CHECK AT_ASSERT 14 | #endif 15 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | // Utilities 20 | void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { 21 | num = x.size(0); 22 | chn = x.size(1); 23 | sp = 1; 24 | for (int64_t i = 2; i < x.ndimension(); ++i) 25 | sp *= x.size(i); 26 | } 27 | 28 | // Operations for reduce 29 | template 30 | struct SumOp { 31 | __device__ SumOp(const T *t, int c, int s) 32 | : tensor(t), chn(c), sp(s) {} 33 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 34 | return tensor[(batch * chn + plane) * sp + n]; 35 | } 36 | const T *tensor; 37 | const int chn; 38 | const int sp; 39 | }; 40 | 41 | template 42 | struct VarOp { 43 | __device__ VarOp(T m, const T *t, int c, int s) 44 | : mean(m), tensor(t), chn(c), sp(s) {} 45 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 46 | T val = tensor[(batch * chn + plane) * sp + n]; 47 | return (val - mean) * (val - mean); 48 | } 49 | const T mean; 50 | const T *tensor; 51 | const int chn; 52 | const int sp; 53 | }; 54 | 55 | template 56 | struct GradOp { 57 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) 58 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} 59 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 60 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; 61 | T _dz = dz[(batch * chn + plane) * sp + n]; 62 | return Pair(_dz, _y * _dz); 63 | } 64 | const T weight; 65 | const T bias; 66 | const T *z; 67 | const T *dz; 68 | const int chn; 69 | const int sp; 70 | }; 71 | 72 | /*********** 73 | * mean_var 74 | ***********/ 75 | 76 | template 77 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { 78 | int plane = blockIdx.x; 79 | T norm = T(1) / T(num * sp); 80 | 81 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, chn, sp) * norm; 82 | __syncthreads(); 83 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, chn, sp) * norm; 84 | 85 | if (threadIdx.x == 0) { 86 | mean[plane] = _mean; 87 | var[plane] = _var; 88 | } 89 | } 90 | 91 | std::vector mean_var_cuda(at::Tensor x) { 92 | CHECK_INPUT(x); 93 | 94 | // Extract dimensions 95 | int64_t num, chn, sp; 96 | get_dims(x, num, chn, sp); 97 | 98 | // Prepare output tensors 99 | auto mean = at::empty(x.type(), {chn}); 100 | auto var = at::empty(x.type(), {chn}); 101 | 102 | // Run kernel 103 | dim3 blocks(chn); 104 | dim3 threads(getNumThreads(sp)); 105 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { 106 | mean_var_kernel<<>>( 107 | x.data(), 108 | mean.data(), 109 | var.data(), 110 | num, chn, sp); 111 | })); 112 | 113 | return {mean, var}; 114 | } 115 | 116 | /********** 117 | * forward 118 | **********/ 119 | 120 | template 121 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, 122 | bool affine, float eps, int num, int chn, int sp) { 123 | int plane = blockIdx.x; 124 | 125 | T _mean = mean[plane]; 126 | T _var = var[plane]; 127 | T _weight = affine ? abs(weight[plane]) + eps : T(1); 128 | T _bias = affine ? bias[plane] : T(0); 129 | 130 | T mul = rsqrt(_var + eps) * _weight; 131 | 132 | for (int batch = 0; batch < num; ++batch) { 133 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 134 | T _x = x[(batch * chn + plane) * sp + n]; 135 | T _y = (_x - _mean) * mul + _bias; 136 | 137 | x[(batch * chn + plane) * sp + n] = _y; 138 | } 139 | } 140 | } 141 | 142 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 143 | bool affine, float eps) { 144 | CHECK_INPUT(x); 145 | CHECK_INPUT(mean); 146 | CHECK_INPUT(var); 147 | CHECK_INPUT(weight); 148 | CHECK_INPUT(bias); 149 | 150 | // Extract dimensions 151 | int64_t num, chn, sp; 152 | get_dims(x, num, chn, sp); 153 | 154 | // Run kernel 155 | dim3 blocks(chn); 156 | dim3 threads(getNumThreads(sp)); 157 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { 158 | forward_kernel<<>>( 159 | x.data(), 160 | mean.data(), 161 | var.data(), 162 | weight.data(), 163 | bias.data(), 164 | affine, eps, num, chn, sp); 165 | })); 166 | 167 | return x; 168 | } 169 | 170 | /*********** 171 | * edz_eydz 172 | ***********/ 173 | 174 | template 175 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, 176 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { 177 | int plane = blockIdx.x; 178 | 179 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 180 | T _bias = affine ? bias[plane] : 0.f; 181 | 182 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, chn, sp); 183 | __syncthreads(); 184 | 185 | if (threadIdx.x == 0) { 186 | edz[plane] = res.v1; 187 | eydz[plane] = res.v2; 188 | } 189 | } 190 | 191 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 192 | bool affine, float eps) { 193 | CHECK_INPUT(z); 194 | CHECK_INPUT(dz); 195 | CHECK_INPUT(weight); 196 | CHECK_INPUT(bias); 197 | 198 | // Extract dimensions 199 | int64_t num, chn, sp; 200 | get_dims(z, num, chn, sp); 201 | 202 | auto edz = at::empty(z.type(), {chn}); 203 | auto eydz = at::empty(z.type(), {chn}); 204 | 205 | // Run kernel 206 | dim3 blocks(chn); 207 | dim3 threads(getNumThreads(sp)); 208 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { 209 | edz_eydz_kernel<<>>( 210 | z.data(), 211 | dz.data(), 212 | weight.data(), 213 | bias.data(), 214 | edz.data(), 215 | eydz.data(), 216 | affine, eps, num, chn, sp); 217 | })); 218 | 219 | return {edz, eydz}; 220 | } 221 | 222 | /*********** 223 | * backward 224 | ***********/ 225 | 226 | template 227 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, 228 | const T *eydz, T *dx, T *dweight, T *dbias, 229 | bool affine, float eps, int num, int chn, int sp) { 230 | int plane = blockIdx.x; 231 | 232 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 233 | T _bias = affine ? bias[plane] : 0.f; 234 | T _var = var[plane]; 235 | T _edz = edz[plane]; 236 | T _eydz = eydz[plane]; 237 | 238 | T _mul = _weight * rsqrt(_var + eps); 239 | T count = T(num * sp); 240 | 241 | for (int batch = 0; batch < num; ++batch) { 242 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 243 | T _dz = dz[(batch * chn + plane) * sp + n]; 244 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; 245 | 246 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; 247 | } 248 | } 249 | 250 | if (threadIdx.x == 0) { 251 | if (affine) { 252 | dweight[plane] = weight[plane] > 0 ? _eydz : -_eydz; 253 | dbias[plane] = _edz; 254 | } 255 | } 256 | } 257 | 258 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 259 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 260 | CHECK_INPUT(z); 261 | CHECK_INPUT(dz); 262 | CHECK_INPUT(var); 263 | CHECK_INPUT(weight); 264 | CHECK_INPUT(bias); 265 | CHECK_INPUT(edz); 266 | CHECK_INPUT(eydz); 267 | 268 | // Extract dimensions 269 | int64_t num, chn, sp; 270 | get_dims(z, num, chn, sp); 271 | 272 | auto dx = at::zeros_like(z); 273 | auto dweight = at::zeros_like(weight); 274 | auto dbias = at::zeros_like(bias); 275 | 276 | // Run kernel 277 | dim3 blocks(chn); 278 | dim3 threads(getNumThreads(sp)); 279 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { 280 | backward_kernel<<>>( 281 | z.data(), 282 | dz.data(), 283 | var.data(), 284 | weight.data(), 285 | bias.data(), 286 | edz.data(), 287 | eydz.data(), 288 | dx.data(), 289 | dweight.data(), 290 | dbias.data(), 291 | affine, eps, num, chn, sp); 292 | })); 293 | 294 | return {dx, dweight, dbias}; 295 | } 296 | 297 | /************** 298 | * activations 299 | **************/ 300 | 301 | template 302 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { 303 | // Create thrust pointers 304 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 305 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 306 | 307 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz, 308 | [slope] __device__ (const T& dz) { return dz * slope; }, 309 | [] __device__ (const T& z) { return z < 0; }); 310 | thrust::transform_if(th_z, th_z + count, th_z, 311 | [slope] __device__ (const T& z) { return z / slope; }, 312 | [] __device__ (const T& z) { return z < 0; }); 313 | } 314 | 315 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { 316 | CHECK_INPUT(z); 317 | CHECK_INPUT(dz); 318 | 319 | int64_t count = z.numel(); 320 | 321 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 322 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count); 323 | })); 324 | } 325 | 326 | template 327 | inline void elu_backward_impl(T *z, T *dz, int64_t count) { 328 | // Create thrust pointers 329 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 330 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 331 | 332 | thrust::transform_if(th_dz, th_dz + count, th_z, th_z, th_dz, 333 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, 334 | [] __device__ (const T& z) { return z < 0; }); 335 | thrust::transform_if(th_z, th_z + count, th_z, 336 | [] __device__ (const T& z) { return log1p(z); }, 337 | [] __device__ (const T& z) { return z < 0; }); 338 | } 339 | 340 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) { 341 | CHECK_INPUT(z); 342 | CHECK_INPUT(dz); 343 | 344 | int64_t count = z.numel(); 345 | 346 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 347 | elu_backward_impl(z.data(), dz.data(), count); 348 | })); 349 | } 350 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhuang13/anytime/d2b58de8b7f99c14550e2ae7a715ad68736f846d/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _fast_hist(label_true, label_pred, n_class): 5 | mask = (label_true >= 0) & (label_true < n_class) 6 | hist = np.bincount( 7 | n_class * label_true[mask].astype(int) + label_pred[mask], 8 | minlength=n_class ** 2, 9 | ).reshape(n_class, n_class) 10 | return hist 11 | 12 | 13 | def scores(label_trues, label_preds, n_class): 14 | hist = np.zeros((n_class, n_class)) 15 | for lt, lp in zip(label_trues, label_preds): 16 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) 17 | acc = np.diag(hist).sum() / hist.sum() 18 | acc_cls = np.diag(hist) / hist.sum(axis=1) 19 | acc_cls = np.nanmean(acc_cls) 20 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 21 | valid = hist.sum(axis=1) > 0 22 | mean_iu = np.nanmean(iu[valid]) 23 | freq = hist.sum(axis=1) / hist.sum() 24 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 25 | cls_iu = dict(zip(range(n_class), iu)) 26 | 27 | return { 28 | "Overall Acc": acc, 29 | "Mean Acc": acc_cls, 30 | "FreqW Acc": fwavacc, 31 | "Mean IoU": mean_iu, 32 | "Class IoU": cls_iu, 33 | } 34 | 35 | 36 | def batch_pix_accuracy(output, target): 37 | _, predict = torch.max(output, 1) 38 | 39 | predict = predict.cpu().numpy().astype('int64') + 1 40 | target = target.cpu().numpy().astype('int64') + 1 41 | 42 | pixel_labeled = np.sum(target > 0) 43 | pixel_correct = np.sum((predict == target)*(target > 0)) 44 | assert pixel_correct <= pixel_labeled, \ 45 | "Correct area should be smaller than Labeled" 46 | return pixel_correct, pixel_labeled 47 | 48 | 49 | def batch_intersection_union(output, target, nclass): 50 | _, predict = torch.max(output, 1) 51 | mini = 1 52 | maxi = nclass 53 | nbins = nclass 54 | predict = predict.cpu().numpy().astype('int64') + 1 55 | target = target.cpu().numpy().astype('int64') + 1 56 | 57 | predict = predict * (target > 0).astype(predict.dtype) 58 | intersection = predict * (predict == target) 59 | area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) 60 | area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) 61 | area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) 62 | area_union = area_pred + area_lab - area_inter 63 | assert (area_inter <= area_union).all(), \ 64 | "Intersection area should be smaller than Union area" 65 | return area_inter, area_union 66 | 67 | 68 | def pixel_accuracy(im_pred, im_lab): 69 | im_pred = np.asarray(im_pred) 70 | im_lab = np.asarray(im_lab) 71 | pixel_labeled = np.sum(im_lab > 0) 72 | pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0)) 73 | return pixel_correct, pixel_labeled 74 | 75 | 76 | def intersection_and_union(im_pred, im_lab, num_class): 77 | im_pred = np.asarray(im_pred) 78 | im_lab = np.asarray(im_lab) 79 | im_pred = im_pred * (im_lab > 0) 80 | intersection = im_pred * (im_pred == im_lab) 81 | area_inter, _ = np.histogram(intersection, bins=num_class-1, 82 | range=(1, num_class - 1)) 83 | area_pred, _ = np.histogram(im_pred, bins=num_class-1, 84 | range=(1, num_class - 1)) 85 | area_lab, _ = np.histogram(im_lab, bins=num_class-1, 86 | range=(1, num_class - 1)) 87 | area_union = area_pred + area_lab - area_inter 88 | return area_inter, area_union 89 | -------------------------------------------------------------------------------- /lib/utils/modelsummary.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | from collections import namedtuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False): 13 | summary = [] 14 | 15 | ModuleDetails = namedtuple( 16 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds", "num_param_counts"]) 17 | hooks = [] 18 | layer_instances = {} 19 | 20 | def add_hooks(module): 21 | 22 | def hook(module, input, output): 23 | class_name = str(module.__class__.__name__) 24 | 25 | instance_index = 1 26 | if class_name not in layer_instances: 27 | layer_instances[class_name] = instance_index 28 | else: 29 | instance_index = layer_instances[class_name] + 1 30 | layer_instances[class_name] = instance_index 31 | 32 | layer_name = class_name + "_" + str(instance_index) 33 | 34 | params = 0 35 | counts = 0 36 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \ 37 | class_name.find("Linear") != -1 or class_name.find('conv') != -1 or class_name.find('Temp') != -1: 38 | for param_ in module.parameters(): 39 | if param_.requires_grad: 40 | params += param_.view(-1).size(0) 41 | counts += 1 42 | 43 | flops = "Not Available" 44 | 45 | if (class_name.find("ConvTranspose2d") != -1) and hasattr(module, "weight"): 46 | flops = ( 47 | torch.prod( 48 | torch.LongTensor(list(module.weight.data.size()))) * 49 | torch.prod( 50 | torch.LongTensor(list(input[0].size())[2:]))).item() 51 | 52 | 53 | elif (class_name.find("Conv") != -1) and hasattr(module, "weight"): 54 | flops = ( 55 | torch.prod( 56 | torch.LongTensor(list(module.weight.data.size()))) * 57 | torch.prod( 58 | torch.LongTensor(list(output.size())[2:]))).item() 59 | 60 | 61 | elif class_name.find("conv")!= -1: 62 | flops = ( 63 | torch.prod( 64 | torch.LongTensor(list(module.weight.data.size()))) * 65 | torch.prod( 66 | torch.LongTensor(list(output.size())[2:]))).item() 67 | p = module.mask.mean().item() 68 | p2 = module.p 69 | flops = int(flops * p) 70 | if module.interpolate == 'rbf' or module.interpolate == 'pooling' or module.interpolate == 'conv': 71 | flops += int((1 - p) * p * output.size(1) * (module.r * module.r) * output.size(2) * output.size(3)) 72 | elif isinstance(module, nn.Linear): 73 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \ 74 | * input[0].size(1)).item() 75 | 76 | if isinstance(input[0], list): 77 | input = input[0] 78 | if isinstance(output, list): 79 | output = output[0] 80 | 81 | summary.append( 82 | ModuleDetails( 83 | name=layer_name, 84 | input_size=list(input[0].size()), 85 | output_size=list(output.size()), 86 | num_parameters=params, 87 | multiply_adds=flops, 88 | num_param_counts=counts,) 89 | ) 90 | 91 | if not isinstance(module, nn.ModuleList) \ 92 | and not isinstance(module, nn.Sequential) \ 93 | and module != model: 94 | hooks.append(module.register_forward_hook(hook)) 95 | 96 | model.eval() 97 | model.apply(add_hooks) 98 | 99 | space_len = item_length 100 | 101 | model(*input_tensors) 102 | for hook in hooks: 103 | hook.remove() 104 | 105 | details = '' 106 | if verbose: 107 | details = "Model Summary" + \ 108 | os.linesep + \ 109 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( 110 | ' ' * (space_len - len("Name")), 111 | ' ' * (space_len - len("Input Size")), 112 | ' ' * (space_len - len("Output Size")), 113 | ' ' * (space_len - len("Parameters")), 114 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \ 115 | + os.linesep + '-' * space_len * 5 + os.linesep 116 | 117 | params_sum = 0 118 | flops_sum = 0 119 | counts_sum = 0 120 | for layer in summary: 121 | params_sum += layer.num_parameters 122 | if layer.multiply_adds != "Not Available": 123 | flops_sum += layer.multiply_adds 124 | counts_sum += layer.num_param_counts 125 | shown_flops = layer.multiply_adds/(1024**3) if layer.multiply_adds != 'Not Available' else 0 126 | if verbose: 127 | if shown_flops < 1: 128 | details += '' 129 | else: 130 | details += "{}{}{}{}{}{}{}{}{}{}".format( 131 | layer.name, 132 | ' ' * (space_len - len(layer.name)), 133 | layer.input_size, 134 | ' ' * (space_len - len(str(layer.input_size))), 135 | layer.output_size, 136 | ' ' * (space_len - len(str(layer.output_size))), 137 | layer.num_parameters, 138 | ' ' * (space_len - len(str(layer.num_parameters))), 139 | shown_flops, 140 | ' ' * (space_len - len(str(shown_flops)))) \ 141 | + os.linesep + '-' * space_len * 5 + os.linesep 142 | 143 | details += os.linesep \ 144 | + "Total Parameters: {:,}".format(params_sum) \ 145 | + os.linesep + '-' * space_len * 5 + os.linesep 146 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \ 147 | + os.linesep + '-' * space_len * 5 + os.linesep 148 | details += "Number of Layers" + os.linesep 149 | for layer in layer_instances: 150 | details += "{} : {} layers ".format(layer, layer_instances[layer]) 151 | 152 | return details, {'params': params_sum, 153 | 'flops': flops_sum, 154 | 'counts': counts_sum} 155 | -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import time 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | import pdb, time 17 | 18 | def json_save(filename, json_obj): 19 | import json 20 | with open(filename, 'w') as f: 21 | json.dump(json_obj, f, indent=4) 22 | 23 | def json_read(filename): 24 | import json 25 | with open(filename, 'r') as f: 26 | data = json.load(f) 27 | 28 | return data 29 | 30 | 31 | class FullModel(nn.Module): 32 | def __init__(self, model, loss): 33 | super(FullModel, self).__init__() 34 | self.model = model 35 | self.loss = loss 36 | 37 | def forward(self, inputs, labels): 38 | outputs = self.model(inputs) 39 | loss = self.loss(outputs, labels) 40 | return torch.unsqueeze(loss,0), outputs 41 | 42 | 43 | class FullEEModel(nn.Module): 44 | def __init__(self, model, loss, config=None): 45 | super(FullEEModel, self).__init__() 46 | self.model = model 47 | self.loss = loss 48 | self.cfg = config 49 | 50 | def forward(self, inputs, labels): 51 | outputs = self.model(inputs) 52 | losses = [] 53 | for i, output in enumerate(outputs): 54 | losses.append(self.loss(outputs[i], labels)) 55 | 56 | return losses, outputs 57 | 58 | def get_world_size(): 59 | if not torch.distributed.is_initialized(): 60 | return 1 61 | return torch.distributed.get_world_size() 62 | 63 | def get_rank(): 64 | if not torch.distributed.is_initialized(): 65 | return 0 66 | return torch.distributed.get_rank() 67 | 68 | class AverageMeter(object): 69 | def __init__(self): 70 | self.initialized = False 71 | self.val = None 72 | self.avg = None 73 | self.sum = None 74 | self.count = None 75 | 76 | def initialize(self, val, weight): 77 | self.val = val 78 | self.avg = val 79 | self.sum = val * weight 80 | self.count = weight 81 | self.initialized = True 82 | 83 | def update(self, val, weight=1): 84 | if not self.initialized: 85 | self.initialize(val, weight) 86 | else: 87 | self.add(val, weight) 88 | 89 | def add(self, val, weight): 90 | self.val = val 91 | self.sum += val * weight 92 | self.count += weight 93 | self.avg = self.sum / self.count 94 | 95 | def value(self): 96 | return self.val 97 | 98 | def average(self): 99 | return self.avg 100 | 101 | def create_logger(cfg, cfg_name, phase='train'): 102 | root_output_dir = Path(cfg.OUTPUT_DIR) 103 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) 104 | 105 | dataset = cfg.DATASET.DATASET 106 | model = cfg.MODEL.NAME 107 | cfg_name = os.path.basename(cfg_name).split('.')[0] 108 | final_output_dir = root_output_dir 109 | os.makedirs(final_output_dir, exist_ok=True) 110 | 111 | print('=> creating {}'.format(final_output_dir)) 112 | final_output_dir.mkdir(parents=True, exist_ok=True) 113 | 114 | time_str = time.strftime('%Y-%m-%d-%H-%M') 115 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 116 | final_log_file = final_output_dir / log_file 117 | head = '%(asctime)-15s %(message)s' 118 | logging.basicConfig(filename=str(final_log_file), 119 | format=head) 120 | logger = logging.getLogger() 121 | logger.setLevel(logging.INFO) 122 | console = logging.StreamHandler() 123 | logging.getLogger('').addHandler(console) 124 | tensorboard_log_dir = final_output_dir 125 | print('=> creating {}'.format(tensorboard_log_dir)) 126 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 127 | 128 | return logger, str(final_output_dir), str(tensorboard_log_dir) 129 | 130 | 131 | def get_confusion_matrix_gpu(label, pred, size, num_class, ignore=-1, device=0): 132 | output = pred.transpose(1,3).transpose(1,2) 133 | seg_pred = torch.max(output, dim=3)[1] 134 | seg_gt = label 135 | 136 | ignore_index = seg_gt != ignore 137 | 138 | seg_gt = seg_gt[ignore_index] 139 | 140 | seg_pred = seg_pred[ignore_index] 141 | if seg_gt.get_device() == -1: 142 | seg_gt = seg_gt.to(0) 143 | 144 | index = ((seg_gt * num_class).long()+ seg_pred) 145 | label_count = torch.bincount(index) 146 | confusion_matrix = np.zeros((num_class, num_class)) 147 | 148 | for i_label in range(num_class): 149 | for i_pred in range(num_class): 150 | cur_index = i_label * num_class + i_pred 151 | if cur_index < len(label_count): 152 | confusion_matrix[i_label, 153 | i_pred] = label_count[cur_index] 154 | return confusion_matrix 155 | 156 | def adjust_learning_rate(optimizer, base_lr, max_iters, 157 | cur_iters, power=0.9): 158 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) 159 | for i, param in enumerate(optimizer.param_groups): 160 | optimizer.param_groups[i]['lr'] = lr 161 | return lr -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cityscapesscripts 2 | numpy 3 | Pillow 4 | tensorboardX 5 | tqdm 6 | yacs -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path as osp 6 | import sys 7 | 8 | 9 | def add_path(path): 10 | if path not in sys.path: 11 | sys.path.insert(0, path) 12 | 13 | this_dir = osp.dirname(__file__) 14 | 15 | lib_path = osp.join(this_dir, '..', 'lib') 16 | add_path(lib_path) 17 | -------------------------------------------------------------------------------- /tools/test_ee.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pprint 4 | import shutil 5 | import sys 6 | 7 | import logging 8 | import time 9 | import timeit 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.backends.cudnn as cudnn 17 | 18 | import _init_paths 19 | import models 20 | import datasets 21 | from config import config 22 | from config import update_config 23 | from core.function import testval_ee, testval_ee_profiling, testval_ee_profiling_actual 24 | from utils.modelsummary import get_model_summary 25 | from utils.utils import create_logger, FullModel, FullEEModel, json_save 26 | 27 | import pdb 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='Train segmentation network') 31 | 32 | parser.add_argument('--cfg', 33 | help='experiment configure file name', 34 | required=True, 35 | type=str) 36 | parser.add_argument('opts', 37 | help="Modify config options using the command-line", 38 | default=None, 39 | nargs=argparse.REMAINDER) 40 | 41 | args = parser.parse_args() 42 | 43 | update_config(config, args) 44 | 45 | return args 46 | 47 | def main(): 48 | args = parse_args() 49 | 50 | config.defrost() 51 | config.OUTPUT_DIR = args.cfg[:-len('config.yaml')] 52 | try: 53 | if config.TEST.SUB_DIR: 54 | config.OUTPUT_DIR = os.path.join(config.OUTPUT_DIR, config.TEST.SUB_DIR) 55 | except: 56 | pass 57 | config.freeze() 58 | 59 | logger, final_output_dir, _ = create_logger( 60 | config, args.cfg, 'test') 61 | 62 | logger.info(pprint.pformat(args)) 63 | logger.info(pprint.pformat(config)) 64 | 65 | cudnn.benchmark = config.CUDNN.BENCHMARK 66 | cudnn.deterministic = config.CUDNN.DETERMINISTIC 67 | cudnn.enabled = config.CUDNN.ENABLED 68 | 69 | model = eval('models.'+config.MODEL.NAME + 70 | '.get_seg_model')(config) 71 | 72 | device = 0 73 | model.eval() 74 | 75 | dump_input = torch.rand( 76 | (1, 3, config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0]) 77 | ) 78 | 79 | if config.PYRAMID_TEST.USE: 80 | dump_input = torch.rand( 81 | (1, 3, config.PYRAMID_TEST.SIZE, config.PYRAMID_TEST.SIZE // 2) 82 | ) 83 | dump_output = model.to(device)(dump_input.to(device)) 84 | del dump_output 85 | dump_output = model.to(device)(dump_input.to(device)) 86 | 87 | if not (config.MASK.USE and (config.MASK.CRIT == 'conf_thre' or config.MASK.CRIT == 'entropy_thre')): 88 | stats = {} 89 | saved_stats = {} 90 | for i in range(4): 91 | setattr(model, f"stop{i+1}", "anY_RanDOM_ThiNg") 92 | summary, stats[i+1] = get_model_summary(model.to(device), dump_input.to(device), verbose=True) 93 | delattr(model, f"stop{i+1}") 94 | 95 | logger.info(f'\n\n>>>>>>>>>>>>>>>>>>>>>>> EXIT {i+1} >>>>>>>>>>>>>>>>>>>>>>>>>> ') 96 | logger.info(summary) 97 | 98 | saved_stats['params'] = [stats[i+1]['params'] for i in range(4)] 99 | saved_stats['flops'] = [stats[i+1]['flops'] for i in range(4)] 100 | saved_stats['counts'] = [stats[i+1]['counts'] for i in range(4)] 101 | saved_stats['Gflops'] = [f/(1024**3) for f in saved_stats['flops']] 102 | saved_stats['Gflops_mean'] = np.mean(saved_stats['Gflops']) 103 | saved_stats['Mparams'] = [f/(10**6) for f in saved_stats['params']] 104 | json_save(os.path.join(final_output_dir, 'test_stats.json'), saved_stats) 105 | 106 | if config.TEST.MODEL_FILE: 107 | model_state_file = config.TEST.MODEL_FILE 108 | else: 109 | model_state_file = os.path.join(final_output_dir, 110 | 'final_state.pth') 111 | 112 | try: 113 | if config.TEST.SUB_DIR: 114 | model_state_file = args.cfg[:-len('config.yaml')] + 'final_state.pth' 115 | except: 116 | pass 117 | 118 | logger.info('=> loading model from {}'.format(model_state_file)) 119 | 120 | pretrained_dict = torch.load(model_state_file) 121 | model_dict = model.state_dict() 122 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() 123 | if k[6:] in model_dict.keys()} 124 | model_dict.update(pretrained_dict) 125 | model.load_state_dict(model_dict) 126 | 127 | gpus = [0] 128 | 129 | model = nn.DataParallel(model, device_ids=gpus).cuda() 130 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0]) 131 | test_dataset = eval('datasets.'+config.DATASET.DATASET)( 132 | root=config.DATASET.ROOT, 133 | list_path=config.DATASET.TEST_SET, 134 | num_samples=None, 135 | num_classes=config.DATASET.NUM_CLASSES, 136 | multi_scale=False, 137 | flip=False, 138 | ignore_label=config.TRAIN.IGNORE_LABEL, 139 | base_size=config.TEST.BASE_SIZE, 140 | crop_size=test_size, 141 | downsample_rate=1) 142 | 143 | testloader = torch.utils.data.DataLoader( 144 | test_dataset, 145 | batch_size=1, 146 | shuffle=False, 147 | num_workers=config.WORKERS, 148 | pin_memory=True) 149 | 150 | start = timeit.default_timer() 151 | 152 | if 'val' in config.DATASET.TEST_SET: 153 | results = testval_ee(config, 154 | test_dataset, 155 | testloader, 156 | model, sv_dir=final_output_dir, sv_pred=True) 157 | 158 | if config.MASK.USE and config.MASK.CRIT == 'conf_thre': 159 | results_profiling = testval_ee_profiling(config, 160 | test_dataset, 161 | testloader, 162 | model, sv_dir=final_output_dir, sv_pred=True) 163 | json_save(os.path.join(final_output_dir, 'test_stats.json'), results_profiling) 164 | 165 | mean_IoUs = [] 166 | for i, result in enumerate(results): 167 | mean_IoU, IoU_array, pixel_acc, mean_acc = result 168 | 169 | msg = 'Exit: {}, MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \ 170 | Mean_Acc: {: 4.4f}, Class IoU: '.format(i+1, mean_IoU, 171 | pixel_acc, mean_acc) 172 | logging.info(msg) 173 | logging.info(IoU_array) 174 | 175 | mean_IoUs.append(mean_IoU) 176 | 177 | 178 | mean_IoUs.append(np.mean(mean_IoUs)) 179 | print_result = '\t'.join(['{:.2f}'.format(m*100) for m in mean_IoUs]) 180 | result_file_name = f'{final_output_dir}/result.txt' 181 | 182 | with open(result_file_name, 'w') as f: 183 | f.write(print_result) 184 | 185 | end = timeit.default_timer() 186 | logger.info('Mins: %d' % np.int((end-start)/60)) 187 | logger.info('Done') 188 | logging.info(print_result) 189 | 190 | if __name__ == '__main__': 191 | main() 192 | -------------------------------------------------------------------------------- /tools/train_ee.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pprint 4 | import shutil 5 | import sys 6 | 7 | import logging 8 | import time 9 | import timeit 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim 18 | from torch.utils.data.distributed import DistributedSampler 19 | from tensorboardX import SummaryWriter 20 | 21 | import _init_paths 22 | import models 23 | import datasets 24 | from config import config 25 | from config import update_config 26 | from core.criterion import CrossEntropy 27 | from core.function import train_ee, validate_ee 28 | from utils.modelsummary import get_model_summary 29 | from utils.utils import create_logger, FullModel, FullEEModel, get_rank, json_save 30 | 31 | import pdb, time 32 | import subprocess 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='Train segmentation network') 36 | 37 | parser.add_argument('--cfg', 38 | help='experiment configure file name', 39 | required=True, 40 | type=str) 41 | parser.add_argument("--local_rank", type=int, default=0) 42 | parser.add_argument('opts', 43 | help="Modify config options using the command-line", 44 | default=None, 45 | nargs=argparse.REMAINDER) 46 | 47 | args = parser.parse_args() 48 | update_config(config, args) 49 | 50 | return args 51 | 52 | def main(): 53 | args = parse_args() 54 | 55 | logger, final_output_dir, tb_log_dir = create_logger( 56 | config, args.cfg, 'train') 57 | 58 | if args.local_rank == 0: 59 | logger.info(config) 60 | 61 | writer_dict = { 62 | 'writer': SummaryWriter(tb_log_dir), 63 | 'train_global_steps': 0, 64 | 'valid_global_steps': 0, 65 | } 66 | 67 | cudnn.benchmark = config.CUDNN.BENCHMARK 68 | cudnn.deterministic = config.CUDNN.DETERMINISTIC 69 | cudnn.enabled = config.CUDNN.ENABLED 70 | 71 | gpus = list(config.GPUS) 72 | distributed = len(gpus) > 1 73 | device = torch.device('cuda:{}'.format(args.local_rank)) 74 | 75 | model = eval('models.'+config.MODEL.NAME + 76 | '.get_seg_model')(config) 77 | 78 | 79 | 80 | if args.local_rank == 0: 81 | with open(f"{final_output_dir}/config.yaml", "w") as f: 82 | f.write(config.dump()) 83 | 84 | this_dir = os.path.dirname(__file__) 85 | models_dst_dir = os.path.join(final_output_dir, 'code') 86 | if os.path.exists(models_dst_dir): 87 | shutil.rmtree(models_dst_dir) 88 | shutil.copytree(os.path.join(this_dir, '../lib'), os.path.join(models_dst_dir, 'lib')) 89 | shutil.copytree(os.path.join(this_dir, '../tools'), os.path.join(models_dst_dir, 'tools')) 90 | shutil.copytree(os.path.join(this_dir, '../scripts'), os.path.join(models_dst_dir, 'scripts')) 91 | shutil.copytree(os.path.join(this_dir, '../experiments'), os.path.join(models_dst_dir, 'experiments')) 92 | 93 | if True: 94 | model.eval() 95 | dump_input = torch.rand( 96 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) 97 | ) 98 | dump_output = model.to(device)(dump_input.to(device)) 99 | 100 | dump_output = model.to(device)(dump_input.to(device)) 101 | 102 | stats = {} 103 | saved_stats = {} 104 | for i in range(4): 105 | setattr(model, f"stop{i+1}", "anY_RanDOM_ThiNg") 106 | summary, stats[i+1] = get_model_summary(model.to(device), dump_input.to(device), verbose=False) 107 | delattr(model, f"stop{i+1}") 108 | 109 | if args.local_rank == 0: 110 | logger.info(f'\n\n>>>>>>>>>>>>>>>>>>>>>>> EXIT {i+1} >>>>>>>>>>>>>>>>>>>>>>>>>> ') 111 | logger.info(summary) 112 | 113 | saved_stats['params'] = [stats[i+1]['params'] for i in range(4)] 114 | saved_stats['flops'] = [stats[i+1]['flops'] for i in range(4)] 115 | saved_stats['counts'] = [stats[i+1]['counts'] for i in range(4)] 116 | saved_stats['Gflops'] = [f/(1024**3) for f in saved_stats['flops']] 117 | saved_stats['Mparams'] = [f/(10**6) for f in saved_stats['params']] 118 | 119 | json_save(os.path.join(final_output_dir, 'stats.json'), saved_stats) 120 | 121 | 122 | if distributed: 123 | torch.cuda.set_device(args.local_rank) 124 | torch.distributed.init_process_group( 125 | backend="nccl", init_method="env://", 126 | ) 127 | 128 | crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) 129 | train_dataset = eval('datasets.'+config.DATASET.DATASET)( 130 | root=config.DATASET.ROOT, 131 | list_path=config.DATASET.TRAIN_SET, 132 | num_samples=None, 133 | num_classes=config.DATASET.NUM_CLASSES, 134 | multi_scale=config.TRAIN.MULTI_SCALE, 135 | flip=config.TRAIN.FLIP, 136 | ignore_label=config.TRAIN.IGNORE_LABEL, 137 | base_size=config.TRAIN.BASE_SIZE, 138 | crop_size=crop_size, 139 | downsample_rate=config.TRAIN.DOWNSAMPLERATE, 140 | scale_factor=config.TRAIN.SCALE_FACTOR) 141 | 142 | if distributed: 143 | train_sampler = DistributedSampler(train_dataset) 144 | else: 145 | train_sampler = None 146 | 147 | trainloader = torch.utils.data.DataLoader( 148 | train_dataset, 149 | batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, 150 | shuffle=config.TRAIN.SHUFFLE and train_sampler is None, 151 | num_workers=config.WORKERS, 152 | pin_memory=True, 153 | drop_last=True, 154 | sampler=train_sampler) 155 | 156 | if config.DATASET.EXTRA_TRAIN_SET: 157 | extra_train_dataset = eval('datasets.'+config.DATASET.DATASET)( 158 | root=config.DATASET.ROOT, 159 | list_path=config.DATASET.EXTRA_TRAIN_SET, 160 | num_samples=None, 161 | num_classes=config.DATASET.NUM_CLASSES, 162 | multi_scale=config.TRAIN.MULTI_SCALE, 163 | flip=config.TRAIN.FLIP, 164 | ignore_label=config.TRAIN.IGNORE_LABEL, 165 | base_size=config.TRAIN.BASE_SIZE, 166 | crop_size=crop_size, 167 | downsample_rate=config.TRAIN.DOWNSAMPLERATE, 168 | scale_factor=config.TRAIN.SCALE_FACTOR) 169 | 170 | if distributed: 171 | extra_train_sampler = DistributedSampler(extra_train_dataset) 172 | else: 173 | extra_train_sampler = None 174 | 175 | extra_trainloader = torch.utils.data.DataLoader( 176 | extra_train_dataset, 177 | batch_size=config.TRAIN.BATCH_SIZE_PER_GPU, 178 | shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None, 179 | num_workers=config.WORKERS, 180 | pin_memory=True, 181 | drop_last=True, 182 | sampler=extra_train_sampler) 183 | 184 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0]) 185 | test_dataset = eval('datasets.'+config.DATASET.DATASET)( 186 | root=config.DATASET.ROOT, 187 | list_path=config.DATASET.TEST_SET, 188 | num_samples=config.TEST.NUM_SAMPLES, 189 | num_classes=config.DATASET.NUM_CLASSES, 190 | multi_scale=False, 191 | flip=False, 192 | ignore_label=config.TRAIN.IGNORE_LABEL, 193 | base_size=config.TEST.BASE_SIZE, 194 | crop_size=test_size, 195 | center_crop_test=config.TEST.CENTER_CROP_TEST, 196 | downsample_rate=1) 197 | 198 | if distributed: 199 | test_sampler = DistributedSampler(test_dataset) 200 | else: 201 | test_sampler = None 202 | 203 | testloader = torch.utils.data.DataLoader( 204 | test_dataset, 205 | batch_size=config.TEST.BATCH_SIZE_PER_GPU, 206 | shuffle=False, 207 | num_workers=config.WORKERS, 208 | pin_memory=True, 209 | sampler=test_sampler) 210 | 211 | criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL, 212 | weight=train_dataset.class_weights) 213 | 214 | model = FullEEModel(model, criterion, config=config) 215 | 216 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 217 | model = model.to(device) 218 | model = nn.parallel.DistributedDataParallel( 219 | model, device_ids=[args.local_rank], output_device=args.local_rank) 220 | 221 | if config.TRAIN.OPTIMIZER == 'sgd': 222 | if config.TRAIN.ALLE_ONLY: 223 | param = [ 224 | {'params': model.module.model.exit1.parameters(), 'lr': config.TRAIN.EXTRA_LR}, 225 | {'params': model.module.model.exit2.parameters(), 'lr': config.TRAIN.EXTRA_LR}, 226 | {'params': model.module.model.exit3.parameters(), 'lr': config.TRAIN.EXTRA_LR}, 227 | {'params': model.module.model.last_layer.parameters(), 'lr': config.TRAIN.EXTRA_LR}, 228 | ] 229 | elif config.TRAIN.EE_ONLY: 230 | param = [ 231 | {'params': model.module.model.exit1.parameters(), 'lr': config.TRAIN.EXTRA_LR}, 232 | {'params': model.module.model.exit2.parameters(), 'lr': config.TRAIN.EXTRA_LR}, 233 | {'params': model.module.model.exit3.parameters(), 'lr': config.TRAIN.EXTRA_LR} 234 | ] 235 | else: 236 | param = [ 237 | {'params': 238 | filter(lambda p: p.requires_grad, 239 | model.parameters()), 240 | 'lr': config.TRAIN.LR} 241 | ] 242 | 243 | optimizer = torch.optim.SGD(param, 244 | lr=config.TRAIN.LR, 245 | momentum=config.TRAIN.MOMENTUM, 246 | weight_decay=config.TRAIN.WD, 247 | nesterov=config.TRAIN.NESTEROV, 248 | ) 249 | else: 250 | raise ValueError('Only Support SGD optimizer') 251 | 252 | epoch_iters = np.int(train_dataset.__len__() / 253 | config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus)) 254 | best_mIoU = 0 255 | last_epoch = 0 256 | if config.TRAIN.RESUME: 257 | if config.DATASET.EXTRA_TRAIN_SET: 258 | model_state_file = os.path.join(config.RESUME_DIR, 'checkpoint.pth.tar') 259 | assert os.path.isfile(model_state_file) 260 | load_optimizer_dict = False 261 | else: 262 | model_state_file = os.path.join(final_output_dir, 263 | 'checkpoint.pth.tar') 264 | 265 | 266 | if os.path.isfile(model_state_file): 267 | checkpoint = torch.load(model_state_file, 268 | map_location=lambda storage, loc: storage) 269 | best_mIoU = checkpoint['best_mIoU'] 270 | last_epoch = checkpoint['epoch'] 271 | model.module.load_state_dict(checkpoint['state_dict']) 272 | if not config.DATASET.EXTRA_TRAIN_SET: 273 | optimizer.load_state_dict(checkpoint['optimizer']) 274 | logger.info("=> loaded checkpoint (epoch {})" 275 | .format(checkpoint['epoch'])) 276 | 277 | 278 | start = timeit.default_timer() 279 | end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH 280 | num_iters = config.TRAIN.END_EPOCH * epoch_iters 281 | extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters 282 | 283 | logger.info('Starting training at rank {}'.format(args.local_rank)) 284 | for epoch in range(last_epoch, end_epoch): 285 | if distributed: 286 | train_sampler.set_epoch(epoch) 287 | if epoch >= config.TRAIN.END_EPOCH: 288 | train_ee(config, epoch-config.TRAIN.END_EPOCH, 289 | config.TRAIN.EXTRA_EPOCH, epoch_iters, 290 | config.TRAIN.EXTRA_LR, extra_iters, 291 | extra_trainloader, optimizer, model, 292 | writer_dict, device) 293 | else: 294 | train_ee(config, epoch, config.TRAIN.END_EPOCH, 295 | epoch_iters, config.TRAIN.LR, num_iters, 296 | trainloader, optimizer, model, writer_dict, 297 | device) 298 | 299 | if args.local_rank == 0: 300 | logger.info('=> saving checkpoint to {}'.format( 301 | final_output_dir + 'checkpoint.pth.tar')) 302 | torch.save({ 303 | 'epoch': epoch+1, 304 | 'best_mIoU': best_mIoU, 305 | 'state_dict': model.module.state_dict(), 306 | 'optimizer': optimizer.state_dict(), 307 | }, os.path.join(final_output_dir,'checkpoint.pth.tar')) 308 | 309 | 310 | torch.save(model.module.state_dict(), os.path.join(final_output_dir,'checkpoint.pth')) 311 | 312 | 313 | if epoch == end_epoch - 1: 314 | torch.save(model.module.state_dict(), 315 | os.path.join(final_output_dir, 'final_state.pth')) 316 | 317 | writer_dict['writer'].close() 318 | end = timeit.default_timer() 319 | logger.info('Hours: {}'.format((end-start)/3600)) 320 | logger.info('Done') 321 | 322 | 323 | pid = os.getpid() 324 | torch.cuda.empty_cache() 325 | devices = os.environ['CUDA_VISIBLE_DEVICES'] 326 | device = devices.split(',')[1] 327 | command = f'CUDA_VISIBLE_DEVICES={device} python tools/test_ee.py --cfg {final_output_dir}/config.yaml' 328 | print(command) 329 | 330 | subprocess.run(command, shell=True) 331 | 332 | if __name__ == '__main__': 333 | main() 334 | 335 | 336 | 337 | --------------------------------------------------------------------------------